数据加载模块

本贴最后更新于 640 天前,其中的信息可能已经斗转星移

说明

这里主要是记录一下 DataSet​和 DataLoader​的用法,主要的知识点其实就是 pyhon 的迭代器模块。数据加载模块在深度学习中最基础的模块,这里不包括对数据的各种预处理,实现的功能是 输入整体的数据集,然后每一个迭代返回一个batch的数据。

DataSet

这个类的作用:

  • 1、接受传入的数据;
  • 2、魔术方法 __len__​获取数据集的长度;
  • 3、魔术方法 __getitem__​等价于 []​,负责从数据集中拿到特定索引的数据。

DataLoader

这个类的作用:

  • 1、接受传入的 DataSet​对象,batch_size​大小,以及数据是否需要进行 shuffle​。(注意:对数据 shuffle 的成本较高,目前通用的方法是 shuffle 数据的索引,再根据索引拿数据。)
  • 2、魔术方法 __iter__​,这里相当于声明类是一个迭代器,后续代码中 DataLoader​类的对象就是可迭代对象,可以使用 for 循环进行迭代操作,一般和 __next__​一起出现。
  • 3、魔术方法 __next__​,用于返回下一个元素,这里用于返回下一个 batch 的数据。这里需要注意的是,需要手动维护一个计数器,在计数超过数据集本身的长度时,触发 StopIteration​并计数器归零,这样迭代过程就停止,相应在 for 循环中就会进入下一次循环。

代码实现

import random """ 构建数据集加载模块,包括 dataset 和 dataloader """ class DataSet: def __init__(self, data_list: list) -> None: """ data_list: 输入的数据集内容,可以是数据,也可以是标签, 输入数据结构为列表 """ self.data = data_list def __len__(self): """得到数据集的数量""" return len(self.data) def __getitem__(self, index): """从数据集中拿到指定索引的数据""" return self.data[index] class DataLoader: def __init__(self, dataset, batch_size, shuffle): """ dataset: DataSet的实例化对象 batch_size: 一次输出的样本的数量 shuffle: 布尔值,标志位,是否需要打乱数据 """ self.dataset = dataset self.batch_size = batch_size self.data_count = 0 self.data_index = list(range(len(dataset))) if shuffle: random.shuffle(self.data_index) def __iter__(self): """返回一个迭代器对象""" return self def __next__(self): """返回迭代器下一个batch的数据""" # 如果数据计数大于数据集的数据数量,抛出异常,重置迭代器 if self.data_count > len(self.dataset): self.data_count = 0 raise StopIteration index_list = self.data_index[self.data_count: self.data_count+self.batch_size] data_batch = [self.dataset[i] for i in index_list] self.data_count += self.batch_size return data_batch if __name__ == "__main__": import numpy as np # 生成数据 data_gen = np.random.randint(0, 10, size=(105, 20)) data_set = DataSet(data_list=data_gen) data_loader = DataLoader(dataset=data_set, batch_size=16, shuffle=True) # 遍历迭代器,输出数据 count = 0 for i_d in data_loader: print(f"itration: {count}") print(f"data_length: {len(i_d)}") print(np.array(i_d)) count += 1
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    556 引用 • 675 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...