July 9, 2025

What is Einsum anyway?

4 minutes to read ยท 799 words

Einstein summation lets us write an algebric expression in index-notation and have numpy/pytorch/jax generate the most effiecient set of tensor operations to compute it.

The only cons is that it can be confusing and comes with a learning curve. Einsum is backend agnostic tho, i.e. it has identical calls whether we are using torch.einsum, jax.einsum or numpy.einsum.

A walkthrough with common use cases:

Matrix Multiplication

import numpy as np

A = np.random.rand(3, 5)
B = np.random.rand(5, 2)
M = np.empty((3, 2))

for i in range(3):
    for j in range(2):
        total = 0
        for k in range(5):
            total += A[i, k] * B[k, j]
        M[i, j] = total

Using einsum we can do the multiplication like this:

  • ik : dimensions of matrix A
  • kj : dimensions of matrix B
  • ij : dimensions of matrix M
M = np.einsum('ik,kj->i,j', A, B)

Free Indices are the indices specified in the output. Summation indices are all the other indices. Those that appear in the input argument but NOT in the output specification. So, i, j are the free indices and k is the summation index.

Example:

import numpy as np

a = np.random.rand(6)
b = np.random.rand(4)
outer = np.einsum('i,j->ij', a, b)
print(f"Using einsum: {outer}\n")


mat = np.empty((6,4))
for i in range(6):
    for j in range(4):
        total = 0
        total += a[i] * b[j]
        mat[i,j] = total
print(f"Using loops: {mat}\n")

Rules

  • Repeating letters in different inputs means those values will be multiplied and those products will be the output
    • M = np.einsum(‘ik,kj->ij’, A, B)
  • Omitting a letter means that axis will be summed
    • x = np.ones(3)
    • sum_x = np.einsum(‘i->’, x)
  • We can return the unused axes in any order
    • x = np.ones((5,4,3))
    • np.einsum(‘ijk->kji’, x)

Operations

Permutation of Tensors

Transpose of a matrix

import torch

x = torch.rand((2, 3))
print(x)

out = torch.einsum('ij->ji', x)
print(f"Transpose: {out}")
tensor([[0.5600, 0.9810, 0.7717],
         [0.1076, 0.1166, 0.8147]])

Transpose: tensor([[0.5600, 0.1076],
         [0.9810, 0.1166],
         [0.7717, 0.8147]])

Summation

import torch

x = torch.rand((2, 3))
print(x)

out = torch.einsum('ij->', x)
print(f"Summation: {out}")
tensor([[0.2053, 0.0860, 0.5769],
        [0.8407, 0.0610, 0.0562]])

Summation: tensor(1.8262)

Column Sum

import torch

x = torch.rand((2, 3))
print(x)

out = torch.einsum('ij->j', x)
print(f"Column Sum: {out}")
tensor([[0.3466, 0.7609, 0.0815],
        [0.2268, 0.6954, 0.8935]])

Column Sum: tensor([0.5734, 1.4562, 0.9750])

Row Sum

import torch
x = torch.rand((2, 3))
print(x)

