pytorch中flatten函数的用法是什么

1063
2024/2/28 14:29:27
栏目: 编程语言
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在PyTorch中,flatten函数用于将输入张量展平为一维张量。它的用法如下:

torch.flatten(input, start_dim=0, end_dim=-1)

参数说明:

  • input:输入的张量。
  • start_dim:开始展平的维度,默认为0。
  • end_dim:结束展平的维度,默认为-1,表示展平到最后一维。

flatten函数将沿着指定的维度范围将输入张量展平为一维张量。展平后的张量将包含原始张量中的所有元素,并将其重新排列为一维。

示例:

import torch

x = torch.randn(3, 4, 5)
flattened = torch.flatten(x)
print(flattened.shape)  # 输出: torch.Size([60])

flattened_dim1 = torch.flatten(x, start_dim=1)
print(flattened_dim1.shape)  # 输出: torch.Size([3, 20])

flattened_dim1_dim2 = torch.flatten(x, start_dim=1, end_dim=2)
print(flattened_dim1_dim2.shape)  # 输出: torch.Size([3, 20, 5])

在上面的示例中,flatten函数首先将形状为(3, 4, 5)的张量x展平为形状为(60,)的一维张量。然后,通过指定start_dim=1,将张量x的第二个维度展平,得到形状为(3, 20)的张量。最后,通过指定start_dim=1, end_dim=2,将张量x的第二个和第三个维度展平,得到形状为(3, 20, 5)的张量。

辰迅云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读: PyTorch中怎么加载数据集