6. Joint Probability (Discrete)#
# Import some helper functions (please ignore this!)
from utils import *
Context: So far, you’ve spent some time conducting a preliminary exploratory data analysis (EDA) of IHH’s ER data. You noticed that considering variables separately can result in misleading information. As a result, you decided to use conditional distributions to model the relationship between variables. Using these conditional distributions, you were able to develop predictive models (e.g. predicting the probability of intoxication given the day of the week). These predictive models are useful for the IHH administration to make decisions.
However, you’ve noticed that your modeling toolkit is still limited. The conditional distributions we introduced can model how the probability of one variable changes given a set of variables. What if we wanted to describe how the probability of a set of variables (i.e. more than one) changes given a set of variables? For example, we may want to answer questions like: “how does the probability that a patient is hospitalized for an allergic reaction change given the day of the week?” In this question, we’re inquiring about two variables—that the condition is an allergic reaction, and that the patient was hospitalized—given the day of the week.
Challenge: We need to expand our modeling toolkit to include yet another tool—joint probabilities.
Outline:
Introduce and practice the concepts, terminology, and notation behind discrete joint probability distributions (leaving continuous distributions to a later time).
Introduce a graphical representation to describe joint distributions.
Translate this graphical representation directly into code in a probabilistic programming language (using
NumPyro
) that we can then use to fit the data.
6.1. Terminology and Notation#
We, again introduce the statistical language—terminology and notation—to precisely specify to a computer how to model our data. We will then translate statements in this language directly into code in NumPyro
that a computer can run.
Concept. The concept behind a joint probability is elegant; it allows us to build complicated distributions over many variables using simple conditional and non-conditional distributions (that we already covered).
We can illustrate this using an example with just two variables. Suppose you have two RVs, \(A\) and \(B\). The probability that \(A = a\) and \(B = b\) are both satisfied is called their joint probability. It is denoted by \(p_{A,B}(a, b)\). This joint distribution can be factorized to a product of conditional and non-conditional (or “marginal”) distributions as follows:
Notice that the joint is now described in terms of conditional and marginal distributions, which we already know how to work with!
Intuition. So what’s the intuition behind this formula? Let’s depict events \(A = a\) and \(B = b\) as follows:
In this diagram, each shaded area represents the probability of an event—i.e. area is proportional to probability. We use it to pictorially represent the marginal, conditional, and joint distributions. The marginal probability of \(B = b\), for example, is the ratio of the blue square relative to the whole space (the gray square):
Similarly, the marginal probability of \(A = a\) is the ratio of the red square relative to the gray square.
Next, the conditional \(p_{A | B}(a | b)\) is the ratio of the purple intersection relative to the blue square. This is because the blue square represents our conditioning on \(B = b\), and the purple intersection represents the probability that we also have \(A = a\).
Finally, the joint \(p_{A, B}(a, b)\) is the ratio between the purple intersection and the whole space (the gray square). This is because the intersection is where both \(A = a\) and \(B = b\).
Now we can see that the joint is the product of the conditional and the marginal because the blue squares “cancel out”:
Choice of Factorization. Lastly, notice that we have a choice to factorize the distribution in two ways. How do you know which one to use? Typically, we choose a factorization that is intuitive to us and what we can compute.
Example: Suppose you want to model the joint distribution of the day of the week, \(D\) and whether a patient arrive with intoxication, \(I\). The joint distribution can be factorized in two ways:
(6.2)#\[\begin{align} p_{D, I}(d, i) &= p_{I | D}(i | d) \cdot p_D(d) \quad \text{(Option 1)} \\ &= p_{D | I}(d | i) \cdot p_I(i) \quad \text{(Option 2)} \\ \end{align}\]Which one makes more intuitive sense? Well, it’s a little weird to try to predict the day of the week given whether a patient arrives with intoxication; we typically know what the day of the week is and we don’t need to predict it. In contrast, given the day of the week, it makes a lot of sense to wonder about the probability of a patient arriving with intoxication. As such, Option 1 makes more sense here.
Generalizing to More than Two RVs. So now we have the tools to work with joint distributions with two RVs. What do we do if we have three or more? The same ideas apply. The joint distribution for random variables \(A\), \(B\), and \(C\) can be factorized in a number of ways. For example, we can condition on two variables at a time:
wherein the above, we already know how to factorize \(p_{B, C}(b, c)\), \(p_{A, C}(a, c)\), and \(p_{A, B}(a, b)\) (since they are joint distributions with two variables).
We can also condition on one variable at a time:
And how do we further factorize distributions of the form \(p_{A, B | C}(a, b | c)\)? We apply the same factorization for a joint distribution with two variables, and simply add a “conditioned on \(C\)” to each one:
6.2. Directed Graphical Models (DGMs)#
As you may have already noticed, the number of possible ways to factorize a joint distribution increases very quickly with the number of RVs. In fact, the more RVs we have, the more unwieldy it becomes for us as data analysts to specify each component in the factorization. What can we do to simplify our model? Often, we can use our domain knowledge (knowledge of the specifics of the problem) to simplify the joint distribution. Specifically, we’ll use our knowledge of “conditional independence” to do this. Let’s get started by first introducing the idea of statistical independence.
Example: Suppose you want to model the joint distribution of: a patient’s true temperature, \(T\), a thermometer’s measurement of this temprature, \(M\), and the doctor’s decision whether or not the patient has fever, \(F\). As we saw above, this joint distribution, \(p_{F, M, T}(f, m, t)\), which has three RVs, has many factorizations. Most of these, however, don’t make any sense to a domain expert. That is, the doctor’s decision, \(F\), is only based on the thermometer’s measurement, \(M\); the doctor has no way of knowing what the patient’s actual temperature is—only its noisy observation through \(M\). Thus, \(F\) and \(T\) are independent. Next, let’s see what this means formally.
Statistical Independence. We say two variables \(A\) and \(B\) are statistically independent if their joint can be factorized to the product of their marginals:
This equation implies that to sample \(A\) and \(B\) jointly, we don’t have to consider their relationship (since the conditional isn’t used)—they are entirely independent.
Another way to understand this equation is by thinking its implications on the conditionals. We do this by factorizing \(p_{A, B}(a, b)\) into the product of the conditional and marginal:
We then observe that \(p_{B | A}(b | a)\) must equal \(p_B(b)\) to satisfy our definition of statistical independence. And \(p_{B | A}(b | a) = p_B(b)\) implies that having observed \(A = a\) does not affect the probability of \(B = b\).
Example: Returning to the above example, we know that \(F\) should only depend on \(M\), and that \(M\) should only depend on \(T\). That is, the doctor’s decision only depends on the thermometer’s measurement, and the thermometer only depends on the true temperature. Formally, we can say that \(p_{F | M, T}(f | m, t) = p_{F | M}(f | m)\), since \(F\) only depends on the thermometer reading, \(M = m\). Similarly, \(p_{M | T, F}(m | t, f) = p_{M | T}(m | t)\) since the thermometer reading only depends on the patient’s temperature, not on what the doctor might say. Using this, we can factorize the joint as follows:
(6.8)#\[\begin{align} p_{F, M, T}(f, m, t) &= p_{F | M, T}(f | m, t) \cdot p_{M, T}(m, t) \quad (\text{factorization of the joint}) \\ &= p_{F | M}(f | m) \cdot p_{M, T}(m, t) \quad (\text{since $F$ and $T$ are independent}) \\ &= p_{F | M}(f | m) \cdot p_{M | T}(m | t) \cdot p_T(t) \quad (\text{factorizing the joint of $M$ and $T$}) \end{align}\]Thus, using our domain expertise, we’ve selected a factorization of the model that makes the most sense.
Graphical Representation of Statistical Dependencies. Since reasoning about many variables jointly is difficult, we introduce a graphical representation to aid with it. This representation is called a directed graphical model (DGM), and it will help us convey which variables depend on one another in what way.
A DGM is represented using a graph (or network) in which nodes represent RVs and arrows represent conditional dependencies. For example, consider the following DGM for some hypothetical joint distribution, \(p_{A, B, C}(\cdot)\):
In this DGM, there are three nodes, corresponding to our three RVs. Our factorization then consists of one factor for each node:
The factor corresponding to \(B\) is \(p_{B | A}(\cdot)\), since there’s an arrow from \(A\) to \(B\), indicating a conditional dependence.
The factor corresponding to \(A\) is \(p_{A}(\cdot)\), since there aren’t any arrows pointing into \(A\).
The factor corresponding to \(C\) is \(p_{C}(\cdot)\), since there aren’t any arrows pointing into \(C\).
In total, the DGM represents the following factorization:
Example: Continuing with our fever example, we can represent our factorization, \(p_{F, M, T}(f, m, t) = p_{F | M}(f | m) \cdot p_{M | T}(m | t) \cdot p_T(t)\), with the following DGM:
As you can see, since \(F\) depends on \(M\), there’s an arrow leading from \(M\) to \(F\). And since \(F\) does not depend on \(T\), there is no arrow leading from \(T\) to \(F\).
As you can see, out of the numerous ways we can factorize \(p_{A, B, C}(a, b, c)\) or \(p_{F, M, T}(f, m, t)\), the factorizations given to us just one to focus on—one that, hopefully, makes sense for our application. But what does this simplicity cost us? Expressivity. By assuming certain variables do not have conditional dependencies, we are limiting the possible joint distributions we represent. Is this a problem? Not if our dependence assumptions are reasonable!
Exercise: Practice with DGMs
Let’s practice converting between distribution notation and DGMs.
Part 1: Write down the factorizations implied by the DGMs below.
Part 2: Draw DGMs representing each joint distribution below.
Sampling from Joint Distributions: Sampling from a joint distribution can be done by sampling from each component in their factorization. When doing this, we must ensure our sampling order is valid. That is, if we have a distribution \(p_{A | B}(a | b) \cdot p_B(b)\), we cannot sample from \(p_{A | B}(\cdot | b)\) first because we don’t know what \(b\) is yet. We first have to sample \(p_B(\cdot)\) to obtain a specific value \(b\), and then use that value of \(b\) when sampling from the conditional.
Let’s illustrate this with some examples.
Example 1: Consider the distribution of temperature, thermometer measurement, and doctor’s decision from the above examples:
(6.10)#\[\begin{align} p_{F, M, T}(f, m, t) &= p_{F | M}(f | m) \cdot p_{M | T}(m | t) \cdot p_T(t) \end{align}\]Suppose we want to draw \(S\) samples from this joint distribution. For every \(s \in \{1, \dots, S \}\), we would follow the steps below, corresponding to the logic behind the model. That is,
First, we sample the patient’s temperature: \(t_s \sim p_T(\cdot)\).
Now that the patient’s temperature has been determined (\(T = t_s\)), we can try to measure it: \(m_s | t_s \sim p_{M | T}(\cdot | t_s)\). This is a draw from a distribution because our thermometer has some random error in its measurement.
Finally, given the reading of the thermometer (\(M = m_s\)), the doctor can decide if the patient has fever: \(f_s | m_s \sim p_{F | M}(\cdot | m_s)\).
Each time we carry out these steps, we get one sample from the joint distribution: \(f_s, m_s, t_s \sim p_{F, M, T}(\cdot, \cdot, \cdot)\).
Example 2: Consider the distribution,
(6.11)#\[\begin{align} p_{A, B, C}(a, b, c) &= p_{B | A}(b | a) \cdot p_A(a) \cdot p_C(c) \end{align}\]Suppose we wanted to draw \(S\) samples from this joint distribution. For every \(s \in \{1, \dots, S \}\), we would have to:
\(c_s \sim p_C(\cdot)\)
\(a_s \sim p_A(\cdot)\)
\(b_s | a_s \sim p_{B | A}(\cdot | a_s)\) (wherein \(a\) was sampled in the previous step)
In carrying out these sampling steps, we would obtain \(S\) samples from the joint distribution: \(a_s, b_s, c_s \sim p_{A, B, C}(\cdot, \cdot, \cdot)\).
Notice that here, steps (1) and (2) are interchangeable, since they don’t depend on one another; i.e. it doesn’t matter if we sample \(A\) first or \(C\) first. In contrast, step (3) had to happen after step (2), because depends on the value of \(A\) sampled.
Example 3: Consider the distribution,
(6.12)#\[\begin{align} p_{A, B, C}(a, b, c) &= p_{B | A, C}(b | a, c) \cdot p_{A | C}(a | c) \cdot p_C(c) \end{align}\]Suppose we wanted to draw \(S\) samples from this joint distribution. For every \(s \in \{1, \dots, S \}\), we would have to:
\(c_s \sim p_C(\cdot)\)
\(a_s | c_s \sim p_{A | C}(\cdot | c_s)\) (wherein \(c_s\) was sampled in the previous step)
\(b_s | a_s, c_s \sim p_{B | A, C}(\cdot | a_s, c_s)\) (wherein \(a_s, c_s\) were sampled in the previous steps)
In carrying out these sampling steps, we would obtain \(S\) samples from the joint distribution: \(a_s, b_s, c_s \sim p_{A, B, C}(\cdot, \cdot, \cdot)\).
Notice again that step (2) depends on \(c_s\), which was sampled in step (1), and that step (3) depends on both \(c_s\) and \(a_s\), sampled in steps (1) and (2). Thus, the order of the sampling does matter!
Exercise: Practice the sampling order
For each of the 4 DGMs from the previous exercise, write down two different valid sampling orderings. Please use the same notation used here.
6.3. Translating DGMs into Code with NumPyro
#
Exercise: Joint distributions are generative models
Context: Your friend is an ML researcher at a nearby university. She heard all about the interesting data you have from the IHH ER and wants to help with the analysis. However, because this is sensitive medical data, she needs to obtain the right credentials, undergo a lengthy training on secure data management, and more, before obtaining access to the data. Realistically, this means she’ll only be able to gain access to the IHH ER data in several months.
Idea: To help her out, you have an idea—instead of sending her the data directly, you will develop a generative model of the IHH ER data and send that to her instead. This generative model will allow your friend to generate (or sample) realistic data with the same characteristics as the real data without violating any privacy constraints. But what is exactly a generative model? A generative model is a joint probability distribution over all variables in the data: \(D\), \(C\), \(H\), and \(A\), where:
\(D\): Day-of-Week
\(C\): Condition
\(H\): Hospitalized
\(A\): Antibiotics
Limitations: Note that, in general, releasing a generative model instead of a data set to ensure its anonymity requires serious care. There’s a whole field called “Differential Privacy” devoted to doing this responsibly. We will not cover these techniques in this class.
Domain Expertise: You consulted with clinical collaborators at the IHH ER and together came up with the following DGM.
Notice the conditional distributions in this DGM are ones you’ve previously learned (by hand)!
Part 1: Use what you already know about the marginal and conditional distributions of the IHH ER data, in tandem with what you learned here about joint distributions to implement this generative model. Implement your model as a function that takes in a random generator key
and outputs a Python dictionary with a single sample from the joint distribution \(p_{D, C, H, A}(\cdot)\). That is, the dictionary has keys \(D\), \(C\), \(H\), \(K\), and \(A\), with values corresponding to the sample. Your function should rely on your previous implementations of the conditional distributions.
Part 2: Implement a second function that, given \(D = d\), \(C = c\), \(H = h\), and \(A = a\), computes the log probability of the joint, \(\log p_{D, C, H, A}(d, c, h, a)\). Your function should rely on your previous implementations of the conditional distributions.
Hint: \(\log (X \cdot Y) = \log X + \log Y\).
Part 3: Draw 10,000 samples from the generative model, and evaluate the model’s log probability on each. Of all the samples you drew, print out the sample with the highest and lowest log probability. What does each tell you?
Note: You’re welcome to use a loop for this problem (we’ll learn how to vectorize sampling in the next chapter).
Note: NumPyro
discrete distributions only work with integers, not strings. For example, instead of using \(d = \text{Monday}\), you should convert the days of the week into integers from 0 to 6 (Monday to Sunday), and instead use \(d = 0\) for “Monday”. We’ve created two helper functions to help you with this conversion: convert_day_of_week_to_int
and convert_condition_to_int
. You can use them as follows:
convert_day_of_week_to_int(data['Day-of-Week'])
convert_condition_to_int(data['Condition'])
Please use the function signatures below:
import jax.numpy as jnp
import jax.random as jrandom
import numpyro
import numpyro.distributions as D
def IHH_ER_generative_model_sample(key):
pass # TODO implement
def IHH_ER_generative_model_log_prob(d, c, h, a):
pass # TODO implement