0%

自定义Pytorch数据集类

在PyTorch 中,自定义数据集类的作用主要是为了灵活地处理数据加载和预处理,使得数据集与 DataLoader 更好地配合工作,可以方便地进行数据预处理、数据增强等操作,并能够处理不同格式的数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torch.utils.data import Dataset

class CustomDataset(Dataset):
def __ini__(self, path_dir):
''''
''''
self.data_X = read_data_X(path_dir)
self.data_y = read_data_y(path_dir)

def __len__(self):
return len(self.data)

def __getotem__(self, idx):
x = self.data_X
y = self.data_y
return {"label": y, "data": x} OR image, label

# 完成自定义后,调用DataLoader进行
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

train_dataset = CustomDataset(path_dir)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)

'''
'''

# 训练部分
for epoch in range(10):
for batch in train_loader:
out = model(batch)
'''
'''