Reparametrization Trick

Backpropagating through continuous and discrete samples

Keywords: reparametrization trick, Gumbel max trick, Gumbel softmax, Concrete distribution, score function estimator, REINFORCE


In the context of deep learning, we often want to backpropagate a gradient through samples , where is a learned parametric distribution.

For example we might want to train a variational autoencoder. Conditioned on the input , the latent representation is not a single value but a distribution , generally a Gaussian distribution which parameters are given by a (inference) neural network of parameters . When learning to maximize the likelihood of the data, we need to backpropagate the loss to the parameters of the inference network, across the distribution of or across samples .

TODO: talk about REINFORCE


More specifically, we want to minimize an expected cost

using gradient descent, which requires to compute the gradients and .


Under certain conditions, Leibniz's rule states that the gradient and expectation can be swapped, resulting in

which can be estimated using Monte-Carlo:

with iid samples .

So computing is fairly straightforward and requires only that:

  • we can sample from
  • is differentiable w.r.t.


Computing this gradient is much harder because parametrizes the expectation. Naturally we can rewrite the expectation as an integral over , and use Leibniz's rule again

but now the integral does not have the form of an expectation, so we cannot use Monte-Carlo to estimate its value.

So computing is not straighforward. However notice that:

  • we only need that the distribution is differentiable w.r.t.
  • there is not requirement that be differentiable w.r.t -- no need to backprop through it

In the rest of the article we review a bunch of different tricks to compute the expectation depending on the particular application.

All methods

The table below sums up some ways to deal with samples in a computation graph. Everything in bold is either more powerful or less constraining. In the context of deep learning, the most important attributes are that the loss is differentiable w.r.t. , so that the parameters can be learned using gradient descent.

Method Continuous or Discrete Backpropable
Differentiable w.r.t
Follow exact distribution must exist
Score function estimator Continuous and discrete Yes Yes No
Reparameterization trick Continuous Yes Yes Yes
Gumbel-max trick Discrete No Yes
Gumbel-softmax trick Discrete Yes No (continuous relaxation) Yes
ST-Gumbel esimator Discrete Yes Yes on forward pass
No on backward pass (continuous relaxation)
REBAR Discrete Yes Yes on forward pass
No on backward pass (continuous relaxation)

Score function estimator (trick)

The score function estimator (SF), also called REINFORCE when applied to reinforcement learning, and likelihood-ratio estimator transforms the integral into an expectation.

More specifically, using the property that we can rewrite the gradient as an expectation

We can now use Monte-Carlo to estimate the gradient.

This estimator has been shown to have issues such as high variance. This problem can be alleviated by subtracting a control variate or baseline to and adding its mean back:


  • Extreme value theory
  • Reinforcement learning (known as REINFORCE)

Reparameterization trick

Sometimes the random variable can be reparameterized as a deterministic function of and of a random variable , where does not depend on :

For instance the Gaussian variable can be rewritten as a function of a standard Gaussian variable , such that .

In that case the gradient rewrites as


  • must be differentiable w.r.t its input. This was not the case for the score function estimator.
  • must exist and be differentiable w.r.t. . This not obvious for discrete categorical variables . However, for discrete variables, we will see that:
    • the Gumbel-max trick does provide a although it is nondifferentiable w.r.t.
    • the Gumbel-softmax trick is a relaxation of the Gumbel-max trick that provides



Gumbel-max trick

In the next sections we will interchangeably use the integer representation and the one-hot representation for the same discrete categorical variable .

The Gumbel-max trick was proposed by Gumbel, Julius, Lieblein (1954) - Statistical theory of extreme values[...] to express a discrete categorical variable as a deterministic function of the class probabilities and independent random variables, called Gumbel variables.

Let be a discrete categorical variable, which can take values, and is parameterized by . The obvious way to sample is to use its cumulated distribution function to invert a uniform random variable. However, we would like to use the reparametrization trick.

Another way is to define variables that follow a Gumbel distribution, which can be obtained as where . Then the random variable

follows the correct categorical distribution .

However we cannot apply the reparametrization trick because is non-differentiable w.r.t the parameters that we want to optimize. We now present the Gumbel-softmax trick which relaxes the Gumbel-max trick to make differentiable.


  • Extreme-value theory?


The idea of replacing the of the Gumbel-max trick with a was concurrently presented by Jang, Gu, Poole (2017) - Categorical reparameterization with Gumbel Softmax (under the name Gumbel-softmax) and Maddison, Mnih, Teh (2017) - The Concrete Distribution (under the name Concrete distribution). More precisely, define a Gumbal Softmax random variable

where is a temperature parameter, and as before. The references give an analytical expression for the distribution of the Gumbel-softmax.

Note that the previous expression gives the value of x as a deterministic function of , not the distribution p(x). So is actually a continuous value supported on the simplex .

The above authors show interesting properties of the Gumbel-softmax:

  • When , the vector becomes one-hot, and as expected, the hot component follows the categorical distribution .
  • When , the vector becomes uniform, and all samples look the same.
  • , since the softmax keeps the relative ordering of the
  • When , the probability density becomes convex.
    • when is convex, the modes are concentrated on the corners of which means samples will tend to be one-hot.

Now we can write and is differentiable w.r.t. . We can use the reparameterization trick!

However, note that does not exactly follow . There is a tradeoff between having accurate one-hot samples and badly conditioned gradient with high variance (using low temperature), and having smoother samples and smaller gradient variance (with higher temperatures) . In practice the authors start with a high temperature and anneal to small non-zero temperatures, so as to approach the categorical distribution in the limit.



For non-zero temperatures, a Gumbel-softmax variable does not exactly follow . If in the forward pass we replace by its argmax, then we get a one-hot variable following exactly . However, in order to backpropagate the gradient, we can still keep the original, continuous , in the backward pass.

This is called Straight-Through-Gumbel-softmax in Jang's paper, and builds on ideas from Bengio, Leonard, Courville (2013) - Estimating or Propagating Gradients


  • Why not just sum over all discrete values?
  • How does it actually work? ST vs Non-ST is there some kind of -weighted sum?


  • Openreview Jang+ 2017
  • Tutorial Eric Jang allows to play with the Gumbel-softmax distribution. Code for discrete VAE on MNIST in Tensorflow.

results for ""

    No results matching ""