1.为什么不需要自己写加载方法
pytorch中提供了两个类用于训练数据的加载,分别是 torch.utils.data.Dataset和 torch.utils.data.DataLoader 。不像torchvision中集合了很多常用的计算机视觉的常用数据集,作为在音乐信息检索这方面,数据集要自己设计加载方法。如果每次不同的数据集都要自己写函数加载,
- 每次读取代码不能够重用,不同的数据读取代码不同
- 自己写的加载函数也会有各种问题,比如说限制数据读取速度,或者当数据集太大,直接加载到字典或者列表中会很占用内存,数据读取阶段也会占用大量时间
- 只能单线程读取数据
这次我做的实验需要加载歌曲的梅尔频谱,每个歌曲的片段为30秒,大约是一个1290*128大小的矩阵。所以这次我决定使用pytorch的Dataset类来加载数据。
2.Dataset类
class torch.utils.data.Dataset
这个抽象类代表了数据集,任何我们自己设计的数据集类都应该是这个类的子类,继承这个类,重写 __len__() 方法,这个方法是用来获得数据集的大小,和__getitem__()方法,这个方法用来返回数据集中索引值为0到len(dataset)的元素。
- def __getitem__(self, index): 实现这个函数,就可以通过索引值来返回训练样本数据
- def __len__(self): 实现这个函数,返回数据集的大小
1 | class Dataset(object): |
如果不重写这两个私有函数,就会触发错误。
3.定义自己的数据集类
于是我就针对自己的需求实现了以下的类。
1 | class Fma_dataset(Dataset): |
4.DataLoader类
class
torch.utils.data.DataLoader
(dataset, batch_size=1**, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)**
仅仅有通过索引返回训练数据数不够的,我们还需要DataLoad类提供拓展功能。
- 可以分批次读取:batch-size
- 可以对数据进行shuffle操作
- 可以用多个线程来读取数据
这个类我们不需要实现代码,直接调用,设置好参数就行了。