在PyTorch 中,自定义数据集类的作用主要是为了灵活地处理数据加载和预处理,使得数据集与 DataLoader 更好地配合工作,可以方便地进行数据预处理、数据增强等操作,并能够处理不同格式的数据。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| from torch.utils.data import Dataset
class CustomDataset(Dataset): def __ini__(self, path_dir): '''' '''' self.data_X = read_data_X(path_dir) self.data_y = read_data_y(path_dir) def __len__(self): return len(self.data) def __getotem__(self, idx): x = self.data_X y = self.data_y return {"label": y, "data": x} OR image, label
# 完成自定义后,调用DataLoader进行 from torch.utils.data import DataLoader from sklearn.model_selection import train_test_split
train_dataset = CustomDataset(path_dir) train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
''' '''
# 训练部分 for epoch in range(10): for batch in train_loader: out = model(batch) ''' '''
|