Probabilistic effects programming with Pyro

Posted on July 30, 2020 by Calvin

Note: This blog post is still a rough draft. Read on with caution.

A lot of machine learning that we do is probabilistic, even though we don’t phrase it as such. Indeed, one can think of regularized linear regression as a hierarchical Bayesian model where the outcomes are normally-distributed with a fixed (yet unknown) variance. More often than not, thinking in terms of the probabilistic perspective leads us to generalizations of our favorite algorithms. As a trite example, if we wanted a robust version of the above regression, we might replace our normal distribution with a distribution with heavier tails, such as the Student \(t\)-distribution.

However most of the time, we ignore this probabilistic bounty and resort to maximal likelihood techniques (or if we end up being fancy, MAP!). By resorting to averages we throw away a huge amount of relevant information, and drive us away from principled optimality. Fully exploiting this leads us to the Bayesian framework, which is characterized by computational complexity and inferential madness. Is there a way to just write models and have the hard computations done for you automatically?

probabilistic programming

Enter probabilistic programming. PPLs have for years been promising data scientists a beautiful story– write your models, press a button, and off you go! There have been bumps along the way. Early probabilistic programming languages were like early tensorflow: models were constructed in an embedded DSL with huge limitations as to what you could do, since each new control flow feature had to be implemented in the confines of the language. Take for instance a generative story for a geometric distribution \(x\sim\text{Geom}(p)\),

This simple distribution has very complicated control flow! Inference on this model requires traversing computational graphs of variable size, so we need PPLs that can deal with this and other kinds of control flow in a natural way.

I’ve been an advocate of the control flow implemented via algebraic effects systems, where computations are given effectful emissions, which are captured and handled by outside effect handlers. While many (functional) languages deal with computational effects by passing them around in monadic types, algebraic effect languages use bounding-constructs like delimited continuations to restrict the scope of effects.

It turns out this strategy can have a lot of mileage, as it is the main idea behind the construction of Uber’s probabilistic programming language Pyro. While the full strength of delimited continuations is difficult to replicate in Python, one-shot continuations are easier to implement by passing global state around. To see examples of full delimited continuations, I have some blog posts dedicated to it.

minipyro, annotated

In this post I want to try and walk through the implementation of Pyro, starting with its main effects system. This is basically an annotated form of minipyro. I think it is an interesting case study in how future PPLs can be build in general-purpose languages, and understanding how Pyro works under the hood is a great way to think about its limitations, and how to potentially get around it (for example, combining with distributed training in pytorch is tricky, but not impossible).

pyro is a probabilistic programming language, and as such are built with random variables and distributions as primitive objects. To perform inference on data, distributions expose an API of two functions– sample and log_prob.

Simply, sample samples a tensor from the given distribution, and log_prob returns the log-probability of a tensor computed from the pdf or pmf of the distribution. This, along with gradients of the log-probability (which is why pyro is built with the auto-diff properties of pytorch) is enough to allow for the sampling of the posterior distribution from probabilistic models, which is what automated inference is designed for.

However, we want our inference to run automatically in the background, without the need for us to write them ourselves. This is fairly easy if all we are doing is running, say, MCMC. We run our model repeatedly, sampling from each distribution and computing log-probabilities, taking gradients and running acceptance protocols. This can be done with a giant for loop. So why does pyro make it so complicated with effect handlers?

Well, its because we often want to remember the entire trace of a probabilistic program. For example, suppose we have a model and want to understand what happens when we condition on a parameter taking on a certain value? We can use a giant loop with rejection sampling semantics, but we’re gonna be running this for a long time before any convergence can occur. A better idea is to start in the middle and rerun simulations from the middle of a sampling process. To do this in the background, we need to keep these effects hidden from us, using effect handlers.

First things first, we need to keep track of our two kinds of global state: effect handlers and named parameters

