Home

Awesome

Makani: Massively parallel training of machine-learning based weather and climate models

Overview | Getting started | More information | Known issues | Contributing | Further reading | References

tests

Makani (the Hawaiian word for wind šŸƒšŸŒŗ) is an experimental library designed to enable the research and development of machine-learning based weather and climate models in PyTorch. Makani is used for ongoing research. Stable features are regularly ported to the NVIDIA Modulus framework, a framework used for training Physics-ML models in Science and Engineering.

<div align="center"> <img src="https://github.com/NVIDIA/makani/blob/main/images/sfno_rollout.gif" height="388px"> </div>

Overview

Makani was started by engineers and researchers at NVIDIA and NERSC to train FourCastNet, a deep-learning based weather prediction model.

Makani is a research code built for massively parallel training of weather and climate prediction models on 100+ GPUs and to enable the development of the next generation of weather and climate models. Among others, Makani was used to train Spherical Fourier Neural Operators (SFNO) [1] and Adaptive Fourier Neural Operators (AFNO) [2] on the ERA5 dataset. Makani is written in PyTorch and supports various forms of model- and data-parallelism, asynchronous loading of data, unpredicted channels, autoregressive training and much more.

Getting started

Makani can be installed by running

git clone git@github.com:NVIDIA/makani.git
cd makani
pip install -e .

Training:

Training is launched by calling train.py and passing it the necessary CLI arguments to specify the configuration file --yaml_config and he configuration target --config:

mpirun -np 8 --allow-run-as-root python -u makani.train --yaml_config="config/sfnonet.yaml" --config="sfno_linear_73chq_sc3_layers8_edim384_asgl2"

:warning: architectures with complex-valued weights will currently fail. See Known issues for more information.

Makani supports various optimization to fit large models ino GPU memory and enable computationally efficient training. An overview of these features and corresponding CLI arguments is provided in the following table:

FeatureCLI argumentoptions
Automatic Mixed Precision--amp_modenone, fp16, bf16
Just-in-time compilation--jit_modenone, script, inductor
CUDA graphs--cuda_graph_modenone, fwdbwd, step
Activation checkpointing--checkpointing_level0,1,2,3
Data parallelism--batch_size1,2,3,...
Channel parallelism--fin_parallel_size, --fout_parallel_size1,2,3,...
Spatial model parallelism--h_parallel_size, --w_parallel_size1,2,3,...
Multistep training--multistep_count1,2,3,...

Especially larger models are enabled by using a mix of these techniques. Spatial model parallelism splits both the model and the data onto multiple GPUs, thus reducing both the memory footprint of the model and the load on the IO as each rank only needs to read a fraction of the data. A typical "large" training run of SFNO can be launched by running

mpirun -np 256 --allow-run-as-root python -u makani.train --amp_mode=bf16 --cuda_graph_mode=fwdbwd --multistep_count=1 --run_num="ngpu256_sp4" --yaml_config="config/sfnonet.yaml" --config="sfno_linear_73chq_sc3_layers8_edim384_asgl2" --h_parallel_size=4 --w_parallel_size=1 --batch_size=64

Here we train the model on 256 GPUs, split horizontally across 4 ranks with a batch size of 64, which amounts to a local batch size of 1/4. Memory requirements are further reduced by the use of bf16 automatic mixed precision.

Inference:

In a similar fashion to training, inference can be called from the CLI by calling inference.py and handled by inferencer.py. To launch inference on the out-of-sample dataset, we can call:

mpirun -np 256 --allow-run-as-root python -u makani.inference --amp_mode=bf16 --cuda_graph_mode=fwdbwd --multistep_count=1 --run_num="ngpu256_sp4" --yaml_config="config/sfnonet.yaml" --config="sfno_linear_73chq_sc3_layers8_edim384_asgl2" --h_parallel_size=4 --w_parallel_size=1 --batch_size=64

By default, the inference script will perform inference on the out-of-sample dataset specified

More about Makani

Project structure

The project is structured as follows:

