8. Maximum Likelihood: Theory#
# Import some helper functions (please ignore this!)
from utils import *
Context: At this point, our modeling toolkit is already getting quite expressive.
We can develop simple predictive models using conditional distributions: we can specify models of the form \(p_{A | B}(a | b)\), which allow us to predict the probability that \(A = a\) given that \(B = b\). We do this by specifying a distribution over random variable (RV) \(A\), whose parameters are a function of \(b\).
We can develop simple generative models using joint distributions: we can specify models of the form \(p_{A, B}(a, b)\), which allow us to sample (or generate) data. We do this by factorizing this joint probability into a product of conditional and marginal distributions, e.g. \(p_{A, B}(a, b) = p_{A | B}(a | b) \cdot p_B(b)\), which we already know how to specify, sample from, and evaluate.
Of course, the predictive and generative models you may have heard about in the news are capable of doing more than the instances we’ve covert so far—we will build up to these fancy models over the course of the semester. What’s important for now, though, is that you understand how such models can be represented using probability distributions.
Challenge: So what stands in our way of applying our modeling tools to real-world data? First, we’ve only instantiated our models with discrete distributions. Many real-world data, however, requires continuous distributions; that is, distributions over real numbers (e.g. blood pressure, body-mass index, time spent in REM sleep, etc.). We’ll get more into the details of continuous modeling a bit later. Our second obstacle is: we still don’t have a way of automatically fitting a model to data. So far, you’ve fit all models to data by hand via inspection—you looked at the data and tried to match the model to the data. With increasing model and data complexity, it becomes prohibitively difficult to fit the model to the data by hand. Today, we’ll introduce one technique for doing this: maximum likelihood estimation (MLE). This is the first algorithm we cover that allows the machine to “learn” from data.
The idea behind MLE is to find a model under which the probability of the data is highest. The intuition behind the MLE is that a model that scores the observed data as likely could have reasonably generated the data.
Outline:
Formally introduce and motivate the MLE.
Extend notation of directed graphical models (DGMs) to represent a full data-set instead of just one observation.
Understand theoretical properties of the MLE.
This will give us the framing we need in order to implement MLE in NumPyro
.
8.1. MLE: Notation and Formalism#
The idea behind the MLE is to find the model parameters that maximize the probability of the data. Let’s introduce some notation to help us formalize what this means mathematically.
Notation for Data. Let \(\mathcal{D}\) denote our all of our observed data (\(\mathcal{D}\) represents the entirety of the above table). Let \(\mathcal{D}_n\) represent observation number \(n\) (i.e. row \(n\)) from the table. \(\mathcal{D}_n\) is a tuple of values at each of the columns: \(\mathcal{D}_n = (d_n, c_n, h_n, a_n, k_n)\). Recall that we define:
\(D\): Day-of-Week
\(C\): Condition
\(H\): Hospitalized
\(A\): Antibiotics
\(K\): Knots
Notation for Parameters. For simplicity, we’ve omitted the notation for each distribution’s parameters from the notation so far. From now on, we’ll explicitly write out the parameters as arguments to the distribution by listing them after a semi-colon.
Consider our running example: modeling the joint probability of a patient arriving on a specific day, \(D = d\), with/without intoxication \(I = i\) via \(p_{I, D}(i, d) = p_{I | D}(i | d) \cdot p_D(d)\). Here, \(p_D(\cdot)\) is a categorical, which means it relies on a parameter, \(\pi\), a 7-dimensional array consisting of the probabilities of patients arriving on each day of the week. To denote this explicitly, we write:
wherein, we explicitly write \(\pi\) after a semi-colon. We can then do the same for the conditional \(p_{I | D}(\cdot | d)\), which relies on a parameter we’ll call \(\rho\), a 7-dimensional array consisting of the probability of intoxication on each day of the week. We can denote this parameter explicitly via:
Here, \(\rho_d\) is the \(d\)-th element of \(\rho\), representing the probability of intoxication on day \(d\). (Note that previously we used \(\rho(d)\) instead of \(\rho_d\)—these are two different ways of expressing the same idea).
Finally, to explicitly denote the parameters of the joint distribution, we will list all of them behind the semi-colon. That is:
We will group them together \(\theta = \{ \rho, \pi \}\) and say we want to learn \(\theta\) from the data.
The MLE Objective. Let \(\theta\) denote the set of all parameters used in our model for the IHH ER data. Using the above notation, \(p(\mathcal{D}; \theta)\) denotes the probability of the observed data; we will call it the joint data likelihood, since it is the joint distribution of all observations. Our goal is then to find the parameters \(\theta\) that maximize the probability of having observed \(\mathcal{D}\):
wherein “argmax” denotes the value of \(\theta\) that maximizes the joint probability. So what does it mean to evaluate the probability of the whole data, \(\mathcal{D}\), under our model, \(p_{D, C, H, A, K}\)? It means evaluating the joint distribution of all observations, \(\mathcal{D}_n = (d_n, c_n, h_n, a_n, k_n)\) for every \(n = 1, \dots, N\):
where \(N\) is the total number of observations.
Now, recall that every joint distribution can be factorized into a product of conditional and marginal distributions, and that the number of possible factorizations grows unwieldy very quickly with the number of variables. Since the number of variables in this joint distribution is a function of the number of observations, \(N\), which is large (e.g. thousands), we need some way to select a reasonable factorization. As typical, we are going to assume that the observations are independent, and identically distributed (i.i.d). This means that one patient coming to the ER does not tell us anything about how likely other patients are to come to the ER. Now, recall that when two RVs are independent, their joint distribution equals a product of their marginals. We can therefore factorize the joint distribution as follows:
We have now arrived at a formula for the joint distribution that we know how to compute—we’ve even written code to evaluate it in NumPyro
.
Numerical Stability: Notice that since our joint is a discrete probability distribution, it outputs probabilities between 0 and 1: \(p_{D, C, H, A, K}(d_n, c_n, h_n, a_n, k_n; \theta) \in [0, 1]\). In other words, it outputs fractions. In the above formula, we then multiply these fractions times one another \(N\) times. But what happens when you multiply fractions together many times? Answer: the results shrinks towards 0 very quickly (try it yourself!). This is a problem, because our computer can only represent small numbers up to a finite precision. For a large \(N\), our computer will round down the answer to \(0\), which will prevent us from performing the argmax. Because of this issue, we have to transform our original MLE objective into a problem that a computer can numerically solve.
We do this by maximizing the \(\log\) of the joint probability for two reasons:
Logs turn products into sums: \(\log (x \cdot y) = \log x + \log y\). Applying this formula to our MLE objective results in a sum of fractions, which is numerically stable:
But by maximizing the \(\log\) of the joint probability instead, will we get the wrong answer? Because the \(\log\) function is a strictly increasing function, our maxima will remain in the same location. That is:
To illustrate point (2), check out the graph below, which shows that the argmax of a function doesn’t change if a \(\log\) is applied to it.
Optimization: So at this point, we can compute the MLE objective for specific choices of \(\theta\), but we don’t know yet how to perform the argmax operation. We’ll introduce this concept a bit later in the course. For now, we’ll provide you with a function that can perform the maximization.
8.2. Graphically Representing I.I.D Observations and Model Parameters#
Before implementing the MLE in NumPyro
, we will extend our Directed Graphical Model (DGM) representation to include i.i.d observations and model parameters. This will help us in the translation process from math to code. Suppose we have a simple joint distribution over two RVs, \(A\) and \(B\), whose conditional controlled by a parameter, \(\theta\), as follows:
Suppose further that we have \(N\) i.i.d observations from this joint distribution. That is, we have \(\mathcal{D}_n = (a_n, b_n)\) for \(n \in 1, \dots, N\). This gives us the following joint data likelihood:
How would we represent this graphically? The answer is a little messy:
Each pair \((A_n, B_n)\) get its own arrow to signify the conditional dependence of \(B_n\) on \(A_n\). And since every pair depends on the same parameter, \(\theta\) has an arrow pointing into every \(B_n\).
Representing Parameters. In the above, notice that circles are only used for RVs. Since \(\theta\) is not an RV, it is not inside a circle—it’s represented by a dot instead.
Representing I.I.D Observations: For more complicated models, like the IHH ER you’ve already developed, this graphical representation becomes too difficult to read. As a result, we use the following short-hand:
In this representation, we introduce a “plate” (the rectangle surrounding \(A_n\) and \(B_n\)). The plate denotes that what’s inside should be repeated \(N\) times, where \(N\) is written in the bottom-right corner. Why is this called a plate? Do you eat off of rectangular plates at home? This shall remain a mystery to us all…
A note on conditional independence: We note that in this example model, for the IHH ER model you’ve developed, and generally for the models we consider in this class, the observations are only i.i.d given the model parameters. That is, given \(\theta\), we can factorize \(p(\mathcal{D}; \theta)\) into \(\prod_{n=1}^N p(\mathcal{D}_n; \theta)\). However, if we do not “condition” on the parameters, \(\theta\), the observations do carry knowledge about one another. That is, having observed \(\mathcal{D}_1\) can tell me something about \(\mathcal{D}_2\) because it tells me something about \(\theta\), which is shared across all observations. More on that later in the course.
Exercise: Translate between DGMs and joint distributions
Part 1: Extend the DGM for the IHH ER below to represent the joint distribution of the data. Additionally, include all parameters of all distributions.
Part 2: For each of the following DGMs, write down the joint distribution of all RVs. We’ve specifically selected models that are commonly used in ML. At this point in the course, we have not covered enough materials to make the connection between the model and their actual use—so don’t expect to understand exactly what they mean yet!
(i) Predictive Models (including regression and classification).
(ii) Bayesian Gaussian Mixture Models (used for clustering).
(iii) Latent Dirichlet Allocation (used for automatically extracting topics from text).
Part 3: Draw the DGM for each of the following models.
(iv) Conditional Subspace Variational Autoencoder (used for generating synthetic data, like pictures of celebrity faces)
(v) Hidden Markov Models (used for modeling time-series data).
Note: for this DGM, you will not be able to use plate notation. Instead, please use “\(\dots\)” to indicate a repeating pattern.
8.3. Theoretical Properties of the MLE#
In this course, our main goal is to focus on model specification—the process of creating a probabilistic model, and understanding the consequences of our modeling assumptions on downstream tasks. As such, we will not get into the nitty gritty of how models are fit to data. Nonetheless, it is important to informally highlight several theoretical properties of the MLE. The purpose of this is to help us understand, what can we expect of the MLE—will it behave like we want in different situations?
Desiderata: Let’s first highlight three properties we typically want for our learning algorithms. Informally, we define:
Consistency: As the number of observations approaches infinity, \(N \rightarrow \infty\), the parameters we learned from the data \(\theta^\text{learned}\) approach the true parameters of the model that generated the data \(\theta^\text{true}\).
Unbiasedness: Suppose we were able to collect several data sets from the same random phenomenon. For each of the data sets, suppose we were to then fit the same model and obtain \(\theta^\text{learned}\). For an unbiased learning algorithm, averaging all of the \(\theta^\text{learned}\)’s would yield \(\theta^\text{true}\).
Low-variance: Suppose we were able to collect several data sets from the same random phenomenon. For each of the data sets, suppose we were to then fit the same model and obtain \(\theta^\text{learned}\). For a low-variance learning algorithm, the average distance between the \(\theta^\text{learned}\)’s and \(\theta^\text{true}\) is small.
Why should we care about these properties? Consistency tells us that with sufficient data, our learning algorithm should converge to the true parameters. This is important! Would you want to use a learning algorithm that, with more data, becomes more and more wrong? Unbiasedness and low-variance, together, give us a notion of how quickly our learning algorithm converges to the true parameters with more data. These two often come at a tradeoff with one another. We won’t get into this for now.
Properties of the MLE: In order for the MLE to satisfy the above desiderata, we must make several assumptions, including that,
The model is well-specified—the observed data was generated from the same model we are fitting.
The model is identifiable—there are no two different sets of parameters that represent the same model.
The data was generated i.i.d.
For any model satisfying these assumptions, the MLE is:
Consistent
Asymptotically (as \(N \rightarrow \infty\)) unbiased
Asymptotically (as \(N \rightarrow \infty\)) minimum-variance