3. Advanced Vectorization#

In the previous guide on Jax, we introduced the basic building blocks—multi-dimensional arrays—as well as a new paradigm for writing code without loops or if/else-expressions. Here, we continue building on this paradigm with more advanced (yet common) uses of Jax in ML.

Let’s get ourselves started by importing some libraries.

import jax.numpy as jnp
import chex

Acknowledgement. Parts of this tutorial have been adapted from this NumPy tutorial.

3.1. Indexing with Boolean Arrays#

Whereas slicing is typically used to extract contiguous chunks from an array (i.e. a subsection of elements that belonged together in the original array), indexing will allow us to extract non-contiguous parts. How? It’s easier shown than explained.

You may remember from the introductory guide to Jax that performing boolean operations on an array returns a boolean array. Here’s an example:

a = jnp.arange(10)
print('a       =', a)
print('is even =', a % 2 == 0)
a       = [0 1 2 3 4 5 6 7 8 9]
is even = [ True False  True False  True False  True False  True False]

By using these boolean arrays, we can select all elements for which the boolean expression a % 2 == 0 is True:

b = a[a % 2 == 0]
print('Even elements from a =', b)
Even elements from a = [0 2 4 6 8]

Using this, you can easily write functions that have some very complicated behaviors. As an example, the function,

\[\begin{align*} f(x) &= \begin{cases} x^2 & \text{if $x$ is even} \\ x^3 & \text{if $x$ is odd} \\ \end{cases} \end{align*}\]

can be implemented as follows:

def f(x):
    is_even = (x % 2 == 0)
    return is_even * (x ** 2.0) + ~is_even * (x ** 3.0)

print(f(jnp.arange(10)))
[  0.   1.   4.  27.  16. 125.  36. 343.  64. 729.]

In the above, we first created a boolean variable is_even, which has values True only at indices where \(x\) is even. By multiplying is_even times \(x\), we automatically cast it to an integer. This means that wherever is_even is True, it becomes 1, and when it’s false, it becomes 0. As such, is_even * (x ** 2.0) returns an array in which every element is \(x^2\) when \(x\) is even and is \(0\) otherwise. In contrast, ~is_even * (x ** 3.0) returns an array in which every element is 0 when \(x\) is even and is \(x^3\) otherwise (the ~ notation flips the boolean values). When we add the two resultant arrays, we get the answer we were looking for.

While the above example has pedagogical value, the logic does become a little obfuscated with all of the casting and such. Here’s another, cleaner way to implement the above function using jnp.where:

def f(x):
    return jnp.where(x % 2 == 0, x ** 2.0, x ** 3.0)
    
print(f(jnp.arange(10)))
[  0.   1.   4.  27.  16. 125.  36. 343.  64. 729.]

Here, jnp.where uses the condition x % 2 == 0 (is even) to select elements from one of two arrays, \(x^2\) or \(x^3\).

Lastly, remember that an a multi-dimensional array can be interpreted as an array of arrays. For example, an array of shape \((N, D)\) can be thought of an array of \(N\) elements, each of which is an array with \(D\) elements. As such, indexing into an \((N, D)\) array using a 1-dimensional boolean array will select rows. To show how this works, suppose we have the array below:

a = jnp.arange(12).reshape(4, 3) + 1
print('a =')
print(a)
a =
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]

And suppose we’d like to select all rows containing a number that’s divisible by 5. First, let’s see which numbers are divisible by 5:

a % 5 == 0
Array([[False, False, False],
       [False,  True, False],
       [False, False, False],
       [ True, False, False]], dtype=bool)

Next, we need to determine which rows have at least one True. We can do this using .any(axis=-1), which returns true if there’s at least one True. The axis=-1 tells any to do this for rows:

contains_number_divisible_by_5 = (a % 5 == 0).any(axis=-1)
print('Row contains a number that is divisible by 5:', contains_number_divisible_by_5)
Row contains a number that is divisible by 5: [False  True False  True]

Finally, we can use contains_number_divisible_by_5 to index into a to select the rows that contin a number divisible by 5:

print('Rows with number divisible by 5:')
print(a[contains_number_divisible_by_5])
Rows with number divisible by 5:
[[ 4  5  6]
 [10 11 12]]

