Note: The IPython notebook for this post can be seen here.

Here we will try to understand the reparameterization trick used by Kingma and Welling (2014)1 to train their variational autoencoder.

Assume we have a normal distribution that is parameterized by , specifically . We want to solve the below problem

This is of course a rather silly problem and the optimal is obvious. But here we just want to understand how the reparameterization trick helps in calculating the gradient of this objective .

The usual way to calculate is as follows

which makes use of . This trick is also the basis of the REINFORCE2 algorithm used in policy gradient.

For our example where , this method gives

Reparameterization trick is a way to rewrite the expectation so that the distribution with respect to which we take the expectation is independent of parameter . To achieve this, we need to make the stochastic element in independent of . Hence, we write as

Then, we can write

where is the distribution of , i.e., . Now we can write the derivative of as follows

Now let us compare the variances of the two methods; we are hoping to see that the first method has high variance while reparameterization trick decreases the variance substantially.

import numpy as np
N = 1000
theta = 2.0
x = np.random.randn(N) + theta
eps = np.random.randn(N)

grad1 = lambda x: np.sum(np.square(x)*(x-theta)) / x.size
grad2 = lambda eps: np.sum(2*(theta + eps)) / x.size

print grad1(x)
print grad2(eps)

Let us plot the variance for different sample sizes.

Ns = [10, 100, 1000, 10000, 100000]
reps = 100

means1 = np.zeros(len(Ns))
vars1 = np.zeros(len(Ns))
means2 = np.zeros(len(Ns))
vars2 = np.zeros(len(Ns))

est1 = np.zeros(reps)
est2 = np.zeros(reps)
for i, N in enumerate(Ns):
    for r in range(reps):
        x = np.random.randn(N) + theta
        est1[r] = grad1(x)
        eps = np.random.randn(N)
        est2[r] = grad2(eps)
    means1[i] = np.mean(est1)
    means2[i] = np.mean(est2)
    vars1[i] = np.var(est1)
    vars2[i] = np.var(est2)
print means1
print means2
print vars1
print vars2
[ 3.8409546   3.97298803  4.03007634  3.98531095  3.99579423]
[ 3.97775271  4.00232825  3.99894536  4.00353734  3.99995899]

[  6.45307927e+00   6.80227241e-01   8.69226368e-02   1.00489791e-02
[  4.59767676e-01   4.26567475e-02   3.33699503e-03   5.17148975e-04
%matplotlib inline
import matplotlib.pyplot as plt


Variance of naive gradient estimate vs. estimate with reparameterization trick

Variance of the estimates using reparameterization trick is one order of magnitude smaller than the estimates from the first method!


  1. Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. arXiv:1312.6114 

  2. Williams, R. J. (1992). Simple Statistical Gradient-following Algorithms for Connectionist Reinforcement Learning. Machine Learning, 8(3–4), 229–256.