Home

Awesome

RecurrentGemma

RecurrentGemma is a family of open-weights Language Models by Google DeepMind, based on the novel Griffin architecture. This architecture achieves fast inference when generating long sequences by replacing global attention with a mixture of local attention and linear recurrences.

This repository contains the model implementation and examples for sampling and fine-tuning. We recommend most users adopt the Flax implementation, which is highly optimized. We also provide an un-optimized PyTorch implementation for reference.

Learn more about RecurrentGemma

Quick start

Installation

Using Poetry

RecurrentGemma uses Poetry for dependency management.

To install dependencies for the full project:

If you only need to install a subset of dependencies use one of the alternative library-specific commands below.

Using pip

If you want to use pip instead of Poetry, then create a virtual environment (run python -m venv recurrentgemma-demo and . recurrentgemma-demo/bin/activate) and:

Installing library-specific packages

JAX

To install dependencies only for the JAX pathway use: poetry install -E jax or (pip install .[jax]).

PyTorch

To install dependencies only for the PyTorch pathway use: poetry install -E torch (or pip install .[torch]).

Tests

To install dependencies required for running unit tests use: poetry install -E test (or pip install .[test])

Downloading the models

The model checkpoints are available through Kaggle at http://kaggle.com/models/google/recurrentgemma. Select either the Flax or PyTorch model variations, click the ⤓ button to download the model archive, then extract the contents to a local directory.

In both cases, the archive contains both the model weights and the tokenizer.

Running the unit tests

To run the tests, install the optional [test] dependencies (e.g. using pip install .[test]) from the root of the source tree, then:

pytest .

Examples

To run the example sampling script, pass the paths to the weights directory and tokenizer:

python examples/sampling_jax.py \
  --path_checkpoint=/path/to/archive/contents/2b/ \
  --path_tokenizer=/path/to/archive/contents/tokenizer.model

Colab notebook tutorials

To run these notebooks you will need to have a Kaggle account and first read and accept the Gemma license terms and conditions from the RecurrentGemma page. After this you can run the notebooks, which will automatically download the weights and tokenizer from there.

Currently different notebooks are supported under the following hardware:

HardwareT4P100V100A100TPUv2TPUv3+
Sampling in Jax
Sampling in PyTorch
Finetuning in Jax

System Requirements

RecurrentGemma code can run on CPU, GPU or TPU. The code has been optimized for running on TPU using the Flax implementation, which contains a low level Pallas kernel to perform the linear scan in the recurrent layers.

Contributing

We are open to bug reports and issues. Please see CONTRIBUTING.md for details on PRs.

License

Copyright 2024 DeepMind Technologies Limited

This code is licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Disclaimer

This is not an official Google product.