Awesome
FID computation in Jax/Flax
This is a port of mseitzer/pytorch-fid, which is a port of the original FID implementation (bioinf-jku/TTUR).
The parameters for the InceptionV3 network are taken from mseitzer/pytorch-fid. The FID scores are almost identical (absolute difference around 1e-7).
The only difference is that mseitzer/pytorch-fid resizes the images to 299x299 by default. In this implementation, the images are not resized by default. You can resize the images using the --img_size
argument.
Installation
You will need Python 3.7 or later.
- For GPU usage, follow the <a href="https://github.com/google/jax#installation">Jax</a> installation with CUDA.
- Then install:
> pip install jax-fid
For CPU-only you can skip step 1.
Usage
Compute FID score
> CUDA_VISIBLE_DEVICES=N python -m jax_fid --path1 /path/to/dataset1 --path2 /path/to/dataset2
where N
is the GPU index.
Pre-compute statistics for image directory
> CUDA_VISIBLE_DEVICES=N python -m jax_fid --precompute --img_dir /path/to/dataset --out_dir /path/to/stats
Arguments
--path1
- Path to image directory or .npz file containing pre-computed statistics.
--path2
- Path to image directory or .npz file containing pre-computed statistics.
--batch_size
- Batch size per device for computing the Inception activations.
--img_size
- Resize images to this size. The format is (height, width).
--precompute
- If True, pre-compute statistics for given image directory.
--img_dir
- Path to image directory for pre-computing statistics.
--out_dir
- Path where pre-computed statistics are stored.
--mmap
- If True, use mmap to compute statistics.
--mmap_file
- Name of mmap file. Only used if mmap is True.
License
Apache-2.0 License