Five must-know Pytorch tensor functions
These will make you stand out of the crowd
This article is inspired by the course Deep Learning with PyTorch: Zero to GANs by Jovian.ai. Enrol now to start learning a practical and coding-focused introduction to deep learning using the PyTorch framework.
In this article, we will look at five Pytorch tensor functions
torch.apply_(callable)
torch.apply_() applies the function callable to each element in the tensor, replacing each element with the value returned by callable. Let’s have a look at a couple of examples.
The above example uses a function name square which is applied to all the elements of tensor t (in place). So, the tensor t contains the square of each element.
We have used a lambda function as callable here. The function multiplies 10 with every element of the tensor (in place).
We see that the above example gives an error. For the torch.apply_(callable) method to work, callable should a real number and not a NoneType as seen from this example. Function square_root() just prints the square root of the number but does not return anything (returns None).
torch.apply_(callable) is a useful function when you want to apply a method to all the elements of the tensor in one go. It saves you from using a loop or nested loops.
torch.baddbmm()
This function performs a batch matrix-matrix product of matrices in batch1 and batch2. Input is added to the final result.
batch1 and batch2 must be 3-D tensors each containing the same number of matrices.
For example,
We have input matrix of size (2,3,6) and two batch matrices i.e. batch1 of size (2,3,5) and batch2 of size (2,5,6) which will multiply and a matrix of size (2,3,6) which can be added with the scaled version of Input.
This example is the same as the above example but we have given the first dimension as 1 which means that the matrices are 2 dimensional essentially.
The example fails because as it was mentioned earlier, the function expects only 3-D matrices. Even if your matrices are 2-D, you must give 3-D matrices as arguments to the Function (1 x m x n).
torch.trace()
This function returns the sum of elements in the trace of the 2-D matrix.
As we can see that the sum of the trace of the matrix is 1 + 5 + 9 = 15.
This example uses a matrix of size 1 x 1. The sum of trace here is the single element inside the matrix.
We have to be careful here because torch.trace() does not work on complex numbers.
torch.eig()
The function returns the eigenvalues and eigenvectors of a real square matrix.
The outputs are the eigenvalues and eigenvectors (if set True) of the given matrix. To learn more about eigenvalues/ eigenvectors refer to this brilliant article by Farhad Malik here (personally my favourite).
Here is another example which finds the eigenvalues and eigenvectors of a 3 x 3 matrix.
This error is the basic one. For finding the eigenvalues of the matrix, the matrix has to be square i.e. dimensions must be equal.
torch.eig() is useful in linear algebraic calculations and directly finding the eigenvalues and the corresponding eigenvectors of the given matrix.
torch.mean()
It returns the mean value of all elements in the input tensor.
We know that the mean of all the elements of the matrix is the sum of all the elements of the matrix divided by the total number of elements of the matrix.
In order to find the mean along some axis, we give 0 or 1 as an argument to the function. 0 for finding mean across columns and 1 for finding mean across rows.
The above example fails because we used an empty tensor.
Conclusion
Tensors are the heart of Pytorch and there are more functions to explore. We have looked on just five tensor operations which may help you in some projects of yours. You can go to jovian.ai for more data science resources.
There is always a first time for everything. This is my first article on medium.com. Hope you have learned something new. Any feedback is highly appreciated. Thank you all and jovian.ai.
References
- Official documentation for the torch.tensor: https://pytorch.org/docs/stable/tensors.html
- Here is the full code.