Home

Awesome

evosax: JAX-Based Evolution Strategies 🦎

Pyversions PyPI version Code style: black codecov Paper <a href="https://github.com/RobertTLange/evosax/blob/main/docs/logo.png?raw=true"><img src="https://github.com/RobertTLange/evosax/blob/main/docs/logo.png?raw=true" width="170" align="right" /></a>

Tired of having to handle asynchronous processes for neuroevolution? Do you want to leverage massive vectorization and high-throughput accelerators for evolution strategies (ES)? evosax allows you to leverage JAX, XLA compilation and auto-vectorization/parallelization to scale ES to your favorite accelerators. The API is based on the classical ask, evaluate, tell cycle of ES. Both ask and tell calls are compatible with jit, vmap/pmap and lax.scan. It includes a vast set of both classic (e.g. CMA-ES, Differential Evolution, etc.) and modern neuroevolution (e.g. OpenAI-ES, Augmented RS, etc.) strategies. You can get started here 👉 Colab

Basic evosax API Usage 🍲

import jax
from evosax import CMA_ES

# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)

# Run ask-eval-tell loop - NOTE: By default minimization!
for t in range(num_generations):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, es_params)
    fitness = ...  # Your population evaluation fct 
    state = strategy.tell(x, fitness, state, es_params)

# Get best overall population member & its fitness
state.best_member, state.best_fitness

Implemented Evolution Strategies 🦎

StrategyReferenceImportExample
OpenAI-ESSalimans et al. (2017)OpenESColab
PGPESehnke et al. (2010)PGPEColab
ARSMania et al. (2018)ARSColab
ESMCMerchant et al. (2021)ESMCColab
Persistent ESVicol et al. (2021)PersistentESColab
Noise-Reuse ESLi et al. (2023)NoiseReuseESColab
xNESWierstra et al. (2014)XNESColab
SNESWierstra et al. (2014)SNESColab
CR-FM-NESNomura & Ono (2022)CR_FM_NESColab
Guided ESMaheswaranathan et al. (2018)GuidedESColab
ASEBOChoromanski et al. (2019)ASEBOColab
CMA-ESHansen & Ostermeier (2001)CMA_ESColab
Sep-CMA-ESRos & Hansen (2008)Sep_CMA_ESColab
BIPOP-CMA-ESHansen (2009)BIPOP_CMA_ESColab
IPOP-CMA-ESAuer & Hansen (2005)IPOP_CMA_ESColab
Full-iAMaLGaMBosman et al. (2013)Full_iAMaLGaMColab
Independent-iAMaLGaMBosman et al. (2013)Indep_iAMaLGaMColab
MA-ESBayer & Sendhoff (2017)MA_ESColab
LM-MA-ESLoshchilov et al. (2017)LM_MA_ESColab
RmESLi & Zhang (2017)RmESColab
Simple GeneticSuch et al. (2017)SimpleGAColab
SAMR-GAClune et al. (2008)SAMR_GAColab
GESMR-GAKumar et al. (2022)GESMR_GAColab
MR15-GARechenberg (1978)MR15_GAColab
LGALange et al. (2023b)LGAColab
Simple GaussianRechenberg (1978)SimpleESColab
DESLange et al. (2023a)DESColab
LESLange et al. (2023a)LESColab
EvoTFLange et al. (2024)EvoTF_ESColab
Diffusion EvolutionZhang et al. (2024)DiffusionEvolutionColab
SV-OpenAI-ESLiu et al. (2017)SV_OpenESColab
SV-CMA-ESBraun et al. (2024)SV_CMA_ESColab
Particle Swarm OptimizationKennedy & Eberhart (1995)PSOColab
Differential EvolutionStorn & Price (1997)DEColab
GLDGolovin et al. (2019)GLDColab
Simulated AnnealingRasdi Rere et al. (2015)SimAnnealColab
Population-Based TrainingJaderberg et al. (2017)PBTColab
Random SearchBergstra & Bengio (2012)RandomSearchColab

Installation ⏳

The latest evosax release can directly be installed from PyPI:

pip install evosax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/evosax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples 📖

Key Features 💵

from evosax.strategies.ars import ARS, EvoParams
# E.g. vectorize over different initial perturbation stds
strategy = ARS(popsize=100, num_dims=20)
es_params = EvoParams(sigma_init=jnp.array([0.1, 0.01, 0.001]), sigma_decay=0.999, ...)

# Specify how to map over ES hyperparameters 
map_dict = EvoParams(sigma_init=0, sigma_decay=None, ...)

# Vmap-composed batch initialize, ask and tell functions 
batch_init = jax.vmap(strategy.init, in_axes=(None, map_dict))
batch_ask = jax.vmap(strategy.ask, in_axes=(None, 0, map_dict))
batch_tell = jax.vmap(strategy.tell, in_axes=(0, 0, 0, map_dict))
@partial(jax.jit, static_argnums=(1,))
def run_es_loop(rng, num_steps):
    """Run evolution ask-eval-tell loop."""
    es_params = strategy.default_params
    state = strategy.initialize(rng, es_params)

    def es_step(state_input, tmp):
        """Helper es step to lax.scan through."""
        rng, state = state_input
        rng, rng_iter = jax.random.split(rng)
        x, state = strategy.ask(rng_iter, state, es_params)
        fitness = ...
        state = strategy.tell(y, fitness, state, es_params)
        return [rng, state], fitness[jnp.argmin(fitness)]

    _, scan_out = jax.lax.scan(es_step,
                               [rng, state],
                               [jnp.zeros(num_steps)])
    return jnp.min(scan_out)
