PyTorch torch.stack() method joins (concatenates) a sequence of tensors (two or more tensors) along a new dimension. It inserts new dimension and concatenates the tensors along that dimension. This method joins the tensors with the same dimensions and shape. We could also use torch.cat() to join tensors But here we discuss the torch.stack() method.
Syntax: torch.stack(tensors, dim=0)
Arguments:
- tensors: It’s a sequence of tensors of same shape and dimensions
- dim: It’s the dimension to insert. It’s an integer between 0 and the number of dimensions of input tensors.
Returns: It returns the concatenated tensor along a new dimension.
Let’s understand the torch.stack() method with the help of some Python 3 examples.
Example 1:
In the Python example below we join two one-dimensional tensors using torch.stack() method.
Python3
import torch
x = torch.tensor([ 1. , 3. , 6. , 10. ])
y = torch.tensor([ 2. , 7. , 9. , 13. ])
print ( "Tensor x:" , x)
print ( "Tensor y:" , y)
print ( "join tensors:" )
t = torch.stack((x,y))
print (t)
print ( "join tensors dimension 0:" )
t = torch.stack((x,y), dim = 0 )
print (t)
print ( "join tensors dimension 1:" )
t = torch.stack((x,y), dim = 1 )
print (t)
|
Output:
Tensor x: tensor([ 1., 3., 6., 10.])
Tensor y: tensor([ 2., 7., 9., 13.])
join tensors:
tensor([[ 1., 3., 6., 10.],
[ 2., 7., 9., 13.]])
join tensors dimension 0:
tensor([[ 1., 3., 6., 10.],
[ 2., 7., 9., 13.]])
join tensors dimension 1:
tensor([[ 1., 2.],
[ 3., 7.],
[ 6., 9.],
[10., 13.]])
Explanation: In the above code tensors x and y are one-dimensional each having four elements. The final concatenated tensor is a 2D tensor. As the dimension is 1, we can stack the tensors with dimensions 0 and 1. When dim =0 the tensors are stacked increasing the number of rows. When dim =1 the tensors are transposed and stacked along the column.
Example 2:
In the Python example below we join two one-dimensional tensors using torch.stack() method.
Python3
import torch
x = torch.tensor([[ 1. , 3. , 6. ], [ 10. , 13. , 20. ]])
y = torch.tensor([[ 2. , 7. , 9. ], [ 14. , 21. , 34. ]])
print ( "Tensor x:\n" , x)
print ( "Tensor y:\n" , y)
print ( "join tensors" )
t = torch.stack((x, y))
print (t)
print ( "join tensors in dimension 0:" )
t = torch.stack((x, y), 0 )
print (t)
print ( "join tensors in dimension 1:" )
t = torch.stack((x, y), 1 )
print (t)
print ( "join tensors in dimension 2:" )
t = torch.stack((x, y), 2 )
print (t)
|
Output:
Tensor x:
tensor([[ 1., 3., 6.],
[10., 13., 20.]])
Tensor y:
tensor([[ 2., 7., 9.],
[14., 21., 34.]])
join tensors
tensor([[[ 1., 3., 6.],
[10., 13., 20.]],
[[ 2., 7., 9.],
[14., 21., 34.]]])
join tensors in dimension 0:
tensor([[[ 1., 3., 6.],
[10., 13., 20.]],
[[ 2., 7., 9.],
[14., 21., 34.]]])
join tensors in dimension 1:
tensor([[[ 1., 3., 6.],
[ 2., 7., 9.]],
[[10., 13., 20.],
[14., 21., 34.]]])
join tensors in dimension 2:
tensor([[[ 1., 2.],
[ 3., 7.],
[ 6., 9.]],
[[10., 14.],
[13., 21.],
[20., 34.]]])
Explanation: In the above code, x and y are two-dimensional tensors. Notice that the final tensor is a 3-D tensor. As the dimension of each input tensor is 2, we can stack the tensors with dimensions 0 and 2. See the differences among the final output tensors with dim = 0, 1, and 2.
Example 3:
In this example, we join more than two tensors. We can join any number of tensors.
Python3
import torch
x = torch.tensor([ 1. , 3. , 6. , 10. ])
y = torch.tensor([ 2. , 7. , 9. , 13. ])
z = torch.tensor([ 4. , 5. , 8. , 11. ])
print ( "Tensor x:" , x)
print ( "Tensor y:" , y)
print ( "Tensor z:" , z)
print ( "join tensors:" )
t = torch.stack((x, y, z))
print (t)
print ( "join tensors dimension 0:" )
t = torch.stack((x, y, z), dim = 0 )
print (t)
print ( "join tensors dimension 1:" )
t = torch.stack((x, y, z), dim = 1 )
print (t)
|
Output:
Tensor x: tensor([ 1., 3., 6., 10.])
Tensor y: tensor([ 2., 7., 9., 13.])
Tensor z: tensor([ 4., 5., 8., 11.])
join tensors:
tensor([[ 1., 3., 6., 10.],
[ 2., 7., 9., 13.],
[ 4., 5., 8., 11.]])
join tensors dimension 0:
tensor([[ 1., 3., 6., 10.],
[ 2., 7., 9., 13.],
[ 4., 5., 8., 11.]])
join tensors dimension 1:
tensor([[ 1., 2., 4.],
[ 3., 7., 5.],
[ 6., 9., 8.],
[10., 13., 11.]])
Example 4: Demonstrating Errors
In the example below we show errors when the input tensors are not of the same shape.
Python3
import torch
x = torch.tensor([ 1. , 3. , 6. , 10. ])
y = torch.tensor([ 2. , 7. , 9. ])
print ( "Tensor x:" , x)
print ( "Tensor y:" , y)
print ( "join tensors:" )
t = torch.stack((x, y))
print (t)
print ( "join tensors dimension 0:" )
t = torch.stack((x, y), dim = 0 )
print (t)
print ( "join tensors dimension 1:" )
t = torch.stack((x, y), dim = 1 )
print (t)
|
Output:
Shape of x: torch.Size([4])
Shape of y: torch.Size([3])
RuntimeError: stack expects each tensor to be equal size, but got [4] at entry 0 and [3] at entry 1
Notice that the shape of the two tensors is not the same. It throws a runtime error. In the same way, when the dimension of tensors is not the same it throws a runtime error. Try for yourself for tensors with different dimensions and see how the output is.
Please Login to comment...