2023年11月29日发(作者:)

Pytorch函数expand()详解

Pytorch函数 .expand( )

其将单个维度扩⼤成更⼤维度,返回⼀个新的tensor,具体看下例:

import torch

a = torch.Tensor([[1], [2], [3],[4]])

# expanda

未使⽤()函数前的

print(': ', a.size())

print('a: ', a)

b = a.expand(4, 2)

# expand

使⽤()函数后的输出

print(': ', a.size())

print('a: ', a)

print(': ', b.size())

print('b: ', b)

expand()函数使⽤前后a没有发⽣变化,输出都是:

: ([4, 1])

a:

1

2

3

4

[ensor of size 4x1]

b 的输出为:

: ([4, 2])

b:

1 1

2 2

3 3

4 4

[ensor of size 4x2]

由此得出结论,a通过expand()函数扩展某⼀维度后⾃⾝不会发⽣变化

a = torch.Tensor([[[[1,2], [2,3], [3,4],[4,5]]]])

b = a.expand(2, 1, 4, 2)

c = a.expand(1, 2, 4, 2)

# expand

使⽤()函数后的输出

print(': ', a.size())

print(': ', b.size())

print('b: ', b)

print(': ', c.size())

print('c: ', c)

b2 = b.expand(3, 1, 4, 2) # b: ([2, 1, 4, 2])

print(': ', b2.size())

输出:

: ([1, 1, 4, 2])

: ([2, 1, 4, 2])

b: