Awesome
MLXU: Machine Learning eXperiment Utilities
This library provide a collection of utilities for machine learning experiments. MLXU is a thin wrapper on top of absl-py, ml_collections and wandb. It also provides some convenient JAX utils.
This library includes the following modules:
- config Experiment configuration and command line flags utils
- logging W&B logging utils
- jax_utils JAX specific utils
Installation
MLXU can be installed via pip. To install from PYPI
pip install mlxu
To install the latest version from GitHub
pip install git+https://github.com/young-geng/mlxu.git
Examples
Here are some examples for the utilities provide in MLXU
Command Line Flags and Logging
MLXU provides convenient wrappers around absl-py and wandb to make command line arg parsing and logging easy.
import mlxu
class ConfigurableModule(object):
# Define a configurable module with a default configuration. This module
# can be directly configured from the command line when plugged into
# the FLAGS.
@staticmethod
def get_default_config(updates=None):
config = mlxu.config_dict()
config.integer_value = 10
config.float_value = 1.0
config.string_value = 'hello'
config.boolean_value = True
return mlxu.update_config_dict(config, updates)
def __init__(self, config):
self.config = self.get_default_config(config)
# Define absl command line flags in one function, with automatic type inference.
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
name='example_experiment', # string flag
seed=42, # integer flag
learning_rate=1e-3, # floating point flag
use_mlxu=True, # boolean flag
num_gpus=int, # integer flag without default value
weight_decay=float, # floating point flag without default value
save_checkpoints=bool, # boolean flag without default value
epochs=(10, 'Number of epochs'), # we can also specify help strings
network_architecture=mlxu.config_dict(
activation='relu',
hidden_dim=128,
hidden_layers=5,
), # nested ml_collections config_dict
configurable_module=ConfigurableModule.get_default_config(), # nested custom config_dict
logger=mlxu.WandBLogger.get_default_config(), # logger configuration
)
def main(argv):
# Print the command line flags
mlxu.print_flags(FLAGS, FLAGS_DEF)
# Access the flags
name = FLAGS.name
seed = FLAGS.seed
# Access nested flags
activation = FLAGS.network_architecture.activation
hidden_dim = FLAGS.network_architecture.hidden_dim
configurable_module = ConfigurableModule(FLAGS.configurable_module)
# Create logger and log metrics
logger = mlxu.WandBLogger(FLAGS.logger, mlxu.get_user_flags(FLAGS, FLAGS_DEF))
logger.log({'step': 1, 'loss': 10.5})
logger.save_pickle([1, 2, 4, 5], 'checkpoint.pkl')
# Run the main function
if __name__ == "__main__":
mlxu.run(main)
The flags can be passed in via command line arguments:
python examples/cli_logging.py \
--name='example' \
--seed=24 \
--learning_rate=1.0 \
--use_mlxu=True \
--network_architecture.activation='gelu' \
--network_architecture.hidden_dim=126 \
--network_architecture.hidden_layers=2 \
--configurable_module.integer_value=20 \
--configurable_module.float_value=2.0 \
--logger.online=True \
--logger.project='mlxu_example'
Specifically, the logger.online
option controls whether the logger will upload
the data to W&B, and the logger.project
option controls the name of the W&B
project.
JAX Random Number Generator
MLXU also provides convenient wrapper around JAX's random number generators to make it much easier to use
import jax
import jax.numpy as jnp
import mlxu
import mlxu.jax_utils as jax_utils
@jax.jit
def sum_of_random_uniform(rng_key):
# Capture RNG key to create a stateful rng key generator.
# As long as JaxRNG object is not pass through the function
# boundary, the function is still pure and jittable.
# JaxRNG object also supports the same tuple and dictionary usage like
# the jax_utils.next_rng function.
rng_generator = jax_utils.JaxRNG(rng_key)
output = jnp.zeros((2, 2))
for i in range(4):
# Each call returns a new key, altering the internal state of rng_generator
output += jax.random.uniform(rng_generator(), (2, 2))
return output
def main(argv):
# Setup global rng generator
jax_utils.init_rng(42)
# Get an rng key
rng_key = jax_utils.next_rng()
print(rng_key)
# Get a new rng key, this key should be different from the previous one
rng_key = jax_utils.next_rng()
print(rng_key)
# You can also get a tuple of N rng keys
k1, k2, k3 = jax_utils.next_rng(3)
print(k1, ', ', k2, ', ', k3)
# Dictionary of keys is also supported
rng_key_dict = jax_utils.next_rng(['k1', 'k2'])
print(rng_key_dict)
# Call a jitted function that makes use of stateful JaxRNG object internally
x = sum_of_random_uniform(jax_utils.next_rng())
print(x)
if __name__ == "__main__":
mlxu.run(main)