Home

Awesome

Attend, Infer, Repeat

Implementation of continuous relaxation of AIR framework proposed in "Attend, Infer, Repeat: Fast Scene Understanding with Generative Models" (Eslami et al., 2016). The work has been done in equal contributions with Alexander Prams. The model is implemented in TensorFlow.

Noisy gradients of discrete z<sub>pres</sub> (Bernoulli random variable sampled to predict the presence of another digit on a canvas: 1 meaning “yes”, 0 – “no”) caused severe stability issues in training the model. NVIL (Mnih & Gregor, 2014) was originally used to alleviate the problem of gradient noise, but it did not make the training process stable enough. Concrete (Gumbel-Softmax) random variable (Maddison et al., 2016, Jang et al., 2016) – a continuous relaxation of discrete random variable – was employed to improve training stability.

Discrete z<sub>pres</sub> was replaced by continuous analogue sampled from Concrete distribution with temperature 1.0 and taking values between 0 and 1. Correspondingly, original Bernoulli KL-divergence was replaced by MC-sample of Concrete KL-divergence. Furthermore, two additional adaptations were made. First, VAE reconstructions were scaled by z<sub>pres</sub> before being added to reconstruction canvas. This pushes continuous samples to 0 or 1 when the model wants to stop or attend to another digit respectively. Second, inspired by ACT (Graves, 2016), stopping criterion was reformulated as a running sum of (1 – z<sub>pres</sub>) values at each time step exceeding some configurable threshold (0.99 used in experiments). The threshold being less than 1 allows stopping during very first time step, which is essential for empty images that should not be attended at all. As a result, in the limit of Concrete z<sub>pres</sub> samples taking extreme values of 0 and 1 this relaxed model turns into the original AIR with discrete z<sub>pres</sub>.

After applying the continuous relaxation, 10 out of 10 training runs in a row converged towards 98% digit count accuracy in the average course of 25,000 iterations. All 10 trainings were conducted for 300 epochs (276k iterations) with the default set of hyperparameters from training.py, some of them being: 256 LSTM cells, learning rate of 10<sup>-4</sup>, gradient clipping with the global norm of 1.0, and smooth exponential decay of z<sub>pres</sub> prior log-odds from 10<sup>4</sup> to 10<sup>-9</sup> during the first 40,000 iterations. Below charts show digit count accuracy for the entire validation set (above) and its subsets of 0-, 1-, and 2-digit (left to right) images respectively (below):

alt text alt text

The samples of attention/reconstruction made by an AIR model trained with traning.py (for each pair: original on the left, reconstruction on the right; red attention window corresponds to the first time step, green – to the second one):

alt text