pyro uses the Messenger class as the main abstraction of an effect handler. Okay, what the heck is an effect? a computational effect can be thought of as any unpure artifact of a computation. Intuitively, we can try and pretend python is a functional programming language and an effects system is something that is “building something in the background” as a computation is executed, e.g. building a computational graph or recording the sample/log-probs as a probabilistic model is evaluated.

To capture the idea of a computation entering an “effectful environment” we use python context managers. Such handlers record all the effects being emitted during a computation and handle them. Multiple effect handlers can wrap a single computation, and effects emitted are bubbled up through the handlers sequentially until they are captured by a relevant one.

The comments above are explanatory. an example of an effect handler (again, we think of it as “effectful augmentation of computations”) is given by trace. In a nutshell, trace records the inputs and outputs of a function and keeps that information around with other metadata.

What is msg in the above? It’s our effect. In pyro’s implementation, effects are reified as messages (dicts) being passed around the context managers, which are processed/post-processed.

Before we see more handlers, let’s implement the running of an effectful computation and elucidate how messages are passed around.

So far, this is the core of the effect handling system in pyro! To see an example of this, let’s start with the core function of a probabilistic programming system: sampling from models. Given a distribution d, calling d.sample() is a uneffectful computation (ignoring randomness)– to make it emit an “effect message”, we wrap all primitive distributions in a pyro.sample handler to allow samples to have this effectful structure.

note: This is not a Messenger! pyro.sample is intended as an effect emitter– it should by itself never handle any effects itself.

To keep things straight in our heads, let’s run through a basic example.

This is a single trace of the probabilistic model simple_model above.

Now that we can capture sample traces from a model, what can we do with it? We can use it to answer probabilistic queries like, what is the joint probability of the sample? This can be given by summing up the log-probabilities of each sample from the sites and exponentiating it.

Let’s bump the complexity of our probabilistic models up a bit. Recall that plate notation is a method of representing variables that repeat in a graphical model. For example, we might wish to model N normally distributed data points where the mean is itself normally-distributed:

\[ x_i \sim \text{Normal}(\alpha, 1) \text{ for }i=1,...,N \] \[ \alpha \sim \text{Normal}(0, 1) \]

To represent this without loops, we treat it as an effect handler.

Now the model above is given by

We see that without writing a loop, we sample N observations from the above distribution via pure broadcasting semantics.


Now that we have a hang of the basic objects in pyro, we now want to formulate the entire reason for probabilistic programming– automatic Bayesian inference. What is inference? It is the effective manipulation of our generative models to produce estimates over latent variables in our model.

The basic ingredients we need for inference in pyro are conditioning and good inference algorithms; in our case, it is given by stochastic variational inference. We start with condition.

condition is an effect handler that effectively allows us to hold certain sampling sites at fixed values. Implementing this is easy– we postprocess any message that is emitted by a pyro.sample site by swapping out whatever was sampled by our fixed value.

We’ll test that it works by using it to compute the log-joint probabilities of a toy model from the pyro documentation.

Often what we want to do is estimate the latent variables in a generative model that best estimates the observational data the model is conditioned on. However, the sample statements in our models are immobile and our inference algorithms won’t operate on them. Instead, we need to represent latent variables by parameters.

