Torch中怎么加载和处理数据集

894
2024/2/19 13:29:33
栏目: 深度学习
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在Torch中加载和处理数据集通常通过使用torch.utils.data.Datasettorch.utils.data.DataLoader类来实现。以下是一个简单的示例代码:

import torch
from torch.utils.data import Dataset, DataLoader

# 定义自定义数据集类
class CustomDataset(Dataset):
    def __init__(self):
        # 初始化数据集
        self.data = torch.randn(100, 10)
        self.labels = torch.randint(0, 2, (100,))
        
    def __len__(self):
        # 返回数据集大小
        return len(self.data)
    
    def __getitem__(self, idx):
        # 获取数据集中的一个样本
        return self.data[idx], self.labels[idx]

# 创建数据集实例
dataset = CustomDataset()

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 遍历数据集
for data, labels in dataloader:
    # 处理每个批次的数据
    print(data.shape, labels.shape)

在上面的示例中,定义了一个自定义的数据集类CustomDataset,其中实现了__init____len____getitem__方法。然后创建了dataset实例和dataloader对象,并使用for循环遍历数据加载器,获取每个批次的数据。

辰迅云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读: Torch中的序列标注模块有哪些