Empirical inference in Pyro

Posted on August 27, 2020 by Calvin

Note: This blog post is still a rough draft. Read on with caution.

Most probabilistic programming languages on the market tend to focus on a single inference algorithm as their bread and butter: Stan promulgated the use of Hamiltonian Monte Carlo for inference of complicated hierarchical models, pyro’s API focuses heavily on stochastic variational inference, etc. However, they usually expose many ways to implement other inference algorithms, usually by exploiting execution traces.

Having a varied selection of inference algorithms to try out is advantagous– some algorithms are exact while others are merely approximate, some are slow to compute, while others are very fast. Also, many algorithms exploit dependency structures in the generative models which make them more robust to modality issues or exponential combinatorial explosions. In this blog post we will focus on trying to understand pyro’s importance sampling algorithms, eventually culminating on understanding the pyro.infer.CSIS module on inference compilation.

empirical

First we describe the Empirical distribution.

In essense, an empirical distribution (derived from a dataset \(\mathcal{D}\)) is a histogram without buckets. Instead, points are “weighted” as a proxy of binning. Mathematically, an empirical distribution can be described by the measure

\[ \mu_\text{emp}(\mathcal{D}) = \sum_{x\in\mathcal{D}} \omega_x\delta_x \]

where \(\delta_x\) is the Dirac measure with mass concentrated at the point \(x\).

As with all distributions in modern PPLs, we just need to implement sample to sample a value from the empirical, and log_prob to retrieve the log weight of a value sampled from the distribution.

If one needs a reminder about how tensor dimensions are used in probabilistic programming, I recommend taking a read of this blog post of Eric Ma.

An Empirical distribution is ultimately a reification of a sampling distribution. Given samples from a normal distribution

we can get a visualization of our empirical distribution as the histogram

posterior predictives

The point of probabilistic programming is to make it easy to perform inference of latent variables and parameters in generative models. A PPL is equipped with a suite of inference algorithms that allow us to determine a posterior distribution over unobserved variables, given a prior model and a set of observations.

In pyro, a prior on data is described by a model, such as:

As a generative story, this model describes the prior

\[ \begin{align} p &\sim \text{Beta}(2, 2) \\ x_i | p &\sim \text{Bernoulli}(p) \end{align} \]

Now, given a dataset, say \(\mathcal{D}=\{1,0,1,1,1\}\), what is the posterior? That is, we want to compute the distribution over latents (here, it is p) that best describe the observed data \(\mathcal{D}\). This is the inference process. pyro is equipped with a number of such algorithms, but the one it is specially designed for is variational inference, in which we construct a guide \(q(z\mid x;\phi)\) and use stochastic optimization on an information-theoretic bound to choose parameters \(\phi\) that make the resulting guide a good approximation to the true posterior of the model.

For example, a guide we can use for the above model is given by \(q(p\mid x;\alpha,\beta) =\text{Beta}(p\mid \alpha,\beta)\):

Once we’ve performed the stochastic variational inference procedure and determined the best parameters a and b above, we get a guide function that performs as the “best” (that is, maximal entropy) approximation to the true posterior of the latent variables in the model. How do we get a grasp of this posterior as a distribution, though?

Sample! In PPLs, we often are concerned with the space of execution traces that are produced by a generative story– that is, the sequence of all stochastic choices and conditions encountered during a single execution of the model. pyro captures the traces automatically in the background using an algebraic effects system called poutine. We described algebraic effect systems and their use in Pyro in previous posts on this blog.

For example, to capture a single execution trace of the model, we just need to wrap the model in an effect handler trace that will keep a logged record of the execution of every pyro.sample site that gets activated in the program.

Given traces from the approximate posterior (guide), we can sample from the generative model with the latents tuned to their posterior distributional values by replaying a trace of the guide in the model, and capturing the resulting values in a new trace.

Getting a joint-prediction involves just running the model

Why is this called Predictor? Because this describes the posterior predictive distribution– it is a function that can actually produce predictions from the generative model, where the latents are now distributed according to the approximate posterior given by the guide.

If we want to isolate the marginal distribution of a single sampling site (let’s say we only care about the posterior for p), we can sample a bunch of traces from the predictive and only extract the values corresponding to the site in question.

Running this is like the predictor above:

Of course, we didn’t train the guide so the resulting histogram should look uniform, and overall very useless for us.

importance sampling

Stochastic variational inference is one of the main ways to produce a good posterior distribution in pyro, but it isn’t the only way. Importance sampling is an inference method of computing the posterior, by using a guide as a proposal distribution for a Monte Carlo sampling of execution traces that directly come from the posterior.

In this way, the importance sampler represents the posterior not as its own generative model, but as a collection of traces, which in pyro is formalized as the abstract TracePosterior class.

An importance sampler (represented in the Importance class) is a TracePosterior in which the way we generate traces is via the importance sampling method, which we recall here.

In importance sampling, our goal is to sample from a distribution \(x\sim p(x)\) (here, \(p\) is our model as prior). This may be difficult, so instead we sample from a proposal \(x\sim q(x)\). This may introduce a bias, so we attach a weight that acts as metadata denoting how confident the sampler is that it’s a sample from \(p\). So for example, for \(x\) our weight would be

\[ \omega_x = \frac{q(x)}{p(x)} \]

Assuming a number of samples \(x_i\sim q(x)\) for \(i=1..N\), the empirical distribution measure

\[ \mu_\text{importance} = \sum_{i=1}^N \omega_i \delta_{x_i} \to p(x) \]

To represent this in pyro, we simulate a trace from the guide (proposal) and compute the importance weight by replaying that trace in the model and computing log-joints (which is also provided in the computed execution traces).

Running the importance sampler will load a collection of traces from the guide and assign importance weights to them.

