Tensor Puzzles

I recently came across Sasha Rush’s Tensor Puzzles that was trending on twitter and had a lot of fun doing them so I wanted share my experience and hopefully convince you to give it a shot!

Puzzle format

There are 21 puzzles, each asking you to reconstruct Numpy/Torch tensor functions in a minimalist fashion. The challenge is to do so in one line, using as few characters as possible. The term “from scratch” is quite literal here, as you’re limited to the basic arange() and where() functions provided, along with any solutions from previously solved puzzles. Your operations are confined to @, arithmetic, comparison, shape, and various indexing methods like a[:j], a[:, None], a[arange(10)]. Notably absent from your toolkit are view, sum, take, squeeze, and tensor.

Reflections

An important aspect to consider is that these puzzles are designed to focus on broadcasting. However, upon solving them, I observed that only half of the puzzles actually require broadcasting, and even then, it is limited to transitioning from 1-D to 2-D tensors.

In my experience, once you grasp the concept of broadcasting, the true challenge lies in identifying and constructing the transformation matrix or mask that enables you to achieve the desired output from the given input.

Its normal to get stuck if you’re not familiar with the tricks needed to solve some of the problems but once you get the idea, its easy to build on the tricks to crack tougher problems. Just ask chatgpt for a hint or refer to the solutions.

Concepts

The first trick to know is how to construct a boolean mask. Puzzle 5 asks to construct the identity matrix which means only the main diagonal is 1 (True) and the rest is 0 (False). Thinking in terms of a 2-D array and loops, its simply about setting arr[i][j] = 1 if i == j.

To achieve this, we need some broadcasting and comparison techniques. Lets assume we want $I_{5}$, then lets use arange(5) and copy it 5 times row-wise and then turn it to a column and copy it 5 times column-wise:

rows = [
[0,1,2,3,4], 
[0,1,2,3,4],
[0,1,2,3,4],
[0,1,2,3,4],
[0,1,2,3,4]]

cols = [
[0,0,0,0,0],
[1,1,1,1,1],
[2,2,2,2,2],
[3,3,3,3,3],
[4,4,4,4,4]]

If you do an equality comparison between rows and cols then you get the answer. So then how do we get to rows and cols? arange(n) gives a row vector of shape (n), arange(n)[:, None] gives a column vector of shape (n, 1). Then you just compare them like so: arange(n) == arange(n)[:, None]. But how can you compare 2 tensors of completely different shapes? Thats where broadcasting comes in and copies each tensor across the required dimension to match the shapes on the left and right of the operator, making the computation possible.

This concept will allow you to solve puzzles like triu, cumsum, vstack, pad_to and sequence_mask.

The next trick to know is advanced indexing. I got really stuck at flatten because I did not know that its possible to index and extract values beyond the size of the tensor. For example, its possible to do eye(3)[ones(10)] to get 10 copies of the second row like this:

tensor([[0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0]])

Takeaway here is that you can extract what and how much you want from a tensor by indexing, you just have to think about how to write the query.

So to write flatten, we need to extract each row sequentially so we just need to construct the sequence of “coordinates”. Here is the solution: a[arange(i*j) //j, arange(i*j) % j]

The last few tricks to know are to know how to make use of where() to modify and replace values tensors based on some conditions; knowing how to “accumulate” values by making use of matmuls; and knowing how to creatively combine techniques to solve problems.

More Puzzles

If the tensor puzzles were a breeze for you, here are 3 more puzzles with to test slightly different/advanced concepts.

More broadcasting (across more dimensions)

Given two tensors A and B:

  • A is a 1-D tensor with shape (i,)
  • B is a 2-D tensor with shape (j, k) Create a new 4D tensor C with shape (i, j, k, i) where each element C[m, n, p, q] is calculated as follows:
  • If q is equal to m, then C[m, n, p, q] should be the product of A[m] and B[n, p].
  • Otherwise, C[m, n, p, q] is zero.
def complex_broadcast_multiply(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    i, j, k = A.shape[0], B.shape[0], B.shape[1]
    return (A[:, None, None, None] * B[None, :, :, None]) * (torch.eye(i)[:, None, None, :])

Sliding Window Sum (useful for moving average)

Compute the sum of elements in a sliding window of size k across a 1D tensor, without using any built-in convolution functions.

def sliding_window_sum(a: torch.Tensor, k: int) -> torch.Tensor:
    # Create an index matrix that shifts by one position for each row
    idx = torch.arange(len(a) - k + 1)[:, None] + torch.arange(k)
    windows = a[idx]
    return windows.sum(dim=1)

Tensor Rotation (this is not a transpose!)

Rotate a 2D tensor by 90 degrees clockwise without using any built-in rotation functions.

def rotate_90_clockwise(a: torch.Tensor) -> torch.Tensor:
    # Reverse the order of the rows
    reversed_rows = a[torch.arange(a.size(0)-1, -1, -1), :]
    # Swap rows with columns by indexing  
    return reversed_rows[:, torch.arange(a.size(0)-1, -1, -1)]

Wrapping up

Once again, thanks to Sasha Rush for sharing these puzzles and walking through the solutions! He has several other puzzles like GPU puzzles, transformer puzzles and more that I plan on doing soon!

Its important to note that the implementation used to solves these puzzles are not at all how the functions are actually written. Under the hood, these functions are written in C and make use of loops and exploit the structure of how these tensors are stored in memory. The puzzles are just good practice.