Home

Awesome

Weight Fixing Networks

<img src="https://user-images.githubusercontent.com/13983188/177870968-d26c4c87-6259-493a-b67e-8ccbc82dccbf.png" width="700">

This repo contains the Pytorch + Lightning code to apply the method proposed in 'Weight Fixing Networks' an accepted paper in ECCV 2022.

<img src="https://user-images.githubusercontent.com/13983188/177872026-dc25192e-f218-4b98-a68d-26332d6d47b9.png" width="700">

Quantized Model Saves

Below we link to the quantised model saves quoted in the results section of the paper.

ModelδUnique Param CountEntropyAccLink
ResNet-180.00751934.1570.3link
ResNet-180.011643.0169.7link
ResNet-180.015902.7267.3link
ResNet-340.00752333.8773.0link
ResNet-340.011643.4872.6link
ResNet-340.0151172.8372.2 link
ResNet-500.00752614.1176.0link
ResNet-500.011994.0075.4link
ResNet-500.0151253.5575.1link

To Run

  1. First make sure you have installed the requirements found in requirements.txt

  2. If you want to run the ImageNet experiments, you'll need to update the data_dir (see - Setting ImageNet file locations)

  3. Now just run python pretrained_model_experiments.py with any options arguments you wish to change from:

Optional arguments:

#c5f015 --distance_allowed : δ in the paper

#c5f015 --percentages : the percentage of weights clustered in each iteration

#c5f015 --optimizer : training optimizer

#c5f015 --experiment_name: used to save tb logs and model saves

#c5f015 --scheduler : the learning rate scheduler

#c5f015 -- lr : the learning rate

#c5f015 --first_epoch : the number of training iterations before any clustering (set to zero for pre-trained models)

#c5f015 --fixing_epochs : the number of training epochs within a single clustering iteration (3 was used in the paper)

#c5f015 --model : the name of the model to train see get_model() for a list of out-the-box supported dataset-model combinations

#c5f015 --dataset : the dataset to train on, currently we support CIFAR-10 and Imagenet

#c5f015 --zero_distance : $\gamma_0$ in the paper, any abs weight less than this will be set to zero and prunned

#c5f015 --regularisation_ratio : the weighting of the $\mathcal{L}_{reg}$ term

#c5f015 --bn_inc : whether to quantize batch-norm layers

#c5f015 --resume : if continuing training, set this to the iteration you wish to continue from

#c5f015 --calculation_type: the distance type you want to use, relative and euclidean supported

Applying WFN to Different Types of Models

To add new models for WFN quantisation, go to the get_model function within pre_trained_model_experiments.py and follow the format. Once added here, the new model will automatically be converted in a weight_fix_base which contains all the functionality needed to apply the clustering.


def get_model(model_name, data):
    """ Here is where the models are defined, if you would like to use a new model, you can insert it into here """

    if model_name == 'conv4':
        model = All_Conv_4()
        model = model.load_from_checkpoint(checkpoint_path="Pretrained_Models/PyTorch_CIFAR10/cifar10_models/state_dicts/all_conv4")

    if model_name == 'resnet18' and data == 'cifar10':
        model = resnet18(pretrained=True)

    if model_name == 'resnet34' and data == 'imnet':
        model = models.resnet34(pretrained=True)

Setting ImageNet File Locations

Go to the ImageNet module within Datasets/imagenet.py and edit the line self.data_dir to point to your imagenet data directory.

class ImageNet_Module(pl.LightningDataModule):
        def __init__(self, data_dir = 'Datasets/', shuffle_pixels=False, shuffle_labels=False, random_pixels=False):
                super().__init__()
                self.data_dir = 'Your data directory here'
                self.mean = [0.485, 0.456, 0.406]
                self.std = [0.229, 0.224, 0.225]
                self.normalise = transforms.Normalize(mean=self.mean, std=self.std)
                self.transform = self.transform_select(shuffle_pixels, random_pixels)
                self.test_trans = self.test_transform()
                self.target_transform = self.target_transform_select(shuffle_labels)
                self.targets = 1000
                self.dims = (3,224,224)
                self.bs = 64