6. Joint Probability (Discrete)#

# Import some helper functions (please ignore this!)
from utils import * 

Context: So far, you’ve spent some time conducting a preliminary exploratory data analysis (EDA) of IHH’s ER data. You noticed that considering variables separately can result in misleading information. As a result, you decided to use conditional distributions to model the relationship between variables. Using these conditional distributions, you were able to develop predictive models (e.g. predicting the probability of intoxication given the day of the week). These predictive models are useful for the IHH administration to make decisions.

However, you’ve noticed that your modeling toolkit is still limited. The conditional distributions we introduced can model how the probability of one variable changes given a set of variables. What if we wanted to describe how the probability of a set of variables (i.e. more than one) changes given a set of variables? For example, we may want to answer questions like: “how does the probability that a patient is hospitalized for an allergic reaction change given the day of the week?” In this question, we’re inquiring about two variables—that the condition is an allergic reaction, and that the patient was hospitalized—given the day of the week.

Challenge: We need to expand our modeling toolkit to include yet another tool—joint probabilities.

Outline:

  1. Introduce and practice the concepts, terminology, and notation behind discrete joint probability distributions (leaving continuous distributions to a later time).

  2. Introduce a graphical representation to describe joint distributions.

  3. Translate this graphical representation directly into code in a probabilistic programming language (using NumPyro) that we can then use to fit the data.

6.1. Terminology and Notation#

We, again introduce the statistical language—terminology and notation—to precisely specify to a computer how to model our data. We will then translate statements in this language directly into code in NumPyro that a computer can run.

Concept. The concept behind a joint probability is elegant; it allows us to build complicated distributions over many variables using simple conditional and non-conditional distributions (that we already covered).

We can illustrate this using an example with just two variables. Suppose you have two RVs, \(A\) and \(B\). The probability that \(A = a\) and \(B = b\) are both satisfied is called their joint probability. It is denoted by \(p_{A,B}(a, b)\). This joint distribution can be factorized to a product of conditional and non-conditional (or “marginal”) distributions as follows:

(6.1)#\[\begin{align} p_{A, B}(a, b) &= p_{A | B}(a | b) \cdot p_B(b) \quad \text{(Option 1)} \\ \underbrace{\phantom{p_{A, B}(a, b)}}_{\text{joint}} &= \underbrace{p_{B | A}(b | a)}_{\text{conditional}} \cdot \underbrace{p_A(a)}_{\text{marginal}} \quad \text{(Option 2)} \end{align}\]

Notice that the joint is now described in terms of conditional and marginal distributions, which we already know how to work with!

Intuition. So what’s the intuition behind this formula? Let’s depict events \(A = a\) and \(B = b\) as follows:

In this diagram, each shaded area represents the probability of an event—i.e. area is proportional to probability. We use it to pictorially represent the marginal, conditional, and joint distributions. The marginal probability of \(B = b\), for example, is the ratio of the blue square relative to the whole space (the gray square):

Similarly, the marginal probability of \(A = a\) is the ratio of the red square relative to the gray square.

Next, the conditional \(p_{A | B}(a | b)\) is the ratio of the purple intersection relative to the blue square. This is because the blue square represents our conditioning on \(B = b\), and the purple intersection represents the probability that we also have \(A = a\).

Finally, the joint \(p_{A, B}(a, b)\) is the ratio between the purple intersection and the whole space (the gray square). This is because the intersection is where both \(A = a\) and \(B = b\).

Now we can see that the joint is the product of the conditional and the marginal because the blue squares “cancel out”:

Choice of Factorization. Lastly, notice that we have a choice to factorize the distribution in two ways. How do you know which one to use? Typically, we choose a factorization that is intuitive to us and what we can compute.

For example, suppose you want to model the joint distribution of the day of the week, \(D\) and whether a patient arrive with intoxication, \(I\). The joint distribution can be factorized in two ways:

(6.2)#\[\begin{align} p_{D, I}(d, i) &= p_{I | D}(i | d) \cdot p_D(d) \quad \text{(Option 1)} \\ &= p_{D | I}(d | i) \cdot p_I(i) \quad \text{(Option 2)} \\ \end{align}\]