Note that in addition to any, there’s also all, which is True when all elements are True.

Exercise: Indexing with Boolean Arrays

Please solve the following using Jax library calls only (no loops, no if/else!):

Part 1: Write a function that takes in a 2-dimensional array a and returns only rows whose sum is positive. For example, given,

a = jnp.array([
    [1.0, 2.0, 3.0],
    [-1.0, -2.0, 3.0],
    [-1.0, -2.0, 4.0],
    [-1.0, -2.0, -3.0],
])

the result should be:

Array([[ 1.,  2.,  3.],
       [-1., -2.,  4.]], dtype=float32)

Use the following function signature:

def boolean_indexing_q1(a):
    pass # TODO implement

Part 2: Write a function that takes in an integer \(N > 0\), a coordinate \(x, y\), and a radius \(r\) (where \(0 \leq x < N\), \(0 \leq y < N\), and \(r > 0\)). The function should return an integer array of shape \((N, N)\) in which every element is 0 except elements that are within radius \(r\) of \((x, y)\) (i.e. every element \(i, j\) should be 1 if \((x - i)^2 + (y - j)^2 \leq r^2\) and 0 otherwise). For example, for \(N = 10\), \(x = 5\), \(y = 5\), and \(r = 2\), the function should return:

Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
       [0, 0, 0, 1, 1, 1, 1, 1, 0, 0],
       [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)

Similarly, for \(N = 10\), \(x = 5\), \(y = 0\), and \(r = 2\), the function should return:

Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)

Use the following function signature:

def boolean_indexing_q2(N, x, y, r):
    pass # TODO implement

3.2. Indexing with Arrays of Indices#

Sometimes, it’s useful to select elements an array not just using a single index, but using an array of indices.

a = jnp.arange(12) ** 2
print('a =', a)
a = [  0   1   4   9  16  25  36  49  64  81 100 121]
i = jnp.array([1, 1, 3, 8, 5]) # a 1-dimensional array of indices
print('a[i] =', a[i])
a[i] = [ 1  1  9 64 25]

Recall again that a multi-dimensional array can be interpreted as an array of arrays. As such, indexing into a an \((N, D)\) array using a 1-dimensional index array will select rows:

a = jnp.arange(12).reshape(6, 2)
print('a =')
print(a)
a =
[[ 0  1]
 [ 2  3]
 [ 4  5]
 [ 6  7]
 [ 8  9]
 [10 11]]
i = jnp.array([2, 3, 5]) # a 1-dimensional array of indices
print('a[i] =')
print(a[i])
a[i] =
[[ 4  5]
 [ 6  7]
 [10 11]]

Lastly, instead of selecting whole rows, we can also index into each dimension independently. That is, we select elements from an array by giving it the “coordinates” of the elements we want as follows:

a = jnp.arange(12).reshape(6, 2)
print('a =')
print(a)
a =
[[ 0  1]
 [ 2  3]
 [ 4  5]
 [ 6  7]
 [ 8  9]
 [10 11]]
i = jnp.array([3, 5, 2])
j = jnp.array([0, 1, 1])
print('selected =', a[(i, j)])
selected = [ 6 11  5]

By indexing into a using the tuple (i, j), we paired each element of \(i\) with each element of \(j\) to select items at indices: \((3, 0)\), \((5, 1)\), and \((2, 1)\).

Exercise: Indexing with Arrays of Indices

Please solve the following using Jax library calls only (no loops, no if/else!):

Part 1: Write a function that, given a square matrix a, returns the elements along its diagonal. For example, given,

a = [[ 0  1  2  3]
     [ 4  5  6  7]
     [ 8  9 10 11]
     [12 13 14 15]]

the function should return, [0, 5, 10, 15]. Note that there already exists a function in Jax that does exactly this: jnp.diagonal. Please do not use it for the sake of this exercise!

Use the following function signature:

def integer_indexing_q1(a):
    pass # TODO implement

Part 2: Extend the function from the previous part to include a second, keyword argument, offset=0, to offset the diagonal. Continuing with the above example,

  • When offset=0, the function’s behavior should not change from the previous part.

  • When offset=1, the function should return: [1, 6, 11].

  • When offset=2, the function should return: [2, 7].

  • When offset=-1, the function should return: [4, 9, 14].

  • When offset=-2, the function should return: [8, 13].

