Awesome
Table of contents
STaR
Code for STaR: Bootstrapping Reasoning With Reasoning (NeurIPS 2022). This library is built on top of mesh-transformer-jax and incorporates masked training from this repo. In order to run it, launch iteration_train.py
with any desired arguments. The README is left mostly unchanged, as iteration_train
largely wraps around device_train.py
, device_inference.py
, and create_finetune_tfrecords.py
.
Mesh Transformer JAX
A haiku library using the xmap
/pjit
operators in JAX for model parallelism of transformers.
The parallelism scheme is similar to the original Megatron-LM, which is efficient on TPUs due to the high speed 2d mesh network. There is also an experimental model version which implements ZeRo style sharding.
This library is designed for scalability up to approximately 40B parameters on TPUv3s, beyond which different parallelism strategies should be used. See other implementations such as GPT-NeoX or DeepSpeed for that.
One future direction for research is integrating this codebase with swarm-jax, to achieve further scalability with pipeline parallelism.
Updates
12-07-21: Added guide to fine tuning
Pretrained Models
GPT-J-6B
A 6 billion parameter, autoregressive text generation model trained on The Pile.
Links
Slim weights (bf16 weights only, for inference, 9GB)
Full weights (including optimizer params, 61GB)
Acknowledgments
This project would not have been possible without compute generously provided by the TPU Research Cloud with assistance from EleutherAI.
Thanks to the Cloud TPU team at Google for providing early access to the Cloud TPU VM alpha (now publicly available!)
Thanks to everyone who have helped out one way or another (listed alphabetically):
- Aran Komatsuzaki for advice with experiment design and writing the blog posts.
- James Bradbury for valuable assistance with debugging JAX issues.
- Janko Prester for creating the web demo frontend.
- Laurence Golding for adding some features to the web demo.
- Leo Gao for running zero shot evaluations for the baseline models for the table.
License
The weights of GPT-J-6B are licensed under version 2.0 of the Apache License.
Model Details
Hyperparameter | Value |
---|---|
n_parameters | 6,053,381,344 |
n_layers | 28* |
d_model | 4,096 |
d_ff | 16,384 |
n_heads | 16 |
d_head | 256 |
n_ctx | 2,048 |
n_vocab | 50,257 (same tokenizer as GPT-2/3) |
position encoding | Rotary position encodings (RoPE) |
RoPE dimensions | 64 |
*
each layer consists of one feedforward block and one self attention block
The model consists of 28 layers with a model dimension of 4096, and a feedforward dimension of 16384. The model dimension is split into 16 heads, each with a dimension of 256. Rotary position encodings (RoPE) was applied to 64 dimensions of each head. The model is trained with a tokenization vocabulary of 50257, using the same set of BPEs as GPT-2/GPT-3.
Zero-Shot Evaluations
Models roughly sorted by performance, or by FLOPs if not available.
Model | Weights | Training FLOPs | LAMBADA PPL ↓ | LAMBADA Acc ↑ | Winogrande ↑ | Hellaswag ↑ | PIQA ↑ | Dataset Size (GB) |
---|---|---|---|---|---|---|---|---|
Chance | ✔ | 0 | ~a lot | ~0% | 50% | 25% | 25% | 0 |
GPT-3-Ada‡ | ✘ | ----- | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | ----- |
GPT-2-1.5B | ✔ | ----- | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 |
GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7.50 | 57.2% | 55.0% | 48.9% | 71.1% | 825 |
Megatron-2.5B* | ✘ | 2.4e21 | ----- | 61.7% | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 825 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 |
GPT-3-Babbage‡ | ✘ | ----- | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | ----- |
Megatron-8.3B* | ✘ | 7.8e21 | ----- | 66.5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 |
Megatron-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B‡ | ✔ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 825 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70.3% | 64.5% | 67.4% | 78.0% | ~800 |
GPT-3-Curie‡ | ✘ | ----- | 4.00 | 69.3% | 65.6% | 68.5% | 77.9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76.2% | 70.2% | 78.9% | 81.0% | ~800 |
GPT-3-Davinci‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
Gopher 230B* | ✘ | 6.31E+23 | ----- | 74.50% | 70.10% | 79.20% | 81.80% | 1344 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76.6% | 73.0% | 80.2% | 82.0% | ----- |
*
represents evaluation numbers reported by their respective authors, all other numbers are provided by
running the lm-evaluation-harness either with the released
weights or with API access. Due to subtle implementation differences as well as different zero shot task framing, these
might not be directly comparable. See this blog post for more
details.
†
The Megatron-11B model provides no comparable metrics, and several implementations using the released weights do not
reproduce the generation quality and evaluations. (see 1
2 3)
Thus, evaluation was not attempted.
‡
These models have been trained with data which contains possible test set contamination. The OpenAI GPT-3 models
failed to deduplicate training data for certain test sets, while the GPT-Neo models as well as this one is
trained on The Pile, which has not been deduplicated against any test sets.
Architecture and Usage
Most scripts in this repository are designed to be run on TPUs, which under the TPU-VM architecture are virtual machines which can run arbitrary code. Most scripts are designed to spin up a TPU, SSH into it to set up the dependencies and copy code over from the local directory, and then start a Ray worker which can accept RPC calls.
The TPUVMs handles running model training steps and evaluation, checkpoint save and loading, while the driver python program handles data loading and general orchestration (such as when to save checkpoints etc).
This means that most scripts (train.py
, eval_harness.py
etc) expect to be running on a GCE virtual machine in the
same region as the TPUs, to minimize RPC latency and data transfer cost. Other scripts
(usually ones which don't take a --tpu
argument, such as device_sample.py
, device_serve.py
or device_train.py
)
expect to be run directly on a TPUVM. The device_* scripts only work on a v3-8 and not on larger pods.
Furthermore, there is an example (resharding_example.py
) of how to convert the provided checkpoints (which have 8
shards in the case of GPT-J-6B) down to a smaller number, such as for when running on GPU(s).
Fine-tuning
To fine-tune the model, run device_train.py
on a TPU VM. Using a TPU v3-8, you can fine-tune at a rate of ~5000
tokens/second, which should be sufficient for small-to-medium-size datasets.
Please read the step by step guide for thorough fine-tuning instructions.
JAX Dependency
Note this library has some specific requirements for JAX version. Specifically, to use the v1 models (including
GPT-J 6B), jax==0.2.12
is required. This in turn depends on jaxlib==0.1.68
. If this is not done, you will get
cryptic xmap errors
However, to use the v2 model code (no publicly released weights), the newest JAX version can be used.
Citation
To cite this repository:
@inproceedings{
zelikman2022star,
title={{ST}aR: Bootstrapping Reasoning With Reasoning},
author={Eric Zelikman and Yuhuai Wu and Jesse Mu and Noah Goodman},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=_3ELRdg2sgI}
}
To cite the base repository:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
To cite the weights of GPT-J-6B:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
If you use this repository or any of the pretrained weights to do something cool, we would love to hear about it. Feel free to open a github issue or reach out over email (in profile).
TODO
- disentangle heads and shards
- test/benchmark on TPU
- implement gradient checkpointing
- fix initialization
- mixed precision
- deal with preemptible TPUs
- test and validate generation
- shard activations instead of replicating for memory efficiency (in v2)
- support ZeRO style sharding (in v2)