pytorch如何加载自己的数据集

1017
2024/4/7 15:23:15
栏目: 深度学习
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在PyTorch中加载自己的数据集,一般可以通过自定义数据集类和数据加载器来实现。下面是一个简单的示例:

  1. 创建自定义数据集类:
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        return sample
  1. 准备数据集并实例化自定义数据集类:
# 准备数据集
data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

# 实例化自定义数据集类
dataset = CustomDataset(data)
  1. 创建数据加载器:
from torch.utils.data import DataLoader

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
  1. 使用数据加载器迭代数据集:
for batch in dataloader:
    print(batch)

以上就是使用PyTorch加载自定义数据集的简单示例。在实际应用中,你可能需要根据数据集的具体情况来修改自定义数据集类中的方法,并根据需求设置数据加载器的参数。

辰迅云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读: pytorch图像处理如何优化