Home

Awesome

InceptionV3 - Burn

This project provides an implementation for the InceptionV3 as described in the Rethinking the Inception Architecture for Computer Vision paper.

The implementation is almost a one-to-one translation of the PyTorch implementation (also see the hub page).

Pre-trained weights for this model can be either downloaded from PyTorch (using torchvision), or from mseitzer/pytorch-fid.

Downloading FID Weights

The FID weights provided by mseitzer/pytorch-fid use the legacy version of PyTorch's serialization which is not supported by Burn (or more precisely, by Candle which Burn uses in the background). Therefore, the script download_fid_weights.py is provided. This script downloads the weights, and re-saves them in the current PyTorch format.

To run the script:

# If no arguments are provided, the weights file will be saved to the default location:
# `~/.cache/inception-v3-burn/pt_inception-2015-12-05-6726825d.pth`
python download_fid_weights.py

# Alternatively, you can provide a custom path.
python download_fid_weights.py --file PATH_TO_FILE

Then, add the model to your dependencies:

[dependencies]
inception-v3-burn = { git = "https://github.com/varonroy/inception-v3-burn", features = ["pretrained"] }

And initialize it using the weights that were prepared in the previous steps.

use inception_v3_burn::model::{
    weights::{downloader::InceptionV3PretrainedLoader, WeightsSource},
    InceptionV3,
};

fn main() {
    type B = burn::backend::NdArray;
    let device = burn::backend::ndarray::NdArrayDevice::default();

    // If you have saved the model to a location other than the default one,
    // replace None, with `Some(<fid-weights-file-path>)`.
    let (config, model) = InceptionV3::<B>::pretrained(WeightsSource::fid(None), &device).unwrap();
}

License

This implementation is licensed under the MIT license.

For the pre-trained weights' licenses, please refer to their original sources: