{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Classification " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import some helper functions (please ignore this!)\n", "from utils import * " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Context:** We've just learned the framework underlying predictive models---their directed graphical model and MLE objective. We've instantiated this framework for regression---the task of predicting *real-valued* outputs from inputs. Now, we will provide another instantiation: classification. Classification focuses on predicting *categorical-valued* outputs from inputs. For example, given chest X-rays, we may want to predict whether a patient has COVID or not; given a sample of text, we may want to predict whether it has positive or negative sentiment; given a patient's demographic information, we may want to predict whether medication will be helpful, unhelpful, or harmful. \n", "\n", "**Challenge:** A classification model is a specific instance of a predictive model; as such, you may find it useful to review the unit on regression before continuing. After reviewing, you'll notice that the main piece that differentiates different types of predictive models is the choice of conditional. That is, if our goal is to predict $Y$ from $X$, we need to define $p_{Y | X}$. Whereas for regression, we used a real-valued distribution---a Gaussian---here we will use a Bernoulli or Categorical distribution.\n", "\n", "**Outline:**\n", "* Instantiate classification models\n", "* Implement classification models in `NumPyro`\n", "* Interpret a classifier fit on IHH data\n", "\n", "**Data:** You're continuing your collaboration with your colleagues at the IHH's Center for Telekinesis Research. This time, the researchers are interested in developing a tool that can help patients make informed treatment decisions. Specifically, there is a large segment of your patient population that suffers from a lack of telekinetic control; that is, their telekinetic abilities tend to do something unexpected---when trying to move a bookshelf to the right, it moves forward, when trying to lift a mug, the mug lifts but spills on the floor, etc. Fortunately, there's a new medication that was just developed that's designed to help patients regain control of their telekinetic abilities. The problem with the medication is that it has some negative side effects (nausea, headaches), so it's important that patients getting the medication can weigh the risks against the benefits. Your colleagues therefore asked you to build a predictive model to determine what the effects of the medication will be. This predictive model will be given to clinicians so they can communicate with patients how likely they are to recover their telekinetic control. \n", "\n", "You just got sent the data and are eager to take a peak. Your colleagues are most excited about understanding the relationship between age, the dosage of medication, and telekinetic control. As you see below, the data includes two columns relating to control. `Control-Before` represents whether the patients had telekinetic control before receiving the medication, and `Control-After` represents the effect of the medication. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Age | \n", "Dose | \n", "Glow | \n", "Telekinetic-Ability | \n", "Control-Before | \n", "Control-After | \n", "
---|---|---|---|---|---|---|
Patient ID | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
405 | \n", "17.600859 | \n", "0.764120 | \n", "0.868124 | \n", "0.294490 | \n", "1 | \n", "1 | \n", "
1190 | \n", "5.332611 | \n", "0.944742 | \n", "0.941606 | \n", "0.404568 | \n", "1 | \n", "1 | \n", "
1132 | \n", "35.766937 | \n", "0.796327 | \n", "0.605670 | \n", "-0.053640 | \n", "0 | \n", "1 | \n", "
731 | \n", "33.879105 | \n", "0.802059 | \n", "0.605779 | \n", "-0.124069 | \n", "1 | \n", "1 | \n", "
1754 | \n", "26.795975 | \n", "0.228325 | \n", "0.824671 | \n", "0.014139 | \n", "1 | \n", "1 | \n", "
1178 | \n", "11.604771 | \n", "0.185052 | \n", "0.893694 | \n", "0.393820 | \n", "1 | \n", "1 | \n", "
1533 | \n", "3.343291 | \n", "0.372828 | \n", "1.005576 | \n", "0.316230 | \n", "1 | \n", "1 | \n", "
1303 | \n", "27.073309 | \n", "0.497788 | \n", "0.645080 | \n", "-0.006282 | \n", "1 | \n", "1 | \n", "
1857 | \n", "77.514244 | \n", "0.300860 | \n", "0.276417 | \n", "-0.655437 | \n", "0 | \n", "0 | \n", "
18 | \n", "15.204407 | \n", "0.766040 | \n", "0.762407 | \n", "0.329108 | \n", "1 | \n", "1 | \n", "
1266 | \n", "12.526113 | \n", "0.655977 | \n", "0.845298 | \n", "0.454340 | \n", "1 | \n", "1 | \n", "
1543 | \n", "13.561113 | \n", "0.422801 | \n", "0.862660 | \n", "0.400310 | \n", "1 | \n", "1 | \n", "
249 | \n", "18.606394 | \n", "0.188196 | \n", "0.748102 | \n", "0.170958 | \n", "1 | \n", "1 | \n", "
191 | \n", "5.620849 | \n", "0.491923 | \n", "0.900761 | \n", "0.419281 | \n", "1 | \n", "1 | \n", "
721 | \n", "21.602093 | \n", "0.517346 | \n", "0.797098 | \n", "0.090256 | \n", "1 | \n", "1 | \n", "