Which one makes more intuitive sense? Well, it’s a little weird to try to predict the day of the week given whether a patient arrives with intoxication; we typically know what the day of the week is and we don’t need to predict it. In contrast, given the day of the week, it makes a lot of sense to wonder about the probability of a patient arriving with intoxication. As such, Option 1 makes more sense here.

Generalizing to More than Two RVs. So now we have the tools to work with joint distributions with two RVs. What do we do if we have three or more? The same ideas apply. The joint distribution for random variables \(A\), \(B\), and \(C\) can be factorized in a number of ways. For example, we can condition on two variables at a time:

(6.3)#\[\begin{align} p_{A, B, C}(a, b, c) &= p_{A | B, C}(a | b, c) \cdot p_{B, C}(b, c) \quad \text{(Option 1)} \\ &= p_{B | A, C}(b | a, c) \cdot p_{A, C}(a, c) \quad \text{(Option 2)} \\ &= p_{C | A, B}(c | a, b) \cdot p_{A, B}(a, b) \quad \text{(Option 3)} \end{align}\]

wherein the above, we already know how to factorize \(p_{B, C}(b, c)\), \(p_{A, C}(a, c)\), and \(p_{A, B}(a, b)\) (since they are joint distributions with two variables).

We can also condition on one variable at a time:

(6.4)#\[\begin{align} p_{A, B, C}(a, b, c) &= p_{A, B | C}(a, b | c) \cdot p_C(c) \quad \text{(Option 1)} \\ &= p_{A, C | B}(a, c | b) \cdot p_B(b) \quad \text{(Option 2)} \\ &= p_{B, C | A}(b, c | a) \cdot p_A(a) \quad \text{(Option 3)} \end{align}\]

And how do we further factorize distributions of the form \(p_{A, B | C}(a, b | c)\)? We apply the same factorization for a joint distribution with two variables, and simply add a “conditioned on \(C\)” to each one:

(6.5)#\[\begin{align} p_{A, B | C}(a, b | c) &= p_{A | B, C}(a | b, c) \cdot p_{B | C}(b | c) \quad \text{(Option 1)} \\ &= p_{B | A, C}(b | a, c) \cdot p_{A | C}(a | c) \quad \text{(Option 2)} \\ \end{align}\]

6.2. Directed Graphical Models (DGMs)#

As you may have already noticed, the number of possible ways to factorize a joint distribution increases very quickly with the number of RVs. In fact, the more RVs we have, the more unwieldy it becomes for us as data analysts to specify each component in the factorization. What can we do to simplify our model? Often, we can use our domain knowledge (knowledge of the specifics of the problem) to simplify the joint distribution. Specifically, we’ll use our knowledge of “conditional independence” to do this. Let’s get started by first introducing the idea of statistical independence.

Statistical Independence. We say two variables \(A\) and \(B\) are statistically independent if their joint can be factorized to the product of their marginals:

(6.6)#\[\begin{align} p_{A, B}(a, b) &= p_B(b) \cdot p_A(a) \end{align}\]

This equation implies that to sample \(A\) and \(B\) jointly, we don’t have to consider their relationship (since the conditional isn’t used)—they are entirely independent.

Another way to understand this equation is by thinking its implications on the conditionals. We do this by factorizing \(p_{A, B}(a, b)\) into the product of the conditional and marginal:

(6.7)#\[\begin{align} p_{A, B}(a, b) &= \underbrace{p_{B | A}(b | a)}_{\text{must equal } p_B(b)} \cdot p_A(a) \end{align}\]

We then observe that \(p_{B | A}(b | a)\) must equal \(p_B(b)\) to satisfy our definition of statistical independence. And \(p_{B | A}(b | a) = p_B(b)\) implies that having observed \(A = a\) does not affect the probability of \(B = b\).

Graphical Representation of Statistical Dependencies. Since reasoning about many variables jointly is difficult, we introduce a graphical representation to aid with it. This representation is called a directed graphical model (DGM), and it will help us convey which variables depend on one another in what way.

