Home

Awesome

MLXU: Machine Learning eXperiment Utilities

This library provide a collection of utilities for machine learning experiments. MLXU is a thin wrapper on top of absl-py, ml_collections and wandb. It also provides some convenient JAX utils.

This library includes the following modules:

Installation

MLXU can be installed via pip. To install from PYPI

pip install mlxu

To install the latest version from GitHub

pip install git+https://github.com/young-geng/mlxu.git

Examples

Here are some examples for the utilities provide in MLXU

Command Line Flags and Logging

MLXU provides convenient wrappers around absl-py and wandb to make command line arg parsing and logging easy.

import mlxu


class ConfigurableModule(object):
    # Define a configurable module with a default configuration. This module
    # can be directly configured from the command line when plugged into
    # the FLAGS.

    @staticmethod
    def get_default_config(updates=None):
        config = mlxu.config_dict()
        config.integer_value = 10
        config.float_value = 1.0
        config.string_value = 'hello'
        config.boolean_value = True
        return mlxu.update_config_dict(config, updates)

    def __init__(self, config):
        self.config = self.get_default_config(config)


# Define absl command line flags in one function, with automatic type inference.
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
    name='example_experiment',          # string flag
    seed=42,                            # integer flag
    learning_rate=1e-3,                 # floating point flag
    use_mlxu=True,                      # boolean flag
    num_gpus=int,                       # integer flag without default value
    weight_decay=float,                 # floating point flag without default value
    save_checkpoints=bool,              # boolean flag without default value
    epochs=(10, 'Number of epochs'),    # we can also specify help strings
    network_architecture=mlxu.config_dict(
        activation='relu',
        hidden_dim=128,
        hidden_layers=5,
    ),                                  # nested ml_collections config_dict
    configurable_module=ConfigurableModule.get_default_config(),  # nested custom config_dict
    logger=mlxu.WandBLogger.get_default_config(),  # logger configuration
)


def main(argv):
    # Print the command line flags
    mlxu.print_flags(FLAGS, FLAGS_DEF)

    # Access the flags
    name = FLAGS.name
    seed = FLAGS.seed

    # Access nested flags
    activation = FLAGS.network_architecture.activation
    hidden_dim = FLAGS.network_architecture.hidden_dim

    configurable_module = ConfigurableModule(FLAGS.configurable_module)

    # Create logger and log metrics
    logger = mlxu.WandBLogger(FLAGS.logger, mlxu.get_user_flags(FLAGS, FLAGS_DEF))
    logger.log({'step': 1, 'loss': 10.5})
    logger.save_pickle([1, 2, 4, 5], 'checkpoint.pkl')


# Run the main function
if __name__ == "__main__":
    mlxu.run(main)

The flags can be passed in via command line arguments:

python examples/cli_logging.py \
    --name='example' \
    --seed=24 \
    --learning_rate=1.0 \
    --use_mlxu=True \
    --network_architecture.activation='gelu' \
    --network_architecture.hidden_dim=126 \
    --network_architecture.hidden_layers=2 \
    --configurable_module.integer_value=20 \
    --configurable_module.float_value=2.0 \
    --logger.online=True \
    --logger.project='mlxu_example'

Specifically, the logger.online option controls whether the logger will upload the data to W&B, and the logger.project option controls the name of the W&B project.

JAX Random Number Generator

MLXU also provides convenient wrapper around JAX's random number generators to make it much easier to use

import jax
import jax.numpy as jnp
import mlxu
import mlxu.jax_utils as jax_utils


@jax.jit
def sum_of_random_uniform(rng_key):
    # Capture RNG key to create a stateful rng key generator.
    # As long as JaxRNG object is not pass through the function
    # boundary, the function is still pure and jittable.
    # JaxRNG object also supports the same tuple and dictionary usage like
    # the jax_utils.next_rng function.
    rng_generator = jax_utils.JaxRNG(rng_key)
    output = jnp.zeros((2, 2))
    for i in range(4):
        # Each call returns a new key, altering the internal state of rng_generator
        output += jax.random.uniform(rng_generator(), (2, 2))

    return output


def main(argv):
    # Setup global rng generator
    jax_utils.init_rng(42)

    # Get an rng key
    rng_key = jax_utils.next_rng()
    print(rng_key)

    # Get a new rng key, this key should be different from the previous one
    rng_key = jax_utils.next_rng()
    print(rng_key)

    # You can also get a tuple of N rng keys
    k1, k2, k3 = jax_utils.next_rng(3)
    print(k1, ', ', k2, ', ', k3)

    # Dictionary of keys is also supported
    rng_key_dict = jax_utils.next_rng(['k1', 'k2'])
    print(rng_key_dict)

    # Call a jitted function that makes use of stateful JaxRNG object internally
    x = sum_of_random_uniform(jax_utils.next_rng())
    print(x)


if __name__ == "__main__":
    mlxu.run(main)