Awesome
Compressors
Warning! Alpha version! This is not product-ready solution so far.
Compressors is a library with a lot of pipelines connected with model compression without significantly performance lose.
Why Compressors?
Compressors provides many ways to compress your model. You can use it for CV and NLP task.
Library separated into three parts:
- Distillation
- Pruning
- Quantization
There are two ways to use Compressors: with Catalyst or just use functional API.
Install
pip install git+https://github.com/elephantmipt/compressors.git
Features
Distillation
Name | References | Status |
---|---|---|
KL-divergence | Hinton et al. | Implemented |
MSE | Hinton et al. | Implemented |
Probabilistic KT | Passalis et al. | Implemented |
Cosine | ??? | Implemented |
Attention Transfer | Zagoruyko et al. | Implemented |
Constrative Representation Distillation | Tian et al. | Implemented (without dataset) |
Probablility Shift | Wen et al. | Implemented and tested |
Pruning
Name | References | Status |
---|---|---|
Lottery ticket hypothesis | Frankle et al. | Implemented |
Iterative pruning | Paganini et al. | Implemented |
Minimal Examples
Distillation
MNIST
from itertools import chain
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as T
from catalyst.contrib.datasets import MNIST
from catalyst.callbacks import AccuracyCallback, OptimizerCallback
from compressors.distillation.runners import EndToEndDistilRunner
from compressors.models import MLP
from compressors.utils.data import TorchvisionDatasetWrapper as Wrp
teacher = MLP(num_layers=4)
student = MLP(num_layers=3)
datasets = {
"train": Wrp(MNIST("./data", train=True, download=True, transform=T.ToTensor())),
"valid": Wrp(MNIST("./data", train=False, transform=T.ToTensor())),
}
loaders = {
dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
for dl_key, dataset in datasets.items()
}
optimizer = torch.optim.Adam(chain(teacher.parameters(), student.parameters()))
runner = EndToEndDistilRunner(
hidden_state_loss="mse",
num_train_teacher_epochs=5
)
runner.train(
model = torch.nn.ModuleDict({"teacher": teacher, "student": student}),
loaders=loaders,
optimizer=optimizer,
num_epochs=4,
callbacks=[
OptimizerCallback(metric_key="loss"),
AccuracyCallback(input_key="logits", target_key="targets")
],
valid_metric="accuracy01",
minimize_valid_metric=False,
logdir="./logs",
valid_loader="valid",
criterion=torch.nn.CrossEntropyLoss()
)
CIFAR100 ResNet
from catalyst.callbacks import (
AccuracyCallback,
ControlFlowCallback,
CriterionCallback,
OptimizerCallback,
SchedulerCallback,
)
import torch
from torch.hub import load_state_dict_from_url
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR100
from compressors.distillation.callbacks import (
AttentionHiddenStatesCallback,
KLDivCallback,
MetricAggregationCallback,
)
from compressors.distillation.runners import DistilRunner
from compressors.models.cv import resnet_cifar_8, resnet_cifar_56
from compressors.utils.data import TorchvisionDatasetWrapper as Wrp
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
datasets = {
"train": Wrp(CIFAR100(root=".", train=True, download=True, transform=transform_train)),
"valid": Wrp(CIFAR100(root=".", train=False, transform=transform_test)),
}
loaders = {
k: DataLoader(v, batch_size=32, shuffle=k == "train", num_workers=2)
for k, v in datasets.items()
}
teacher_sd = load_state_dict_from_url(
"https://github.com/chenyaofo/CIFAR-pretrained-models/releases/download/resnet/cifar100-resnet56-2f147f26.pth"
)
teacher_model = resnet_cifar_56(num_classes=100)
teacher_model.load_state_dict(teacher_sd)
student_model = resnet_cifar_8(num_classes=100)
optimizer = torch.optim.SGD(
student_model.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
runner = DistilRunner(apply_probability_shift=True)
runner.train(
model={"teacher": teacher_model, "student": student_model},
loaders=loaders,
optimizer=optimizer,
scheduler=scheduler,
valid_metric="accuracy",
minimize_valid_metric=False,
logdir="./cifar100_logs",
callbacks=[
ControlFlowCallback(AttentionHiddenStatesCallback(), loaders="train"),
ControlFlowCallback(KLDivCallback(temperature=4), loaders="train"),
CriterionCallback(input_key="s_logits", target_key="targets", metric_key="cls_loss"),
ControlFlowCallback(
MetricAggregationCallback(
prefix="loss",
metrics={
"attention_loss": 1000,
"kl_div_loss": 0.9,
"cls_loss": 0.1,
},
mode="weighted_sum",
),
loaders="train",
),
AccuracyCallback(input_key="s_logits", target_key="targets"),
OptimizerCallback(metric_key="loss", model_key="student"),
SchedulerCallback(),
],
valid_loader="valid",
num_epochs=200,
criterion=torch.nn.CrossEntropyLoss(),
)
AG NEWS BERT (transformers)
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset, load_metric
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from catalyst.callbacks import ControlFlowCallback, OptimizerCallback
from catalyst.callbacks.metric import LoaderMetricCallback
from compressors.distillation.callbacks import (
HiddenStatesSelectCallback,
KLDivCallback,
LambdaPreprocessCallback,
MetricAggregationCallback,
MSEHiddenStatesCallback,
)
from compressors.distillation.runners import HFDistilRunner
from compressors.metrics.hf_metric import HFMetric
from compressors.runners.hf_runner import HFRunner
datasets = load_dataset("ag_news")
tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-4_H-128_A-2")
datasets = datasets.map(
lambda e: tokenizer(e["text"], truncation=True, padding="max_length", max_length=128),
batched=True,
)
datasets = datasets.map(lambda e: {"labels": e["label"]}, batched=True)
datasets.set_format(
type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"],
)
loaders = {
"train": DataLoader(datasets["train"], batch_size=64, shuffle=True),
"valid": DataLoader(datasets["test"], batch_size=64),
}
metric_callback = LoaderMetricCallback(
metric=HFMetric(metric=load_metric("accuracy")), input_key="logits", target_key="labels",
)
################### Teacher Training #####################
teacher_model = AutoModelForSequenceClassification.from_pretrained(
"google/bert_uncased_L-4_H-128_A-2", num_labels=4
)
runner = HFRunner()
runner.train(
model=teacher_model,
loaders=loaders,
optimizer=torch.optim.Adam(teacher_model.parameters(), lr=1e-4),
callbacks=[metric_callback],
num_epochs=5,
valid_metric="accuracy",
minimize_valid_metric=False,
verbose=True
)
############### Distillation ##################
slct_callback = ControlFlowCallback(
HiddenStatesSelectCallback(hiddens_key="t_hidden_states", layers=[1, 3]), loaders="train",
)
lambda_hiddens_callback = ControlFlowCallback(
LambdaPreprocessCallback(
lambda s_hiddens, t_hiddens: (
[c_s[:, 0] for c_s in s_hiddens],
[t_s[:, 0] for t_s in t_hiddens], # tooks only CLS token
)
),
loaders="train",
)
mse_hiddens = ControlFlowCallback(MSEHiddenStatesCallback(), loaders="train")
kl_div = ControlFlowCallback(KLDivCallback(temperature=4), loaders="train")
aggregator = ControlFlowCallback(
MetricAggregationCallback(
prefix="loss",
metrics={"kl_div_loss": 0.2, "mse_loss": 0.2, "task_loss": 0.6},
mode="weighted_sum",
),
loaders="train",
)
runner = HFDistilRunner()
student_model = AutoModelForSequenceClassification.from_pretrained(
"google/bert_uncased_L-2_H-128_A-2", num_labels=4
)
metric_callback = LoaderMetricCallback(
metric=HFMetric(metric=load_metric("accuracy")), input_key="s_logits", target_key="labels",
)
runner.train(
model=torch.nn.ModuleDict({"teacher": teacher_model, "student": student_model}),
loaders=loaders,
optimizer=torch.optim.Adam(student_model.parameters(), lr=1e-4),
callbacks=[
metric_callback,
slct_callback,
lambda_hiddens_callback,
mse_hiddens,
kl_div,
aggregator,
OptimizerCallback(metric_key="loss"),
],
num_epochs=5,
valid_metric="accuracy",
minimize_valid_metric=False,
valid_loader="valid",
verbose=True
)
Pruning
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from catalyst.callbacks import (
PruningCallback,
OptimizerCallback,
CriterionCallback,
AccuracyCallback,
ControlFlowCallback
)
from catalyst.contrib.datasets import MNIST
from compressors.distillation.callbacks import MetricAggregationCallback
from compressors.distillation.callbacks import KLDivCallback
from compressors.models import MLP
from compressors.pruning.runners import PruneRunner
from compressors.utils.data import TorchvisionDatasetWrapper as Wrp
model = MLP(num_layers=3)
model = model.load_state_dict(torch.load("trained_model.pth"))
datasets = {
"train": Wrp(MNIST("./data", train=True, download=True, transform=ToTensor())),
"valid": Wrp(MNIST("./data", train=False, transform=ToTensor())),
}
loaders = {
dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
for dl_key, dataset in datasets.items()
}
optimizer = torch.optim.Adam(model.parameters())
runner = PruneRunner(num_sessions=10)
runner.train(
model=model,
loaders=loaders,
optimizer=optimizer,
criterion=torch.nn.CrossEntropyLoss(),
callbacks=[
PruningCallback(pruning_fn="l1_unstructured", amount=0.2, remove_reparametrization_on_stage_end=False),
OptimizerCallback(metric_key="loss"),
CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
AccuracyCallback(input_key="logits", target_key="targets"),
],
logdir="./pruned_model",
valid_loader="valid",
valid_metric="accuracy",
minimize_valid_metric=False,
)