Home

Awesome

minGPT in JAX

This is Karpathy’s minGPT re-written in JAX.

The GPT model is completely 1:1 with Karpathy’s. Been tested by loading Pytorch weights into JAX.

The Trainer is a bit different. It trains as good as in Pytorch, but for that I had to decrease the batch size by 2 times. Worth exploring!

mingpt

A JAX re-implementation of GPT training. minGPT tries to be small, clean, interpretable and educational, as most of the currently available ones are a bit sprawling. GPT is not a complicated model and this implementation is appropriately about 300 lines of code, including boilerplate and a totally unnecessary custom causal self-attention module. Anyway, all that's going on is that a sequence of indices goes into a sequence of transformer blocks, and a probability distribution of the next index comes out. The rest of the complexity is just being clever with batching (both across examples and over sequence length) so that training is efficient.

The core minGPT "library" (hah) is two files: mingpt/model.py contains the actual Transformer model definition and mingpt/trainer.py is (GPT-independent) JAX boilerplate that trains the model. The attached Jupyter notebooks then show how the "library" (hah) can be used to train sequence models:

With a bpe encoder, distributed training and maybe fp16 this implementation may be able to reproduce GPT-1/GPT-2 results, though I haven't tried $$$. GPT-3 is likely out of reach as my understanding is that it does not fit into GPU memory and requires a more careful model-parallel treatment.

Example usage

This code is simple enough to just hack inline, not "used", but current API looks something like:


# you're on your own to define a class that returns individual examples as PyTorch LongTensors
from torch.utils.data import Dataset
train_dataset = MyDataset(...)
test_dataset = MyDataset(...)

# construct a GPT model
from mingpt.model import gpt, loss_fn, GPTConfig
mconf = GPTConfig(vocab_size, block_size, n_layer=12, n_head=12, n_embd=768) # a GPT-1
hk_loss_fn = hk.transform(partial(loss_fn, config=gpt_config, is_training=True))

# construct a trainer
from mingpt.trainer import Trainer, TrainerConfig
tconf = TrainerConfig(max_epochs=10, batch_size=256)
trainer = Trainer(hk_loss_fn, train_dataset, test_dataset, tconf)
params = trainer.init_params() 
params, _ = trainer.train(params)
# (... enjoy the show for a while... )

# sample from the model. no need for a dummy batch dimension, it works on single example. use jax.vmap for batch processing
from mingpt.utils import sample
model = hk.transform(partial(gpt, config=gpt_config, is_training=False))
model = hk.without_apply_rng(model).apply
x = jnp.array([1, 2, 3]) # context 
y = sample(params, model, gpt_config, x, steps=30, temperature=1.0, sample=True, top_k=5)
print(y) # our model filled in the integer sequence with 30 additional likely integers

References

Code:

Papers + some implementation notes:

Improving Language Understanding by Generative Pre-Training (GPT-1)

Language Models are Unsupervised Multitask Learners (GPT-2)

Language Models are Few-Shot Learners (GPT-3)

Generative Pretraining from Pixels (Image GPT)

License

Apache 2.0