makani
ā”œā”€ā”€ ...
ā”œā”€ā”€ config                      # configuration files, also known as recipes
ā”œā”€ā”€ data_process                # data pre-processing such as computation of statistics
ā”œā”€ā”€ datasets                    # dataset utility scripts
ā”œā”€ā”€ docker                      # scripts for building a docker image for training
ā”œā”€ā”€ makani                      # Main directory containing the package
ā”‚   ā”œā”€ā”€ inference               # contains the inferencer
ā”‚   ā”œā”€ā”€ mpu                     # utilities for model parallelism
ā”‚   ā”œā”€ā”€ networks                # networks, contains definitions of various ML models
ā”‚   ā”œā”€ā”€ third_party/climt       # third party modules
ā”‚   ā”‚   ā””ā”€ā”€ zenith_angle.py     # computation of zenith angle
ā”‚   ā”œā”€ā”€ utils                   # utilities
ā”‚   ā”‚   ā”œā”€ā”€ dataloaders         # contains various dataloaders
ā”‚   ā”‚   ā”œā”€ā”€ metrics             # metrics folder contains routines for scoring and benchmarking.
ā”‚   ā”‚   ā”œā”€ā”€ ...
ā”‚   ā”‚   ā”œā”€ā”€ comm.py             # comms module for orthogonal communicator infrastructure
ā”‚   ā”‚   ā”œā”€ā”€ dataloader.py       # dataloader interface
ā”‚   ā”‚   ā”œā”€ā”€ metric.py           # centralized metrics handler
ā”‚   ā”‚   ā”œā”€ā”€ trainer_profile.py  # copy of trainer.py used for profiling
ā”‚   ā”‚   ā””ā”€ā”€ trainer.py          # main file for handling training
ā”‚   ā”œā”€ā”€ ...
ā”‚   ā”œā”€ā”€ inference.py            # CLI script for launching inference
ā”‚   ā”œā”€ā”€ train.py                # CLI script for launching training
ā”œā”€ā”€ tests                       # test files
ā””ā”€ā”€ README.md                   # this file

Model and Training configuration

Model training in Makani is specified through the use of .yaml files located in the config folder. The corresponding models are located in networks and registered in the get_model routine in networks/models.py. The following table lists the most important configuration options.

Configuration KeyDescriptionOptions
nettypeNetwork architecture.SFNO, FNO, AFNO, ViT
lossLoss function.l2, geometric l2, ...
optimizerOptimizer to be used.Adam, AdamW
lrInitial learning rate.float > 0.0
batch_sizeBatch size.integer > 0
max_epochsNumber of epochs to train forinteger
schedulerLearning rate scheduler to be used.None, CosineAnnealing, ReduceLROnPlateau, StepLR
lr_warmup_stepsNumber of warmup steps for the learning rate scheduler.integer >= 0
weight_decayWeight decay.float
train_data_pathDirectory path which contains the training data.string
test_data_pathNetwork architecture.string
exp_dirDirectory path for ouputs such as model checkpoints.string
metadata_json_pathPath to the metadata file data.json.string
channel_namesChannels to be used for training.List[string]

For a more comprehensive overview, we suggest looking into existing .yaml configurations. More details about the available configurations can be found in this file.

Training data

Makani expects the training/test data in HDF5 format, where each file contains the data for an entire year. The dataloaders in Makani will then load the input inp and the target tar, which correspond to the state of the atmosphere at a given point in time and at a later time for the target. The time difference between input and target is determined by the parameter dt, which determines how many steps the two are apart. The physical time difference is determined by the temporal resolution dhours of the dataset.

Makani requires a metadata file named data.json, which describes important properties of the dataset such as the HDF5 variable name that contains the data. Another example are channels to load in the dataloader, which arespecified via channel names. The metadata file has the following structure:

