3. Advanced Vectorization#
In the previous guide on Jax
, we introduced the basic building blocks—multi-dimensional arrays—as well as a new paradigm for writing code without loops or if/else-expressions. Here, we continue building on this paradigm with more advanced (yet common) uses of Jax
in ML.
Let’s get ourselves started by importing some libraries.
import jax.numpy as jnp
import chex
Acknowledgement. Parts of this tutorial have been adapted from this NumPy tutorial.
3.1. Indexing with Boolean Arrays#
Whereas slicing is typically used to extract contiguous chunks from an array (i.e. a subsection of elements that belonged together in the original array), indexing will allow us to extract non-contiguous parts. How? It’s easier shown than explained.
You may remember from the introductory guide to Jax
that performing boolean operations on an array returns a boolean array. Here’s an example:
a = jnp.arange(10)
print('a =', a)
print('is even =', a % 2 == 0)
a = [0 1 2 3 4 5 6 7 8 9]
is even = [ True False True False True False True False True False]
By using these boolean arrays, we can select all elements for which the boolean expression a % 2 == 0
is True
:
b = a[a % 2 == 0]
print('Even elements from a =', b)
Even elements from a = [0 2 4 6 8]
Using this, you can easily write functions that have some very complicated behaviors. As an example, the function,
can be implemented as follows:
def f(x):
is_even = (x % 2 == 0)
return is_even * (x ** 2.0) + ~is_even * (x ** 3.0)
print(f(jnp.arange(10)))
[ 0. 1. 4. 27. 16. 125. 36. 343. 64. 729.]
In the above, we first created a boolean variable is_even
, which has values True
only at indices where \(x\) is even. By multiplying is_even
times \(x\), we automatically cast it to an integer. This means that wherever is_even
is True
, it becomes 1, and when it’s false, it becomes 0. As such, is_even * (x ** 2.0)
returns an array in which every element is \(x^2\) when \(x\) is even and is \(0\) otherwise. In contrast, ~is_even * (x ** 3.0)
returns an array in which every element is 0 when \(x\) is even and is \(x^3\) otherwise (the ~
notation flips the boolean values). When we add the two resultant arrays, we get the answer we were looking for.
While the above example has pedagogical value, the logic does become a little obfuscated with all of the casting and such. Here’s another, cleaner way to implement the above function using jnp.where
:
def f(x):
return jnp.where(x % 2 == 0, x ** 2.0, x ** 3.0)
print(f(jnp.arange(10)))
[ 0. 1. 4. 27. 16. 125. 36. 343. 64. 729.]
Here, jnp.where
uses the condition x % 2 == 0
(is even) to select elements from one of two arrays, \(x^2\) or \(x^3\).
Lastly, remember that an a multi-dimensional array can be interpreted as an array of arrays. For example, an array of shape \((N, D)\) can be thought of an array of \(N\) elements, each of which is an array with \(D\) elements. As such, indexing into an \((N, D)\) array using a 1-dimensional boolean array will select rows. To show how this works, suppose we have the array below:
a = jnp.arange(12).reshape(4, 3) + 1
print('a =')
print(a)
a =
[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]]
And suppose we’d like to select all rows containing a number that’s divisible by 5. First, let’s see which numbers are divisible by 5:
a % 5 == 0
Array([[False, False, False],
[False, True, False],
[False, False, False],
[ True, False, False]], dtype=bool)
Next, we need to determine which rows have at least one True
. We can do this using .any(axis=-1)
, which returns true if there’s at least one True
. The axis=-1
tells any
to do this for rows:
contains_number_divisible_by_5 = (a % 5 == 0).any(axis=-1)
print('Row contains a number that is divisible by 5:', contains_number_divisible_by_5)
Row contains a number that is divisible by 5: [False True False True]
Finally, we can use contains_number_divisible_by_5
to index into a
to select the rows that contin a number divisible by 5:
print('Rows with number divisible by 5:')
print(a[contains_number_divisible_by_5])
Rows with number divisible by 5:
[[ 4 5 6]
[10 11 12]]
Note that in addition to any
, there’s also all
, which is True
when all elements are True
.
Exercise: Indexing with Boolean Arrays
Please solve the following using Jax
library calls only (no loops, no if/else!):
Part 1: Write a function that takes in a 2-dimensional array a
and returns only rows whose sum is positive. For example, given,
a = jnp.array([
[1.0, 2.0, 3.0],
[-1.0, -2.0, 3.0],
[-1.0, -2.0, 4.0],
[-1.0, -2.0, -3.0],
])
the result should be:
Array([[ 1., 2., 3.],
[-1., -2., 4.]], dtype=float32)
Use the following function signature:
def boolean_indexing_q1(a):
pass # TODO implement
Part 2: Write a function that takes in an integer \(N > 0\), a coordinate \(x, y\), and a radius \(r\) (where \(0 \leq x < N\), \(0 \leq y < N\), and \(r > 0\)). The function should return an integer array of shape \((N, N)\) in which every element is 0 except elements that are within radius \(r\) of \((x, y)\) (i.e. every element \(i, j\) should be 1 if \((x - i)^2 + (y - j)^2 \leq r^2\) and 0 otherwise). For example, for \(N = 10\), \(x = 5\), \(y = 5\), and \(r = 2\), the function should return:
Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)
Similarly, for \(N = 10\), \(x = 5\), \(y = 0\), and \(r = 2\), the function should return:
Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)
Use the following function signature:
def boolean_indexing_q2(N, x, y, r):
pass # TODO implement
3.2. Indexing with Arrays of Indices#
Sometimes, it’s useful to select elements an array not just using a single index, but using an array of indices.
a = jnp.arange(12) ** 2
print('a =', a)
a = [ 0 1 4 9 16 25 36 49 64 81 100 121]
i = jnp.array([1, 1, 3, 8, 5]) # a 1-dimensional array of indices
print('a[i] =', a[i])
a[i] = [ 1 1 9 64 25]
Recall again that a multi-dimensional array can be interpreted as an array of arrays. As such, indexing into a an \((N, D)\) array using a 1-dimensional index array will select rows:
a = jnp.arange(12).reshape(6, 2)
print('a =')
print(a)
a =
[[ 0 1]
[ 2 3]
[ 4 5]
[ 6 7]
[ 8 9]
[10 11]]
i = jnp.array([2, 3, 5]) # a 1-dimensional array of indices
print('a[i] =')
print(a[i])
a[i] =
[[ 4 5]
[ 6 7]
[10 11]]
Lastly, instead of selecting whole rows, we can also index into each dimension independently. That is, we select elements from an array by giving it the “coordinates” of the elements we want as follows:
a = jnp.arange(12).reshape(6, 2)
print('a =')
print(a)
a =
[[ 0 1]
[ 2 3]
[ 4 5]
[ 6 7]
[ 8 9]
[10 11]]
i = jnp.array([3, 5, 2])
j = jnp.array([0, 1, 1])
print('selected =', a[(i, j)])
selected = [ 6 11 5]
By indexing into a
using the tuple (i, j)
, we paired each element of \(i\) with each element of \(j\) to select items at indices: \((3, 0)\), \((5, 1)\), and \((2, 1)\).
Exercise: Indexing with Arrays of Indices
Please solve the following using Jax
library calls only (no loops, no if/else!):
Part 1: Write a function that, given a square matrix a
, returns the elements along its diagonal. For example, given,
a = [[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]]
the function should return, [0, 5, 10, 15]
. Note that there already exists a function in Jax
that does exactly this: jnp.diagonal
. Please do not use it for the sake of this exercise!
Use the following function signature:
def integer_indexing_q1(a):
pass # TODO implement
Part 2: Extend the function from the previous part to include a second, keyword argument, offset=0
, to offset the diagonal. Continuing with the above example,
When
offset=0
, the function’s behavior should not change from the previous part.When
offset=1
, the function should return:[1, 6, 11]
.When
offset=2
, the function should return:[2, 7]
.When
offset=-1
, the function should return:[4, 9, 14]
.When
offset=-2
, the function should return:[8, 13]
.
Again, this functionality is already built-in to jnp.diagonal
, but for the purpose of the exercise, implement this yourself.
Use the following function signature:
def integer_indexing_q2(a, offset=0):
pass # TODO implement
3.3. Broadcasting#
The term broadcasting describes how Jax
treats arrays with different shapes during arithmetic operations. Subject to certain constraints, the smaller array is “broadcast” across the larger array so that they have compatible shapes. Broadcasting provides a means of vectorizing array operations so that looping occurs very fast under the hood without making needless copies of data. Broadcasting follows two rules:
Rule 1: If all input arrays do not have the same number of dimensions, a “1” will be repeatedly prepended to the shapes of the smaller arrays until all the arrays have the same number of dimensions.
Example: If we were to add an array of shape \((N, D)\) to an array of shape \((D,)\), under the hood, broadcasting will change the second array’s shape to \((1, D)\). Notice that the arrays are still not the same shape, so how can we add them? This is where Rule 2 comes in.
Rule 2: Arrays with a size of 1 along a particular dimension act as if they had the size of the array with the largest shape along that dimension. The value of the array element is assumed to be the same along that dimension for the “broadcast” array.
Example: If we were to add an array of shape \((N, D)\) to an array of shape \((1, D)\), under the hood, broadcasting would repeat the \((1, D)\)-array \(N\) times along it’s 0th axis. This would mean both arrays are now of shape \((N, D)\) and can be added elementwise.
Here’s the example in code:
a = jnp.arange(12).reshape(3, 4)
print('a =')
print(a)
a =
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
print('a + [1, 2, 3, 4] =')
print(a + jnp.array([1, 2, 3, 4]))
a + [1, 2, 3, 4] =
[[ 1 3 5 7]
[ 5 7 9 11]
[ 9 11 13 15]]
And of course, the above addition can be replaced with any other operation and Jax
will broadcast in the same way.
While broadcasting is a powerful tool, it can sometimes catch us by surprise. For example, what would happen if we were to subtract a \((N, 1)\) array from a \((1, M)\) array? Let’s see:
a = jnp.array([[1, 2, 3, 4]])
b = jnp.array([[5], [8], [12]])
print('a =', a)
print('b =')
print(b)
a = [[1 2 3 4]]
b =
[[ 5]
[ 8]
[12]]
c = b - a
print('c =')
print(c)
c =
[[ 4 3 2 1]
[ 7 6 5 4]
[11 10 9 8]]
The resultant array c
is of shape \((N, M)\), containing differences between every pair of elements in a
and b
. That is, c[i, j] = b[i] - a[j]
. There are many applications in ML in which computing pairwise differences is helpful, and this is one great way to do it!
However, if you forget to check the shapes of your arrays before operating on them, you may also trigger some broadcasting behavior unintentionally, resulting in bugs that are hard to find. For this reason, we recommend adding shape-checking code after every line that operates of arrays of unequal shapes. The shape-checking code should error if the resultant array does not have its desired shape. How can we do this? See the section below about the library chex
.
Exercise: Broadcasting
Please solve the following using Jax
library calls only (no loops, no if/else!):
Part 1: Write a function that accepts two arrays, a
of shape \((N, D)\) and b
of shape \((M, D)\). Write a function that computes mean squared difference between their rows (pairwise). That is, return an array c
of shape \((N, M)\), in which \(c_{i,j} = \sum\limits_{d=0}^{D-1} (b_{j,d} - a_{i,d})^2\).
For example, given,
a = jnp.array([
[1.0, 2.0],
[2.0, 4.0],
[5.0, 6.0],
])
b = jnp.array([
[5.0, 3.0],
[4.0, 1.0],
[6.0, 6.0],
[7.0, 1.0],
])
the function should return:
Array([[17., 10., 41., 37.],
[10., 13., 20., 34.],
[ 9., 26., 1., 29.]], dtype=float32)
Use the following function signature:
def broadcasting_q1(a, b):
pass # TODO implement
Part 2: Extend your function that evaluates polynomial on a batch of \(x\)’s at once. That is, write a function that takes in the degree of the polynomial, \(N\), an array of the \(N + 1\) coefficients, \(a\), and an array \(x\) of size \(B\). Your function should return:
Test cases:
Given \(N = 3\), \(a = [1.0, -1.0, 0.0, 2.0]\), and \(x = [1.0, 2.0, 3.0, 4.0, 5.0]\), your function should return \([2.0, 6.0, 20.0, 50.0, 102.0]\).
Given \(N = 2\), \(a = [5.0, -2.0, 3.0]\), and \(x = [1.0, 2.0, 3.0, 4.0, 5.0]\), your function should return \([6.0, 19.0, 42.0, 75.0, 118.0]\).
Use the following function signature:
def broadcasting_q2(a, b):
pass # TODO implement
For more information on broadcasting, checkout this tutorial.
3.4. Catching Bugs Early with chex
#
What’s chex
? As we just saw, while powerful, Jax
’s broadcasting can subtly sneak in bugs into your code. Luckily, there’s a great Python library that can help us: chex
. chex
is a library for asserting Jax
array shapes.
What are assertions? Assertions are lines in your code that raise an exception if something that should be true is false. They are useful for catching bugs and for debugging. For example, if you write a function that takes in argument \(p \in [0, 1]\), you can write:
assert(0.0 <= p and p <= 1.0)
In this way, whenever \(p \notin [0, 1]\), your code will raise an exception. This tells you not only that something went wrong, it also tells you where and why it went wrong, speeding up your debugging process.
Best practices: asserting array shapes. In Jax
, we will use it to ensure that all arguments passed into any function we write have the correct shapes, and that every operation we perform results in the correct shape. When chex
raises an exception, the statement that it failed on will give us a clue as to where the bug is.
Why use chex
assertions instead of Python’s built-in assert
? As we will see later, to speed things up, Jax
sometimes removes lines of code that do not affect the output. Since Python’s built-in assert
technically does not return anything, Jax
might optimize it out. chex
provides specialized assertions that do not get optimized out by Jax
. As such, as a rule of thumb:
Use
chex
onJax
arraysUse
assert
on non-Jax
variables (e.g. checking a float is within range, a list has the right length, etc.)
Example. Let’s give it a go! Suppose you have a function that takes an array and splits into two parts. Specifically, the function,
Takes an array of shape \((N, D)\) and a variable \(p \in [0, 1]\) (that determines the size of the split).
Returns two array of shape \((p \cdot N, D)\) and \(((1 - p) \cdot N, D)\).
We can use chex
to ensure the following properties are true:
The input is a 2-dimensional array. In
chex
lingo, it is a “rank-2” array.The size of the 0th dimension of the resultant arrays should sum to \(N\).
We can implement this as follows:
def split_2d_array(a, p):
# ensure 'p' is between 0 and 1
assert(0.0 <= p and p <= 1.0)
# ensures 'a' is a 2-dimensional array
chex.assert_rank(a, 2)
partition_size = int(p * a.shape[0])
part1 = a[:partition_size]
part2 = a[partition_size:]
# asserts that the shapes of the resultant arrays sums to a.shape[0]
chex.assert_shape(part1, (partition_size, a.shape[1]))
chex.assert_shape(part2, (a.shape[0] - partition_size, a.shape[1]))
return part1, part2
If you were to pass in an array that is not 2-dimensional, or if there was a bug in the array-splitting code inside the function, chex
would have raised an exception (try it!)
You can find the chex
documentation here. For this course, we will only focus on assertions that have to do with Jax
array shapes. That is, only consider the following:
assert_shape
assert_equal_shape
assert_axis_dimension
assert_equal_shape_suffix
assert_equal_rank
assert_rank
Exercise: Assertions
Go back and add assertions to every function you implemented as part of this tutorial. Your assertions should, at the very least, check the inputs and outputs for correct shapes.