A DGM is represented using a graph (or network) in which nodes represent RVs and arrows represent conditional dependencies. For example, consider the following DGM for some hypothetical joint distribution, \(p_{A, B, C}(\cdot)\):

In this DGM, there are three nodes, corresponding to our three RVs. Our factorization then consists of one factor for each node:

  • The factor corresponding to \(B\) is \(p_{B | A}(\cdot)\), since there’s an arrow from \(A\) to \(B\), indicating a conditional dependence.

  • The factor corresponding to \(A\) is \(p_{A}(\cdot)\), since there aren’t any arrows pointing into \(A\).

  • The factor corresponding to \(C\) is \(p_{C}(\cdot)\), since there aren’t any arrows pointing into \(C\).

In total, the DGM represents the following factorization:

(6.8)#\[\begin{align} p_{A, B, C}(a, b, c) &= p_{B | A}(b | a) \cdot p_A(a) \cdot p_C(c) \end{align}\]

As you can see, relative to the numerous ways we can factorize \(p_{A, B, C}(a, b, c)\) above, this factorization is much simpler: the conditional dependencies/independencies implied by the DGM allows us to reduce the number of terms in their factorization.

So what does this simplicity cost us? Expressivity. By assuming certain variables do not have conditional dependencies, we are limiting the possible joint distributions we represent. Is this a problem? Not if you make reasonable dependence assumptions based on your domain knowledge!

Exercise: Practice with DGMs

Let’s practice converting between distribution notation and DGMs.

Part 1: Write down the factorizations implied by the DGMs below.

Part 2: Draw DGMs representing each joint distribution below.

\[\begin{align*} p_{A, B, C, D, E}(a, b, c, d, e) = p_{C | A, B, D}(c | a, b, d) \cdot p_{A | B}(a | b) \cdot p_{D | B, E}(d | b, e) \cdot p_{B}(b) \cdot p_{E}(e) \end{align*}\]
\[\begin{align*} p_{A, B, C, D, E}(a, b, c, d, e) = p_{E | D}(e | d) \cdot p_{D | C}(d | c) \cdot p_{C | A, B}(c | a, b) \cdot p_{A | B}(a | b) \cdot p_B(b) \end{align*}\]
\[\begin{align*} p_{A, B, C, D, E}(a, b, c, d, e) = p_{E | A, B, C, D}(e | a, b, c, d) \cdot p_{D | A, B, C}(d | a, b, c) \cdot p_{C | A, B}(c | a, b) \cdot p_{B | A}(b | a) \cdot p_{A}(a) \end{align*}\]

Sampling from Joint Distributions: Sampling from a joint distribution can be done by sampling from each component in their factorization. When doing this, we must ensure our sampling order is valid. That is, if we have a distribution \(p_{A | B}(a | b) \cdot p_B(b)\), we cannot sample from \(p_{A | B}(\cdot | b)\) first because we don’t know what \(b\) is yet. We first have to sample \(p_B(\cdot)\) to obtain a specific value \(b\), and then use that value of \(b\) when sampling from the conditional.

Let’s illustrate this with some examples.

Example 1: Consider the distribution,

(6.9)#\[\begin{align} p_{A, B, C}(a, b, c) &= p_{B | A}(b | a) \cdot p_A(a) \cdot p_C(c) \end{align}\]

Suppose we wanted to draw \(S\) samples from this joint distribution. For every \(s \in \{1, \dots, S \}\), we would have to:

  1. \(c_s \sim p_C(\cdot)\)

  2. \(a_s \sim p_A(\cdot)\)

  3. \(b_s | a_s \sim p_{B | A}(\cdot | a_s)\) (wherein \(a\) was sampled in the previous step)

In carrying out these sampling steps, we would obtain \(S\) samples from the joint distribution: \((a_s, b_s, c_s)\).

Notice that here, steps (1) and (2) are interchangeable, since they don’t depend on one another; i.e. it doesn’t matter if we sample \(A\) first or \(C\) first. In contrast, step (3) had to happen after step (2), because depends on the value of \(A\) sampled.

Example 2: Consider the distribution,

(6.10)#\[\begin{align} p_{A, B, C}(a, b, c) &= p_{B | A, C}(b | a, c) \cdot p_{A | C}(a | c) \cdot p_C(c) \end{align}\]