Again, this functionality is already built-in to jnp.diagonal, but for the purpose of the exercise, implement this yourself.

Use the following function signature:

def integer_indexing_q2(a, offset=0):
    pass # TODO implement

3.3. Broadcasting#

The term broadcasting describes how Jax treats arrays with different shapes during arithmetic operations. Subject to certain constraints, the smaller array is “broadcast” across the larger array so that they have compatible shapes. Broadcasting provides a means of vectorizing array operations so that looping occurs very fast under the hood without making needless copies of data. Broadcasting follows two rules:

Rule 1: If all input arrays do not have the same number of dimensions, a “1” will be repeatedly prepended to the shapes of the smaller arrays until all the arrays have the same number of dimensions.

Example: If we were to add an array of shape \((N, D)\) to an array of shape \((D,)\), under the hood, broadcasting will change the second array’s shape to \((1, D)\). Notice that the arrays are still not the same shape, so how can we add them? This is where Rule 2 comes in.

Rule 2: Arrays with a size of 1 along a particular dimension act as if they had the size of the array with the largest shape along that dimension. The value of the array element is assumed to be the same along that dimension for the “broadcast” array.

Example: If we were to add an array of shape \((N, D)\) to an array of shape \((1, D)\), under the hood, broadcasting would repeat the \((1, D)\)-array \(N\) times along it’s 0th axis. This would mean both arrays are now of shape \((N, D)\) and can be added elementwise.

Here’s the example in code:

a = jnp.arange(12).reshape(3, 4)
print('a =')
print(a)
a =
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
print('a + [1, 2, 3, 4] =')
print(a + jnp.array([1, 2, 3, 4]))
a + [1, 2, 3, 4] =
[[ 1  3  5  7]
 [ 5  7  9 11]
 [ 9 11 13 15]]

And of course, the above addition can be replaced with any other operation and Jax will broadcast in the same way.

While broadcasting is a powerful tool, it can sometimes catch us by surprise. For example, what would happen if we were to subtract a \((N, 1)\) array from a \((1, M)\) array? Let’s see:

a = jnp.array([[1, 2, 3, 4]])
b = jnp.array([[5], [8], [12]])

print('a =', a)
print('b =')
print(b)
a = [[1 2 3 4]]
b =
[[ 5]
 [ 8]
 [12]]
c = b - a
print('c =')
print(c)
c =
[[ 4  3  2  1]
 [ 7  6  5  4]
 [11 10  9  8]]

The resultant array c is of shape \((N, M)\), containing differences between every pair of elements in a and b. That is, c[i, j] = b[i] - a[j]. There are many applications in ML in which computing pairwise differences is helpful, and this is one great way to do it!

However, if you forget to check the shapes of your arrays before operating on them, you may also trigger some broadcasting behavior unintentionally, resulting in bugs that are hard to find. For this reason, we recommend adding shape-checking code after every line that operates of arrays of unequal shapes. The shape-checking code should error if the resultant array does not have its desired shape. How can we do this? See the section below about the library chex.

Exercise: Broadcasting

Please solve the following using Jax library calls only (no loops, no if/else!):

Part 1: Write a function that accepts two arrays, a of shape \((N, D)\) and b of shape \((M, D)\). Write a function that computes mean squared difference between their rows (pairwise). That is, return an array c of shape \((N, M)\), in which \(c_{i,j} = \sum\limits_{d=0}^{D-1} (b_{j,d} - a_{i,d})^2\).

For example, given,

a = jnp.array([
    [1.0, 2.0],
    [2.0, 4.0],
    [5.0, 6.0],
])

b = jnp.array([
    [5.0, 3.0],
    [4.0, 1.0],
    [6.0, 6.0],
    [7.0, 1.0],
])

the function should return:

Array([[17., 10., 41., 37.],
       [10., 13., 20., 34.],
       [ 9., 26.,  1., 29.]], dtype=float32)

Use the following function signature:

def broadcasting_q1(a, b):
    pass # TODO implement

Part 2: Extend your function that evaluates polynomial on a batch of \(x\)’s at once. That is, write a function that takes in the degree of the polynomial, \(N\), an array of the \(N + 1\) coefficients, \(a\), and an array \(x\) of size \(B\). Your function should return:

