PyTorch中怎么初始化模型权重

1311
2024/4/12 15:34:54
栏目: 深度学习
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在PyTorch中,可以通过定义一个函数来初始化模型的权重。以下是一个示例代码:

import torch
import torch.nn as nn

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)

# 定义模型
model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3),
    nn.ReLU(),
    nn.Linear(64*28*28, 10)
)

# 初始化模型权重
model.apply(init_weights)

在上面的代码中,定义了一个init_weights函数,该函数根据模型的类型对权重进行初始化。然后通过调用model.apply(init_weights)来初始化模型的权重。这样就可以保证模型的权重被正确地初始化。

辰迅云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读: PyTorch模型训练怎样优化超参数