from flax import linen as nn
from evosax import ParameterReshaper

class MLP(nn.Module):
    num_hidden_units: int
    ...

    @nn.compact
    def __call__(self, obs):
        ...
        return ...

network = MLP(64)
net_params = network.init(rng, jnp.zeros(4,), rng)

# Initialize reshaper based on placeholder network shapes
param_reshaper = ParameterReshaper(net_params)

# Get population candidates & reshape into stacked pytrees
x = strategy.ask(...)
x_shaped = param_reshaper.reshape(x)
from evosax import FitnessShaper

# Instantiate jittable fitness shaper (e.g. for Open ES)
fit_shaper = FitnessShaper(centered_rank=True,
                           z_score=False,
                           weight_decay=0.01,
                           maximize=True)

# Shape the evaluated fitness scores
fit_shaped = fit_shaper.apply(x, fitness) 
<details> <summary>Additonal Work-In-Progress</summary> **Strategy Restart Wrappers**: *Work-in-progress*. You can also choose from a set of different restart mechanisms, which will relaunch a strategy (with e.g. new population size) based on termination criteria. Note: For all restart strategies which alter the population size the ask and tell methods will have to be re-compiled at the time of change. Note that all strategies can also be executed without explicitly providing `es_params`. In this case the default parameters will be used.
```Python
from evosax import CMA_ES
from evosax.restarts import BIPOP_Restarter

# Define a termination criterion (kwargs - fitness, state, params)
def std_criterion(fitness, state, params):
    """Restart strategy if fitness std across population is small."""
    return fitness.std() < 0.001

# Instantiate Base CMA-ES & wrap with BIPOP restarts
# Pass strategy-specific kwargs separately (e.g. elite_ration or opt_name)
strategy = CMA_ES(num_dims, popsize, elite_ratio)
re_strategy = BIPOP_Restarter(
                strategy,
                stop_criteria=[std_criterion],
                strategy_kwargs={"elite_ratio": elite_ratio}
            )
state = re_strategy.initialize(rng)

# ask/tell loop - restarts are automatically handled 
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = re_strategy.ask(rng_gen, state)
fitness = ...  # Your population evaluation fct 
state = re_strategy.tell(x, fitness, state)
```

- **Batch Strategy Rollouts**: *Work-in-progress*. We are currently also working on different ways of incorporating multiple subpopulations with different communication protocols.

```Python
from evosax.experimental.subpops import BatchStrategy

# Instantiates 5 CMA-ES subpops of 20 members
strategy = BatchStrategy(
        strategy_name="CMA_ES",
        num_dims=4096,
        popsize=100,
        num_subpops=5,
        strategy_kwargs={"elite_ratio": 0.5},
        communication="best_subpop",
    )

state = strategy.initialize(rng)
# Ask for evaluation candidates of different subpopulation ES
x, state = strategy.ask(rng_iter, state)
fitness = ...
state = strategy.tell(x, fitness, state)
```

- **Indirect Encodings**: *Work-in-progress*. ES can struggle with high-dimensional search spaces (e.g. due to harder estimation of covariances). One potential way to alleviate this challenge, is to use indirect parameter encodings in a lower dimensional space. So far we provide JAX-compatible encodings with random projections (Gaussian/Rademacher) and Hypernetworks for MLPs. They act as drop-in replacements for the `ParameterReshaper`:

```Python
from evosax.experimental.decodings import RandomDecoder, HyperDecoder

# For arbitrary network architectures / search spaces
num_encoding_dims = 6
param_reshaper = RandomDecoder(num_encoding_dims, net_params)
x_shaped = param_reshaper.reshape(x)

# For MLP-based models we also support a HyperNetwork en/decoding
reshaper = HyperDecoder(
        net_params,
        hypernet_config={
            "num_latent_units": 3,  # Latent units per module kernel/bias
            "num_hidden_units": 2,  # Hidden dimensionality of a_i^j embedding
        },
    )
x_shaped = param_reshaper.reshape(x)
```
</details>

Resources & Other Great JAX-ES Tools 📝

Acknowledgements & Citing evosax ✏️

If you use evosax in your research, please cite the following paper:

@article{evosax2022github,
  author = {Robert Tjarko Lange},
  title = {evosax: JAX-based Evolution Strategies},
  journal={arXiv preprint arXiv:2212.04180},
  year = {2022},
}

We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.

Development 👷

You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.

Disclaimer ⚠️

This repository contains an independent reimplementation of LES and DES based on the corresponding ICLR 2023 publication (Lange et al., 2023). It is unrelated to Google or DeepMind. The implementation has been tested to roughly reproduce the official results on a range of tasks.