Note: The IPython notebook for this post can be seen here.

Here we will try to understand the reparameterization trick used by Kingma and Welling (2014)1 to train their variational autoencoder.

Assume we have a normal distribution $q$ that is parameterized by $\theta$, specifically $q_{\theta}(x) = N(\theta,1)$. We want to solve the below problem

This is of course a rather silly problem and the optimal $\theta$ is obvious. But here we just want to understand how the reparameterization trick helps in calculating the gradient of this objective $E_q[x^2]$.

The usual way to calculate $\nabla_{\theta} E_q[x^2]$ is as follows

which makes use of $\nabla_{\theta} q_{\theta} = q_{\theta} \nabla_{\theta} \log q_{\theta}$. This trick is also the basis of the REINFORCE2 algorithm used in policy gradient.

For our example where $q_{\theta}(x) = N(\theta,1)$, this method gives

Reparameterization trick is a way to rewrite the expectation so that the distribution with respect to which we take the gradient is independent of parameter $\theta$. To achieve this, we need to make the stochastic element in $q$ independent of $\theta$. Hence, we write $x$ as

Then, we can write

where $p$ is the distribution of $\epsilon$, i.e., $N(0,1)$. Now we can write the derivative of $E_q[x^2]$ as follows

Now let us compare the variances of the two methods; we are hoping to see that the first method has high variance while reparameterization trick decreases the variance substantially.

4.46239612174
4.1840532024


Let us plot the variance for different sample sizes.

[ 3.8409546   3.97298803  4.03007634  3.98531095  3.99579423]
[ 3.97775271  4.00232825  3.99894536  4.00353734  3.99995899]

[  6.45307927e+00   6.80227241e-01   8.69226368e-02   1.00489791e-02
8.62396526e-04]
[  4.59767676e-01   4.26567475e-02   3.33699503e-03   5.17148975e-04
4.65338152e-05]


Variance of the estimates using reparameterization trick is one order of magnitude smaller than the estimates from the first method!

#### Bibliography

1. Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. arXiv:1312.6114

2. Williams, R. J. (1992). Simple Statistical Gradient-following Algorithms for Connectionist Reinforcement Learning. Machine Learning, 8(3–4), 229–256.