\[\begin{align*} \begin{bmatrix} \sum\limits_{n=0}^N a_n \cdot x_1^{N - n} \\ \sum\limits_{n=0}^N a_n \cdot x_2^{N - n} \\ \vdots \\ \sum\limits_{n=0}^N a_n \cdot x_B^{N - n} \end{bmatrix} \end{align*}\]

Test cases:

  • Given \(N = 3\), \(a = [1.0, -1.0, 0.0, 2.0]\), and \(x = [1.0, 2.0, 3.0, 4.0, 5.0]\), your function should return \([2.0, 6.0, 20.0, 50.0, 102.0]\).

  • Given \(N = 2\), \(a = [5.0, -2.0, 3.0]\), and \(x = [1.0, 2.0, 3.0, 4.0, 5.0]\), your function should return \([6.0, 19.0, 42.0, 75.0, 118.0]\).

Use the following function signature:

def broadcasting_q2(a, b):
    pass # TODO implement

For more information on broadcasting, checkout this tutorial.

3.4. Catching Bugs Early with chex#

What’s chex? As we just saw, while powerful, Jax’s broadcasting can subtly sneak in bugs into your code. Luckily, there’s a great Python library that can help us: chex. chex is a library for asserting Jax array shapes.

What are assertions? Assertions are lines in your code that raise an exception if something that should be true is false. They are useful for catching bugs and for debugging. For example, if you write a function that takes in argument \(p \in [0, 1]\), you can write:

assert(0.0 <= p and p <= 1.0)

In this way, whenever \(p \notin [0, 1]\), your code will raise an exception. This tells you not only that something went wrong, it also tells you where and why it went wrong, speeding up your debugging process.

Best practices: asserting array shapes. In Jax, we will use it to ensure that all arguments passed into any function we write have the correct shapes, and that every operation we perform results in the correct shape. When chex raises an exception, the statement that it failed on will give us a clue as to where the bug is.

Why use chex assertions instead of Python’s built-in assert? As we will see later, to speed things up, Jax sometimes removes lines of code that do not affect the output. Since Python’s built-in assert technically does not return anything, Jax might optimize it out. chex provides specialized assertions that do not get optimized out by Jax. As such, as a rule of thumb:

  • Use chex on Jax arrays

  • Use assert on non-Jax variables (e.g. checking a float is within range, a list has the right length, etc.)

Example. Let’s give it a go! Suppose you have a function that takes an array and splits into two parts. Specifically, the function,

  1. Takes an array of shape \((N, D)\) and a variable \(p \in [0, 1]\) (that determines the size of the split).

  2. Returns two array of shape \((p \cdot N, D)\) and \(((1 - p) \cdot N, D)\).

We can use chex to ensure the following properties are true:

  1. The input is a 2-dimensional array. In chex lingo, it is a “rank-2” array.

  2. The size of the 0th dimension of the resultant arrays should sum to \(N\).

We can implement this as follows:

def split_2d_array(a, p):
    # ensure 'p' is between 0 and 1
    assert(0.0 <= p and p <= 1.0)

    # ensures 'a' is a 2-dimensional array
    chex.assert_rank(a, 2)

    partition_size = int(p * a.shape[0])

    part1 = a[:partition_size]
    part2 = a[partition_size:]

    # asserts that the shapes of the resultant arrays sums to a.shape[0]
    chex.assert_shape(part1, (partition_size, a.shape[1]))
    chex.assert_shape(part2, (a.shape[0] - partition_size, a.shape[1]))
    
    return part1, part2

If you were to pass in an array that is not 2-dimensional, or if there was a bug in the array-splitting code inside the function, chex would have raised an exception (try it!)

You can find the chex documentation here. For this course, we will only focus on assertions that have to do with Jax array shapes. That is, only consider the following:

  • assert_shape

  • assert_equal_shape

  • assert_axis_dimension

  • assert_equal_shape_suffix

  • assert_equal_rank

  • assert_rank

Exercise: Assertions

Go back and add assertions to every function you implemented as part of this tutorial. Your assertions should, at the very least, check the inputs and outputs for correct shapes.