The resulting lines of code extracts the p-marginal distribution from the importance-weighted guide traces.

The empirical distribution of this marginal has the following histogram:

Being that our prior is \(\text{Beta}(2, 2)\) and we saw in our dataset 3 successes, 2 failures, the posterior should analytically be a \(\text{Beta}(5, 4)\) with mean at \(5/9=0.555...\), which seems to agree with the marginal distribution above.

inference compilation

Importance sampling gives an exact posterior, but suffers from a similar problem to variational inference: the choice of proposal distribution is a hard one to make, and greatly affects the quality of the posterior samples generated. For example, we often desire proposal distributions for importance sampling to have thicker tails than the target distribution, otherwise they don’t capture tail behavior well. This can be crucial, especially in high dimensions where probability mass tends to be concentrated away from the bulk.

However, recent work in deep learning has shown that deep neural networks have the ability to expressively focus on high-density portions of generative models, effectively making them generalize well even in high dimensional settings. An obvious question that comes from this is: can we use deep neural nets to learn good proposal distributions for importance sampling? This is the idea behind inference compilation. For more details, see the paper of Wood et al.

In a nutshell, we wish to construct an expressive guide (proposal distribution) \(q(z\mid \text{NN}(x;\phi))\) where \(\text{NN}(-;\phi)\) is a neural network that takes in observations generated by the model and returns the parameters to a family of guides. This sounds like an amortized variational inference, but the main distinctions here are two-fold: 1) in variational inference, we are concerned with choosing parameters that minimize the KL divergence

\[ \phi_\text{VI} = \text{arg}\max_{\phi} \text{D}_\text{KL}\left(q(z|\text{NN}(x;\phi))\mid p(z|x)\right) \]

This loss encourages the proposal density to be narrow, fitting into a single mode in the true posterior. However, inference compilation tries to optimize the reverse KL divergence

\[ \phi_\text{IC} = \text{arg}\max_{\phi} \text{D}_\text{KL}\left(p(z|x)\mid q(z|\text{NN}(x;\phi))\right) \]

which encourages the proposal to be mean-seeking. This often induces the tail behavior that is desired for importance sampling. We want this to be minimized over many observations \(x\), so we truly set as our loss

\[ \mathcal{L}(\phi) = \mathbf{E}_{p(x)} \text{D}_\text{KL}\left(p(z|x)\mid q(z|\text{NN}(x;\phi))\right) \propto \mathbf{E}_{p(x,z)}\left[-\log{q(z|\text{NN}(x;\phi)}\right] \]

And 2) we are still doing importance sampling, hence this is an exact inference method, not approximate like variational inference. Decoupled from the importance sampler, the learned proposal is not a good approximate posterior for the generative model.

To implement inference compilation, we first note that we need to expose the observation sites on a model (because ultimately the amortized proposal is dependent on an observation to determine it’s parameters), and so we make the convention that models accept an observations argument that is a mapping of observation sites \(\to\) values.

We take as our guide a variational family of beta distributions, parameterized by a neural network.

Inference compilation occurs in two stages: first we compile the inference network, teaching it to produce the right latent variable proposals for use in importance sampling. We do this by generating synthetic data from an execution of the model, and using the synthetic observation and latents as the training data for the guide network. We train the net to minimize the loss \(\mathcal{L}(\phi)\) described above.

The second stage is the importance sampling step, where we use our compiled inference network to generate importance samples for the posterior distribution, utilizing the true (non-synthetic) dataset we have for Bayesian inference.

Like the SVI class in pyro, training of the inference network proceeds through epoch steps.

The structure of this function looks similar to the SVI.step() function, where we use the same parameter capture trick as in the previous blog post. The work here is left in the implementation of the loss function. While in stochastic variational inference our loss function was the ELBO, here we wish to compute the averaged KL divergence

\[ \mathcal{L}(\phi) = \mathbf{E}_{p(x,z)}\left[-\log{q(z|\text{NN}(x;\phi)}\right] \]

Since the compilation of the inference network is separate from the actual inference of the model using our dataset, we need to produce a dataset to train our neural network guide with. Luckily, we have a way to produce an infinite number of latent variables and observations from the model– just run the model repeatedly and record traces!

The function _get_joint_trace performs a run of the model and records it’s trace freely, without any conditioning involved.

Given a trace of the model \[(x,z)\sim p(x,z)\] we need to “plug it into” the guide function \(q(z\mid\text{NN}(x;\phi))\). The action of plugging in the trace into the guide is effectively the replay effect handler, except we need to be a bit careful around observed nodes.

In our loss function, this is where we begin looping through the training batch we just synthesized

The remainder of our loss function is clear– we compute the Monte Carlo estimate of the loss function above using log_prob_sum() of the guide trace, and manually backpropagate the gradients.

For clarity, the full loss function is here:

Running the inference comes down to first compiling the inference network…

…and then running the importance sampler.

Again, computing the marginal posterior for p we get samples

and a similar looking histogram as pure importance sampling

closing

If you’ve read the Pyro documentation you may have noticed that Predictor, VariationalMarginal and InferenceCompilation has different names. This is mostly because these are approximately-pedogogical implementations of the actual algorithms in Pyro. However, I hope that this post gives the reader a certain sense of power, knowing that it is actually fairly straightforward to implement these inference algorithms, and give them the confidence to try and implement their own as they see fit.

The field of Bayesian inference and probabilistic programming is in flux constantly. Seemingly every week new inference algorithms are invented for a wider variety of models, and data scientists are often at the mercy of library maintainers for their favorite new algorithm to be implemented. Understanding the complexities behind probabilistic programming languages and how to deal with them helps to bridge that gap, and adds a powerful tool to the arsenal.