Suppose we wanted to draw \(S\) samples from this joint distribution. For every \(s \in \{1, \dots, S \}\), we would have to:

  1. \(c_s \sim p_C(\cdot)\)

  2. \(a_s | c_s \sim p_{A | C}(\cdot | c_s)\) (wherein \(c_s\) was sampled in the previous step)

  3. \(b_s | a_s, c_s \sim p_{B | A, C}(\cdot | a_s, c_s)\) (wherein \(a_s, c_s\) were sampled in the previous steps)

In carrying out these sampling steps, we would obtain \(S\) samples from the joint distribution: \((a_s, b_s, c_s)\).

Notice again that step (2) depends on \(c_s\), which was sampled in step (1), and that step (3) depends on both \(c_s\) and \(a_s\), sampled in steps (1) and (2). Thus, the order of the sampling does matter!

Exercise: Practice the sampling order

For each of the 4 DGMs from the previous exercise, write down two different valid sampling orderings.

6.3. Translating DGMs into Code with NumPyro#

Exercise: Joint distributions are generative models

Context: Your friend is an ML researcher at a nearby university. She heard all about the interesting data you have from the IHH ER and wants to help with the analysis. However, because this is sensitive medical data, she needs to obtain the right credentials, undergo a lengthy training on secure data management, and more, before obtaining access to the data. Realistically, this means she’ll only be able to gain access to the IHH ER data in several months.

Idea: To help her out, you have an idea: instead of sending her the data directly, you will develop a generative model of the IHH ER data and send that to her instead. This generative model will allow your friend to generate (or sample) realistic data with the same characteristics as the real data without violating any privacy constraints. But what is exactly a generative model? A generative model is a joint probability distribution over all variables in the data: \(D\), \(C\), \(H\), \(K\), and \(A\), where:

  • \(D\): Day-of-Week

  • \(C\): Condition

  • \(H\): Hospitalized

  • \(A\): Antibiotics

  • \(K\): Knots

Domain Expertise: You consulted with clinical collaborators at the IHH ER and together came up with the following DGM.

Notice the conditional distributions in this DGM are ones you’ve previously learned (by hand)!

Part 1: Use what you already know about the marginal and conditional distributions of the IHH ER data, in tandem with what you learned here about joint distributions to implement this generative model. Implement your model as a function that takes in a random generator key and outputs a Python dictionary with a single sample from the joint distribution \(p_{D, C, H, K, A}(\cdot)\). That is, the dictionary has keys \(D\), \(C\), \(H\), \(K\), and \(A\), with values corresponding to the sample. Your function should rely on your previous implementations of the conditional distributions.

Part 2: Implement a second function that, given \(D = d\), \(C = c\), \(H = h\), \(K = k\), and \(A = a\), computes the log probability of the joint, \(\log p_{D, C, H, K, A}(d, c, h, k, a)\). Your function should rely on your previous implementations of the conditional distributions.

Hint: \(\log (X \cdot Y) = \log X + \log Y\).

Part 3: Draw 10,000 samples from the generative model, and evaluate the model’s log probability on each. Of all the samples you drew, print out the sample with the highest and lowest log probability. What does each tell you?

Note: You’re welcome to use a loop for this problem (we’ll learn how to vectorize sampling in the next chapter).

Note: NumPyro discrete distributions only work with integers, not strings. For example, instead of using \(d = \text{Monday}\), you should convert the days of the week into integers from 0 to 6 (Monday to Sunday), and instead use \(d = 0\) for “Monday”. We’ve created two helper functions to help you with this conversion: convert_day_of_week_to_int and convert_condition_to_int. You can use them as follows:

  • convert_day_of_week_to_int(data['Day-of-Week'])

  • convert_condition_to_int(data['Condition'])

Please use the function signatures below:

import jax.numpy as jnp
import jax.random as jrandom
import numpyro
import numpyro.distributions as D

def IHH_ER_generative_model_sample(key):
    pass # TODO implement

def IHH_ER_generative_model_log_prob(d, c, h, k, a):
    pass # TODO implement