Awesome
Meliad
This is not an officially supported Google product.
This code is provided "as-is" to the broader research community. Google does not promise to maintain or otherwise support this code in any way.
Introduction
The Meliad library is collection of models which are being developed as part of ongoing research into various architectural improvements in deep learning. The name "meliad" is the Greek word for a tree nymph; a long-term goal of this research is to design architectures that can understand recursive and compositional structures, i.e. trees.
The library currently consists of several transformer variations, which explore ways in which the popular transformer architecture can be extended to better support language modeling over long sequences.
Transformer-XL with sliding window
This model is provided as a baseline. It is similar to the Transformer-XL architecture, but uses a T5-style relative position bias. A long sequence, such as a book, is divided into segments of fixed length, e.g. 4096 tokens. The segments are processed in order, with one segment per training step.
Attention within a segment is done locally using sliding window that is typically smaller than the segment length. A causal mask ensures that each token can attend to exactly W previous tokens, where W is the window size, e.g. 512 or 1024. The complexity of attention is quadratic with respect to window size, but linear with respect to segment length, so the segment length is limited only by available device memory. Like Transformer-XL, the model caches the keys and values from the last window for use on the next training step, and thus implements truncated backpropagation through time over very long (book-length) works.
If the window and segment lengths are the same, then there is no sliding window (just the T-XL cache), and this model will behave like Transformer-XL. However, the cache is not differentiable, whereas the sliding window is, so there is some benefit to using segments that are longer than the window length. Gradients with the sliding window can potentially be backpropagated across the length of the entire segment.
Memorizing Transformer
The Memorizing Transformer equips one layer of the transformer with a large external memory that stores prior (key,value) pairs. Typical memory sizes are 32k or 64k tokens. In addition to local attention, the model can do k-nearest-neighbor lookup into the external memory, which allows it to handle long-range dependencies; the range is limited only by the size of the memory.
The external memory, like the T-XL cache, is not differentiable. Memory and the T-XL cache work well together; the memory is used for long-range lookups, while the cache is used for short-range lookups. However, memory should not be used with a sliding window, so the window and segment length should be the same.
Block-Recurrent Transformer
The Block-Recurrent Transformer equips one layer of the transformer with a recurrent cell. The cell is structured similarly to an LSTM cell, but it is several orders of magnitude larger, and operates on blocks of tokens and blocks of recurrent state vectors. Recurrence is integrated with the sliding window mechanism; the block size is the same as the window size.
Recurrence serves a similar role to external memory, but is faster. The recurrent state has a fixed capacity, but unlimited range (in theory).
Installation instructions
Create an activate a python virtual environment. (Commands given are for linux).
python -m venv my_env
source my_env/bin/activate
Install required packages into the python virtual environment. If you want to use GPUs, then Jax must be upgraded to use CUDA. Installing t5 after upgrading jax may be necessary to avoid link errors (we don't know why).
pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install t5
On Unix systems, you may need to ensure that PYTHONPATH
includes the
current directory. All module names are given relative to the meliad root.
export PYTHONPATH=.:$PYTHONPATH
Run a small baseline model on a synthetic test dataset.
python transformer/ht_main.py --alsologtostderr \
--gin_file=base_htrans.gin \
--gin_file=size/small_test.gin
Configuring and running the model
Meliad uses gin to configure the model.
The first gin file should always be
base_htrans.gin
, which supplies a default configuration. Other options are
specified as additional files in the configs directory. Most options are
orthogonal, but in some cases the order matters; inspect the contents of the
gin files to determine the correct order.
Some important options are:
size/medium150M.gin
The 150M parameter model in the paper.options/positions_t5.gin
Use a T5-style relative position bias.options/seq_4096.gin
Use a segment length of 4096 tokens.options/window_1024.gin
Use a sliding window of size 1024. (The default is 512).options/lr_cosine_decay.gin
Cosine decay learning rate schedule.
Tasks are also defined in gin files:
tasks/pg19_tokens.gin
Run on PG19 with the default T5 sentencepiece vocabulary.
Other important command-line options:
--alsologtostderr
View the progress of the model.--workdir=/my/work/directory
For checkpoints and tensorboard.--load_dir=/location/of/pretrained/model
For finetuning.--default_data_dir=/location/of/tfds/datasets
For tensorflow datasets.
For the Memorizing Transformer:
size/medium150M.gin
The 150M parameter model in the paper.options/positions_t5.gin
Use a T5-style relative position bias.options/seq_512.gin
Segment length of 512. (Window is 512 by default).options/external_memory_32k.gin
Memorizing Transformer with a memory size of 32k.
For the Block-Recurrent Transformer:
size/medium150M.gin
The 150M parameter model in the paper.options/positions_t5.gin
Use a T5-style relative position bias.options/seq_4096.gin
Segment length of 4096. (Window is 512 by default).recurrent/bias_skip.gin
The fixed:skip configuration.