Home

Awesome

Official code for using / reproducing TRIM from the paper Transformation Importance with Applications to Cosmology (ICLR 2020 Workshop). This code shows examples and provides useful wrappers for calculating importance in a transformed feature space.

This repo is actively maintained. For any questions please file an issue.

trim

examples/documentation

Attribution to different scales in cosmological imagesFake news attribution to different topics
Attribution to different NMF components in MNIST classificationAttribution to different frequencies in audio classification

sample usage

import torch
import torch.nn as nn
from trim import TrimModel
from functools import partial

# setup a trim model
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1)) # orig model
transform = partial(torch.rfft, signal_ndim=1, onesided=False) # fft
inv_transform = partial(torch.irfft, signal_ndim=1, onesided=False) # inverse fft
model_trim = TrimModel(model=model, inv_transform=inv_transform) # trim model

# get a data point
x = torch.randn(1, 10)
s = transform(x)

# can now use any attribution method on the trim model
# get (input_x_gradient) attribution in the fft space
s.requires_grad = True
model_trim(s).backward()
input_x_gradient = s.grad * s

related work

reference

@article{singh2020transformation,
    title={Transformation Importance with Applications to Cosmology},
    author={Singh, Chandan and Ha, Wooseok and Lanusse, Francois, and Boehm, Vanessa, and Liu, Jia and Yu, Bin},
    journal={arXiv preprint arXiv:2003.01926},
    year={2020},
    url={https://arxiv.org/abs/2003.01926},
}