torch.cat
2021/2/25
所幸这个和我印象中的用法非常一样
而且这次官方文档 写得更像个人了:EVEN BETTER(worse,为什么难的函数不写这么明白)
torch.cat
torch.cat(
tensors,
dim=0,
*,
out=None
) → Tensor
- 作用:在给定维度上Concatenate输入的tensor序列。所有的tensor应该有相同的shape(除了需要concatenate的维度)或为空。
-
torch.cat()
可以被看作是torch.split()
和torch.chunk()
的反操作。
Parameters
- tensors (sequence of Tensors) – 任何相同类型(type)的python tensor序列。非空的tensor要在非concatenate维度有相同的形状。
- dim (int, optional) – 要concate的维度。
Keyword Arguments
- out (Tensor, optional) – 输出的tensor。
Examples
官方例子:
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])
但是我至今对高维tensor的concatenate缺乏想象力…还是尝试写了个例子:
>>> x =torch.ones(2, 2, 2, 2)
>>> x
tensor([[[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]],
[[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]]])
>>> y = torch.ones(2, 2, 3, 2).mul(2)
>>> torch.cat((x, y) ,2)
tensor([[[[1., 1.],
[1., 1.],
[2., 2.],
[2., 2.],
[2., 2.]],
[[1., 1.],
[1., 1.],
[2., 2.],
[2., 2.],
[2., 2.]]],
[[[1., 1.],
[1., 1.],
[2., 2.],
[2., 2.],
[2., 2.]],
[[1., 1.],
[1., 1.],
[2., 2.],
[2., 2.],
[2., 2.]]]])