数据加载模块

本贴最后更新于 487 天前,其中的信息可能已经斗转星移

说明

这里主要是记录一下 DataSet​和 DataLoader​的用法,主要的知识点其实就是 pyhon 的迭代器模块。数据加载模块在深度学习中最基础的模块,这里不包括对数据的各种预处理,实现的功能是 输入整体的数据集,然后每一个迭代返回一个batch的数据。

DataSet

这个类的作用:

  • 1、接受传入的数据;
  • 2、魔术方法 __len__​获取数据集的长度;
  • 3、魔术方法 __getitem__​等价于 []​,负责从数据集中拿到特定索引的数据。

DataLoader

这个类的作用:

  • 1、接受传入的 DataSet​对象,batch_size​大小,以及数据是否需要进行 shuffle​。(注意:对数据 shuffle 的成本较高,目前通用的方法是 shuffle 数据的索引,再根据索引拿数据。)
  • 2、魔术方法 __iter__​,这里相当于声明类是一个迭代器,后续代码中 DataLoader​类的对象就是可迭代对象,可以使用 for 循环进行迭代操作,一般和 __next__​一起出现。
  • 3、魔术方法 __next__​,用于返回下一个元素,这里用于返回下一个 batch 的数据。这里需要注意的是,需要手动维护一个计数器,在计数超过数据集本身的长度时,触发 StopIteration​并计数器归零,这样迭代过程就停止,相应在 for 循环中就会进入下一次循环。

代码实现

import random

"""
构建数据集加载模块,包括 dataset 和 dataloader
"""

class DataSet:
    def __init__(self, data_list: list) -> None:
        """
        data_list: 输入的数据集内容,可以是数据,也可以是标签,
        输入数据结构为列表
        """
        self.data = data_list

    def __len__(self):
        """得到数据集的数量"""
        return len(self.data)
  
    def __getitem__(self, index):
        """从数据集中拿到指定索引的数据"""
        return self.data[index]
  

class DataLoader:
    def __init__(self, dataset, batch_size, shuffle):
        """
        dataset: DataSet的实例化对象
        batch_size: 一次输出的样本的数量
        shuffle: 布尔值,标志位,是否需要打乱数据
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.data_count = 0
        self.data_index = list(range(len(dataset)))

        if shuffle:
            random.shuffle(self.data_index)

    def __iter__(self):
        """返回一个迭代器对象"""
        return self
  
    def __next__(self):
        """返回迭代器下一个batch的数据"""
        # 如果数据计数大于数据集的数据数量,抛出异常,重置迭代器
        if self.data_count > len(self.dataset):
            self.data_count = 0
            raise StopIteration
      
        index_list = self.data_index[self.data_count: self.data_count+self.batch_size]
        data_batch = [self.dataset[i] for i in index_list]
        self.data_count += self.batch_size

        return data_batch
  

if __name__ == "__main__":
    import numpy as np

    # 生成数据
    data_gen = np.random.randint(0, 10, size=(105, 20))
    data_set = DataSet(data_list=data_gen)

    data_loader = DataLoader(dataset=data_set, batch_size=16, shuffle=True)

    # 遍历迭代器,输出数据
    count = 0
    for i_d in data_loader:
        print(f"itration: {count}")
        print(f"data_length: {len(i_d)}")
        print(np.array(i_d))
        count += 1
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    543 引用 • 672 回帖 • 1 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...