Keywords: reparametrization trick, Gumbel max trick, Gumbel softmax, Concrete distribution, score function estimator, REINFORCE
In the context of deep learning, we often want to backpropagate a gradient through samples , where is a learned parametric distribution.
For example we might want to train a variational autoencoder. Conditioned on the input , the latent representation is not a single value but a distribution , generally a Gaussian distribution which parameters are given by a (inference) neural network of parameters . When learning to maximize the likelihood of the data, we need to backpropagate the loss to the parameters of the inference network, across the distribution of or across samples .
TODO: talk about REINFORCE
More specifically, we want to minimize an expected cost
using gradient descent, which requires to compute the gradients and .
Under certain conditions, Leibniz's rule states that the gradient and expectation can be swapped, resulting in
which can be estimated using Monte-Carlo:
with iid samples .
So computing is fairly straightforward and requires only that:
Computing this gradient is much harder because parametrizes the expectation. Naturally we can rewrite the expectation as an integral over , and use Leibniz's rule again
but now the integral does not have the form of an expectation, so we cannot use Monte-Carlo to estimate its value.
So computing is not straighforward. However notice that:
In the rest of the article we review a bunch of different tricks to compute the expectation depending on the particular application.
The table below sums up some ways to deal with samples in a computation graph. Everything in bold is either more powerful or less constraining. In the context of deep learning, the most important attributes are that the loss is differentiable w.r.t. , so that the parameters can be learned using gradient descent.
Method | Continuous or Discrete | Backpropable Differentiable w.r.t |
Follow exact distribution | must exist |
---|---|---|---|---|
Score function estimator | Continuous and discrete | Yes | Yes | No |
Reparameterization trick | Continuous | Yes | Yes | Yes |
Gumbel-max trick | Discrete | No | Yes | |
Gumbel-softmax trick | Discrete | Yes | No (continuous relaxation) | Yes |
ST-Gumbel esimator | Discrete | Yes | Yes on forward pass No on backward pass (continuous relaxation) |
Yes |
REBAR | Discrete | Yes | Yes on forward pass No on backward pass (continuous relaxation) |
? |
The score function estimator (SF), also called REINFORCE when applied to reinforcement learning, and likelihood-ratio estimator transforms the integral into an expectation.
More specifically, using the property that we can rewrite the gradient as an expectation
We can now use Monte-Carlo to estimate the gradient.
This estimator has been shown to have issues such as high variance. This problem can be alleviated by subtracting a control variate or baseline to and adding its mean back:
Applications:
Sometimes the random variable can be reparameterized as a deterministic function of and of a random variable , where does not depend on :
For instance the Gaussian variable can be rewritten as a function of a standard Gaussian variable , such that .
In that case the gradient rewrites as
Requirements:
Applications:
Links:
In the next sections we will interchangeably use the integer representation and the one-hot representation for the same discrete categorical variable .
The Gumbel-max trick was proposed by Gumbel, Julius, Lieblein (1954) - Statistical theory of extreme values[...] to express a discrete categorical variable as a deterministic function of the class probabilities and independent random variables, called Gumbel variables.
Let be a discrete categorical variable, which can take values, and is parameterized by . The obvious way to sample is to use its cumulated distribution function to invert a uniform random variable. However, we would like to use the reparametrization trick.
Another way is to define variables that follow a Gumbel distribution, which can be obtained as where . Then the random variable
follows the correct categorical distribution .
However we cannot apply the reparametrization trick because is non-differentiable w.r.t the parameters that we want to optimize. We now present the Gumbel-softmax trick which relaxes the Gumbel-max trick to make differentiable.
Applications:
The idea of replacing the of the Gumbel-max trick with a was concurrently presented by Jang, Gu, Poole (2017) - Categorical reparameterization with Gumbel Softmax (under the name Gumbel-softmax) and Maddison, Mnih, Teh (2017) - The Concrete Distribution (under the name Concrete distribution). More precisely, define a Gumbal Softmax random variable
where is a temperature parameter, and as before. The references give an analytical expression for the distribution of the Gumbel-softmax.
Note that the previous expression gives the value of x as a deterministic function of , not the distribution p(x). So is actually a continuous value supported on the simplex .
The above authors show interesting properties of the Gumbel-softmax:
Now we can write and is differentiable w.r.t. . We can use the reparameterization trick!
However, note that does not exactly follow . There is a tradeoff between having accurate one-hot samples and badly conditioned gradient with high variance (using low temperature), and having smoother samples and smaller gradient variance (with higher temperatures) . In practice the authors start with a high temperature and anneal to small non-zero temperatures, so as to approach the categorical distribution in the limit.
Applications:
For non-zero temperatures, a Gumbel-softmax variable does not exactly follow . If in the forward pass we replace by its argmax, then we get a one-hot variable following exactly . However, in order to backpropagate the gradient, we can still keep the original, continuous , in the backward pass.
This is called Straight-Through-Gumbel-softmax in Jang's paper, and builds on ideas from Bengio, Leonard, Courville (2013) - Estimating or Propagating Gradients