pytorch中如何加载训练数据

1.为什么不需要自己写加载方法

pytorch中提供了两个类用于训练数据的加载,分别是 torch.utils.data.Dataset和 torch.utils.data.DataLoader 。不像torchvision中集合了很多常用的计算机视觉的常用数据集,作为在音乐信息检索这方面,数据集要自己设计加载方法。如果每次不同的数据集都要自己写函数加载,

  • 每次读取代码不能够重用,不同的数据读取代码不同
  • 自己写的加载函数也会有各种问题,比如说限制数据读取速度,或者当数据集太大,直接加载到字典或者列表中会很占用内存,数据读取阶段也会占用大量时间
  • 只能单线程读取数据

这次我做的实验需要加载歌曲的梅尔频谱,每个歌曲的片段为30秒,大约是一个1290*128大小的矩阵。所以这次我决定使用pytorch的Dataset类来加载数据。

2.Dataset类

class torch.utils.data.Dataset

这个抽象类代表了数据集,任何我们自己设计的数据集类都应该是这个类的子类,继承这个类,重写 __len__() 方法,这个方法是用来获得数据集的大小,和__getitem__()方法,这个方法用来返回数据集中索引值为0到len(dataset)的元素。

  • def __getitem__(self, index): 实现这个函数,就可以通过索引值来返回训练样本数据
  • def __len__(self): 实现这个函数,返回数据集的大小
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Dataset(object):
"""An abstract class representing a Dataset.

All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""

def __getitem__(self, index):
raise NotImplementedError

def __len__(self):
raise NotImplementedError

def __add__(self, other):
return ConcatDataset([self, other])

如果不重写这两个私有函数,就会触发错误。

3.定义自己的数据集类

于是我就针对自己的需求实现了以下的类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Fma_dataset(Dataset):
# root 是训练集的根目录, mode可选的参数是train,test,validation,分别读取相应的文件夹
def __init__(self, root, mode):
self.mode = mode
self.root = root + "/fma_" + self.mode
self.mel_cepstrum_path = self.get_sample(self.root)

def __getitem__(self, index):
sample = np.load(self.mel_cepstrum_path[index])
data = torch.from_numpy(sample[0])
target = torch.from_numpy(sample[1].astype(np.float32))
return data, target

def __len__(self):
if self.mode == "train":
return 23733 # 训练集大小
elif self.mode == "validation":
return 6780 # 验证集大小
elif self.mode == "test":
return 3390 # 测试集大小

def get_sample(self, root):
cepstrum = []
for entry in os.scandir(root):
if entry.is_file():
cepstrum.append(entry.path)
return cepstrum

4.DataLoader类

classtorch.utils.data.DataLoader(dataset, batch_size=1**, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)**

仅仅有通过索引返回训练数据数不够的,我们还需要DataLoad类提供拓展功能。

  • 可以分批次读取:batch-size
  • 可以对数据进行shuffle操作
  • 可以用多个线程来读取数据

这个类我们不需要实现代码,直接调用,设置好参数就行了。