# Reparameterization Trick

Note: The IPython notebook for this post can be seen here.

Here we will try to understand the reparameterization trick used by Kingma and Welling (2014)^{1} to train their variational autoencoder.

Assume we have a normal distribution \(q\) that is parameterized by \(\theta\), specifically \(q_{\theta}(x) = N(\theta,1)\). We want to solve the below problem

\[\text{min}_{\theta} \quad E_q[x^2]\]This is of course a rather silly problem and the optimal \(\theta\) is obvious. But here we just want to understand how the reparameterization trick helps in calculating the gradient of this objective \(E_q[x^2]\).

The usual way to calculate \(\nabla_{\theta} E_q[x^2]\) is as follows

\[\nabla_{\theta} E_q[x^2] = \nabla_{\theta} \int q_{\theta}(x) x^2 dx = \int x^2 \nabla_{\theta} q_{\theta}(x) \frac{q_{\theta}(x)}{q_{\theta}(x)} dx = \int q_{\theta} \nabla_{\theta} \log q_{\theta}(x) x^2 dx = E_q[x^2 \nabla_{\theta} \log q_{\theta}(x)]\]which makes use of \(\nabla_{\theta} q_{\theta} = q_{\theta} \nabla_{\theta} \log q_{\theta}\). This trick is also the basis of the REINFORCE^{2} algorithm used in policy gradient.

For our example where \(q_{\theta}(x) = N(\theta,1)\), this method gives

\[\nabla_{\theta} E_q[x^2] = E_q[x^2 (x-\theta)]\]Reparameterization trick is a way to rewrite the expectation so that the distribution with respect to which we take the expectation is independent of parameter \(\theta\). To achieve this, we need to **make the stochastic element in** \(q\) **independent of** \(\theta\). Hence, we write \(x\) as

Then, we can write

\[E_q[x^2] = E_p[(\theta+\epsilon)^2]\]where \(p\) is the distribution of \(\epsilon\), i.e., \(N(0,1)\). Now we can write the derivative of \(E_q[x^2]\) as follows

\[\nabla_{\theta} E_q[x^2] = \nabla_{\theta} E_p[(\theta+\epsilon)^2] = E_p[2(\theta+\epsilon)]\]Now let us compare the variances of the two methods; we are hoping to see that the first method has high variance while reparameterization trick decreases the variance substantially.

```
4.46239612174
4.1840532024
```

Let us plot the variance for different sample sizes.

```
[ 3.8409546 3.97298803 4.03007634 3.98531095 3.99579423]
[ 3.97775271 4.00232825 3.99894536 4.00353734 3.99995899]
[ 6.45307927e+00 6.80227241e-01 8.69226368e-02 1.00489791e-02
8.62396526e-04]
[ 4.59767676e-01 4.26567475e-02 3.33699503e-03 5.17148975e-04
4.65338152e-05]
```

Variance of the estimates using reparameterization trick is one order of magnitude smaller than the estimates from the first method!