pytorch的view函数怎么使用

1029
2023/12/5 11:07:09
栏目: 编程语言
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在PyTorch中,view()函数用于调整张量的形状。它的使用方式如下:

output = input.view(*shape)

这里的input是输入张量,shape是一个元组,用于指定调整后的形状。具体来说:

  • shape中的每个元素可以是一个具体的维度大小,或者-1表示根据其他维度的大小自动计算。
  • 调整后的张量和原始张量共享内存空间,即它们指向相同的数据。

下面是一些示例:

import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 将x的形状调整为(3, 2)
output = x.view(3, 2)
print(output)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

# 将x的形状调整为(6, -1),其中-1表示自动计算
output = x.view(6, -1)
print(output)
# 输出:
# tensor([[1],
#         [2],
#         [3],
#         [4],
#         [5],
#         [6]])

# 将x的形状调整为(1, 6)
output = x.view(1, 6)
print(output)
# 输出:
# tensor([[1, 2, 3, 4, 5, 6]])

需要注意的是,调整后的形状必须和原始张量的元素总数保持一致,否则会抛出错误。

辰迅云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读: PyTorch中怎么使用反向传播