{
    "dataset_name": "give this dataset a name",     # name of the dataset
    "attrs": {                                      # optional attributes, can contain anything you want
        "decription": "description of the dataset",
        "location": "location of your dataset"
    },
    "h5_path": "fields",                            # variable name of the data inside the hdf5 file
    "dims": ["time", "channel", "lat", "lon"],      # dimensions of fields contained in the dataset
    "dhours": 6,                                    # temporal resolution in hours
    "coord": {                                      # coordinates and channel descriptions
        "grid_type": "equiangular",                 # type of grid used in dataset: currently suppported choices are 'equiangular' and 'legendre-gauss'
        "lat": [0.0, 0.1, ...],                     # latitudinal grid coordinates
        "lon": [0.0, 0.1, ...],                     # longitudinal grid coordinates
        "channel": ["t2m", "u10", "v10", ...]       # names of the channels contained in the dataset
    }
}

The ERA5 dataset can be downloaded here.

Model packages

By default, Makani will save out a model package when training starts. Model packages allow easily contain all the necessary data to run the model. This includes statistics used to normalize inputs and outputs, unpredicted static channels and even the code which appends celestial features such as the cosine of the solar zenith angle. Read more about model packages here.

Known Issues

:warning: architectures with complex-valued weights: Training some architectures with complex-valued weights requires yet to be released patches to PyTorch. A hotfix that addresses these issues is available in the makani/third_party/torch folder. Overwriting the corresponding files in the PyTorch installation will resolve these issues.

Contributing

Thanks for your interest in contributing. There are many ways to contribute to this project.

While this is a research project, we aim to have functional unit tests with decent coverage. We kindly ask you to implement unit tests if you add a new feature and it can be tested.

Further reading

Authors

<img src="https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/01-nvidia-logo-horiz-500x200-2c50-d@2x.png" height="120px"><img src="https://www.nersc.gov/assets/Logos/NERSClogocolor.png" height="120px">

The code was developed by Thorsten Kurth, Boris Bonev, Jean Kossaifi, Animashree Anandkumar, Kamyar Azizzadenesheli, Noah Brenowitz, Ashesh Chattopadhyay, Yair Cohen, David Hall, Peter Harrington, Pedram Hassanzadeh, Christian Hundt, Alexey Kamenev, Karthik Kashinath, Zongyi Li, Morteza Mardani, Jaideep Pathak, Mike Pritchard, David Pruitt, Sanjeev Raja, Shashank Subramanian.

References

<a id="#sfno_paper">[1]</a> Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.; Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere; arXiv 2306.0383, 2023.

<a id="1">[2]</a> Pathak J., Subramanian S., Harrington P., Raja S., Chattopadhyay A., Mardani M., Kurth T., Hall D., Li Z., Azizzadenesheli K., Hassanzadeh P., Kashinath K., Anandkumar A.; FourCastNet: A Global Data-driven High-resolution Weather Model using Adaptive Fourier Neural Operators; arXiv 2202.11214, 2022.

Citation

If you use this package, please cite

@InProceedings{bonev2023sfno,
    title={Spherical {F}ourier Neural Operators: Learning Stable Dynamics on the Sphere},
    author={Bonev, Boris and Kurth, Thorsten and Hundt, Christian and Pathak, Jaideep and Baust, Maximilian and Kashinath, Karthik and Anandkumar, Anima},
    booktitle={Proceedings of the 40th International Conference on Machine Learning},
    pages={2806--2823},
    year={2023},
    volume={202},
    series={Proceedings of Machine Learning Research},
    month={23--29 Jul},
    publisher={PMLR},
}

@article{pathak2022fourcastnet,
    title={Fourcastnet: A global data-driven high-resolution weather model using adaptive fourier neural operators},
    author={Pathak, Jaideep and Subramanian, Shashank and Harrington, Peter and Raja, Sanjeev and Chattopadhyay, Ashesh and Mardani, Morteza and Kurth, Thorsten and Hall, David and Li, Zongyi and Azizzadenesheli, Kamyar and Hassanzadeh, Pedram and Kashinath, Karthik and Anandkumar, Animashree},
    journal={arXiv preprint arXiv:2202.11214},
    year={2022}
}