如何在Torch中定义一个神经网络模型

812
2024/4/14 19:29:59
栏目: 深度学习
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在Torch中定义一个神经网络模型通常需要使用nn.Module类。下面是一个示例代码,展示了如何定义一个简单的全连接神经网络模型:

import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = SimpleNN()

在上面的代码中,我们定义了一个名为SimpleNN的神经网络模型,它包含两个全连接层和一个ReLU激活函数。在__init__方法中,我们定义了模型的各个层,然后在forward方法中定义了数据在模型中的流动路径。

需要注意的是,在定义神经网络模型时,通常需要继承nn.Module类,并实现__init__forward方法。__init__方法用于初始化模型的结构,forward方法用于定义数据在模型中的传播路径。

辰迅云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读: 如何在Torch中保存和加载模型