def param(name, init_value=None, constraint=torch.distributions.constraints.real, event_dim=None):
    if event_dim is not None:
        raise NotImplementedError("plate does not support the event_dim arg")
    # a parameter needs to be optimized in an unconstrained space, while the underlying
    # parameter itself could be very constrained. we need a function that brokers the
    # passage between the two spaces
    def fn(init_value, constraint):
        # remember, PARAM_STORE maps name -> (unconstrained_value, constraint)
        if name in PARAM_STORE:
            unconstrained_value, constraint = PARAM_STORE[name]
            # initialize with a constrained value
            assert init_value is not None
            with torch.no_grad():
                constrained_value = init_value.detach()
                unconstrained_value = torch.distributions.transform_to(constraint).inv(constrained_value)
            # as the unconstrained parameter needs to be optimized, make sure
            # gradients can propagate through node
            PARAM_STORE[name] = (unconstrained_value, constraint)

        # transform from unconstrained space to constrained space.
        constrained_value = torch.distributions.transform_to(constraint)(unconstrained_value)
        constrained_value.unconstrained = weakref.ref(unconstrained_value)
        return constrained_value
    # if there are no active Messengers, we just draw a sample and return it as expected
    if not PYRO_STACK:
        return fn(init_value, constraint)
    # else, an effect (msg) is initialized
    effect_msg = {
        "type": "param",
        "name": name,     # parameter sites are given unique names
        "fn": fn,
        "args": (init_value, constraint),
        "value": None}
    # pass it through the effect handlers
    processed_effect_msg = apply_stack(effect_msg)
    return processed_effect_msg["value"]

The complexity of param is only from the fact that 1) we need to be able to impose optimizable constraints on our parameters, and 2) we need to be able to apply autodifferentiation to optimize said parameters.

Now all we need is an inference algorithm. The hallmark of pyro is its usage of stochastic variational inference, which can be easily describe as trying to approximate the posterior distribution \(p(\theta|\text{data})\) of our model via a parameterized family of guides \(q(\theta; \text{params})\) where \(\text{params}\) are optimizable params. We do this by finding the parameters that make our guides as close as possible to the posterior, where distance here is given by the Kullback-Leibner divergence

\[ \text{D}_\text{KL}(q(\theta;\text{params})|| p(\theta|\text{data})) \]

For example, in our scale example above, the posterior \(p(\text{weight}|\text{meas}=9.5, \text{guess}=8.5)\) is some univariate distribution. We will try and approximate it with the family of normal distributions \(q(\text{weight}|a, b)\):

This is our guide \(q(\theta|\text{params})\). Our goal is to find the parameter that minimizes the KL-divergence (also known as the \(\text{ELBO}\) (evidence lower bound)):

\[ \text{D}_\text{KL}(q(\theta;\text{params})|| p(\theta|\text{data})) = \mathbf{E}_{\theta\sim q(\theta;\text{params})}[\log{q(\theta;\text{params})} - \log{p(\theta|\text{data})}] \]

This will be computed using a 1-sample monte-carlo estimator of the \(\text{ELBO}\), which we call Trace_ELBO. There is one technical point to be made here– once we sample a trace from our guide \(q(\theta;\text{params})\), we must use the exact same trace to compute the log-probabilities for both \(q\) and \(p\). In essence, we need to “replay” the trace for the model underlying \(p(\theta|\text{data})\) and use those values to compute the log-probability. This can be seen by the expectation above– a monte-carlo estimate of this expectation samples \(theta\) from \(q(\theta;\text{params})\) and then uses these same \(\theta\) in the log-probs.

This replay is itself an effect handler– intuitively during a replay, any sample effect msg emitted by the model is intercepted and has its value replaced by the trace its replaying.

Now we use this to compute the \(\text{ELBO}\) via Monte Carlo:

Finally, we can wrap up the training loop for our inference algorithm into a stochastic variational inference class! Since we are interested in optimizing the param tensors in our algorithm, we need a way to isolate the param messages from sample statements in our model. We do this by a separate trace of our model that “blocks” out the sample statements. as you guessed it, it is another effect handler– block:

All of our effect handlers cumulate in our SVI class:

Here, torch optimizers should be wrapped as parameters can be dynamically created during the training process.

Let’s test this out on our scale model!

Plotting the loss gives

Now we can extract the posterior parameter values inferred by our variational inference procedure.

This is close to what can be inferred analytically for this problem!

As we see, flexible algebraic effect systems allow us to build complicated inference algorithms without thinking too much about the emitted effects. For example, with multi-shot delimited effects we can build inference algorithms that allow us to fully enumerate over discrete latent variables. This is the idea behind the construction of pyro.poutine.queue.


Note: Write this section at some point when I have time.