out = torch.einsum('ij->i', x)
print(f"Row Sum: {out})
tensor([[0.8775, 0.3120, 0.6106],
        [0.5312, 0.6179, 0.5185]])

Row Sum: tensor([1.8001, 1.6676])

Matrix Vector Multiplication

import torch

v = torch.rand((1, 3))
x = torch.rand((2, 3))
print(v)
print(x)

out = torch.einsum('ij, kj -> ik', x, v)
print(f"Matrix Vector Multiplication: {out}")
tensor([[0.7960, 0.9741, 0.1856]])

tensor([[0.4157, 0.2612, 0.9272],
        [0.8721, 0.3250, 0.4674]])

Matrix Vector Multiplication: tensor([[0.7573], [1.0974]])

Matrix Matrix Multiplication

import torch

x = torch.rand((2, 3))
print(x)

out = torch.einsum('ij, kj -> ik', x, x) # 2x3 * 3x2 = 2x2
print(f"Matrix Multiplication: {out}")
tensor([[0.3680, 0.3761, 0.5536],
        [0.4293, 0.4596, 0.7517]])

Matrix Matrix Multiplication: tensor([[0.5834, 0.7470], [0.7470, 0.9605]])

Dot product first row with first row of matrix

import torch

x = torch.rand((2, 3))
print(x)

out = torch.einsum('i, i -> ', x[0], x[0])
print(f"Dot product of first row: {out}")
tensor([[0.2510, 0.1493, 0.2530],
        [0.4679, 0.5010, 0.5318]])

Dot product of first row: tensor(0.1493)

Dot product with a matrix

import torch

x = torch.rand((2, 3))
print(x)

out = torch.einsum('ij, ij -> ', x, x)
print(f"Dot product with a matrix: {out}")
tensor([[0.0405, 0.0519, 0.9728],
        [0.0479, 0.2052, 0.3783]])

Dot product with matrix: tensor(1.1382)

Element wise multiplication

import torch

x = torch.rand((2, 3))
print(x)

out = torch.einsum('ij, ij -> ij', x, x)
print(f"Element wise multiplication: {out}")
tensor([[0.3876, 0.6711, 0.0713],
        [0.3814, 0.8174, 0.8426]])

Element wise multiplication: tensor([[0.1502, 0.4503, 0.0051], [0.1455, 0.6681, 0.7100]])

Outer Product

import torch

a = torch.rand((3))
b = torch.rand((2))
print(a)
print(b)

out = torch.einsum('i, j -> ij', a, b)
print(f"Outer Product: {out}")
tensor([0.3336, 0.4734, 0.1117])
tensor([0.8645, 0.2134])

Outer Product:
tensor([[0.2884, 0.0712],
        [0.4092, 0.1010],
        [0.0966, 0.0238]])

Batch Matrix Multiplication

a = torch.rand((3, 2, 5))
b = torch.rand((3, 5, 3))
print(a)
print(b)

out = torch.einsum('ijk, ikl -> ijl', a, b)
print(f"Batch Matrix: {out}")
tensor([[[0.5067, 0.8090, 0.0088, 0.9936, 0.1828],
         [0.2544, 0.7564, 0.8248, 0.5977, 0.0694]],

        [[0.8260, 0.4089, 0.7738, 0.5748, 0.1143],
         [0.6959, 0.9731, 0.6790, 0.4556, 0.2643]],

        [[0.7593, 0.6693, 0.3735, 0.0406, 0.1825],
         [0.2616, 0.2867, 0.4369, 0.0290, 0.1599]]])

tensor([[[0.9256, 0.7068, 0.1455],
         [0.6793, 0.4292, 0.7994],
         [0.9526, 0.5082, 0.0190],
         [0.0798, 0.7878, 0.5322],
         [0.5871, 0.2096, 0.5116]],

        [[0.4090, 0.6878, 0.2609],
         [0.4642, 0.5034, 0.4563],
         [0.8460, 0.1870, 0.3020],
         [0.1418, 0.9659, 0.9026],
         [0.9834, 0.2295, 0.5935]],

        [[0.5233, 0.2923, 0.2487],
         [0.9908, 0.5659, 0.6832],
         [0.5437, 0.0347, 0.1832],
         [0.9032, 0.3923, 0.4762],
         [0.0904, 0.2979, 0.2598]]])

Batch Matrix:
tensor([[[1.2136, 1.5310, 1.3428],
         [1.6234, 1.4091, 1.0109]],

        [[1.3762, 1.5001, 1.2225],
         [1.6354, 1.5962, 1.3988]],

        [[1.3168, 0.6839, 0.7814],
         [0.6992, 0.3129, 0.3964]]])

Matrix Diagonal

import torch

x = torch.rand((3, 3))

print(x)

out = torch.einsum('ii -> i', x)
print(f"Diagonal: {out}")
tensor([[0.9194, 0.1331, 0.5036],
        [0.2516, 0.6575, 0.7909],
        [0.6403, 0.2983, 0.7398]])

Diagonal: tensor([0.9194, 0.6575, 0.7398])

Matrix Trace

import torch

x = torch.rand((3, 3))
print(x)

out = torch.einsum('ii ->', x)
print(f"Trace: {out}")
tensor([[0.9849, 0.3388, 0.4579],
        [0.4265, 0.5766, 0.1355],
        [0.4080, 0.5419, 0.5608]])

Trace: tensor(2.1224)

© Ankit Sharma 2025