Some notes on memory in pytorch
You may have been confused by some functions on torch that does the same thing but with different names. For example:
reshape(), view(), permute(), transpose()
Are they really doing differently? No! but in order to understand it first we need to know a little bit how the tensors are implemented in pytorch.
Tensors are abstract or logical constructs just like arrays that can’t be implemented the way they conceived. The obvious reason is that memory cells are sequential, so we need to find a way to save them in memory. For example, if we have a 2d tensor (or array) like below:
The normal (or contiguous) way of saving it into memory is rowed by row. So we will have:
Every tensor has metadata that tells how to read the tensor. For example, in this 2d tensor for accessing the next row we have to move forward 3 steps and the next column we should move 1 step. We call these two numbers strides. So we can extract them like below:
This opens up new possibilities for us because in order to change the matrix we can change the stride metadata! For example, if we change the (3, 1) stride to (1, 3) we actually transposed the matrix without any manipulation in all the memory items.
As you noticed the tensor is not contiguous anymore because we changed it!In order to go to the next row, we only have to skip 1 value, while 3 to move to the next column.
This makes sense if we recall the memory layout of the tensor:
[0, 1, 2, 3, 4, …, 11]
In order to move to the next column (e.g. from 0 to 3, we would have to skip 3 values. The tensor is thus non-contiguous anymore! To make it contiguous just call the contiguous on it:
When you call contiguous()
, it actually makes a copy of the tensor so the order of elements would be the same as if a tensor of the same shape created from scratch.
Note that the word “contiguous” is bit misleading because its not that the content of tensor is spread out around disconnected blocks of memory. Here bytes are still allocated in one block of memory but the order of the elements is different!
With the same token, the view function is just a view of the original variable which means if you change the original memory it will change:
This is actually very efficient because we don’t have to make new memory slots for a transformation. But reshape can copy the original data. from the original doc:
Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.
For example if we have code like below:
gives:
The view doesn’t work on non-contiguous data.
Also, consider that permute is another function that only works on the metadata so it also creates non-contiguous data. Permute changes the order of axis so it’s totally different from view or reshape that change the shape of the matrix.