Awesome
CVPR 2022
Title: Masking Adversarial Damage: Finding Adversarial Saliency for Robust and Sparse Network
Authors: Byung-Kwan Lee*, Junho Kim*, and Yong Man Ro (*: equally contributed)
Affiliation: School of Electrical Engineering, Korea Advanced Institute of Science and Technology (KAIST)
Email: leebk@kaist.ac.kr
, arkimjh@kaist.ac.kr
, ymro@kaist.ac.kr
This is official PyTorch Implementation code for the paper of "Masking Adversarial Damage: Finding Adversarial Saliency for Robust and Sparse Network" accepted in CVPR 2022. To bridge adversarial robustness and model compression, we propose a novel adversarial pruning method, Masking Adversarial Damage (MAD) that employs second-order information of adversarial loss. By using it, we can accurately estimate adversarial saliency for model parameters and determine which parameters can be pruned without weakening adversarial robustness.
<p align="center"> <img src="figure/adversarial saliency.png" width="760" height="200"> </p>Furthermore, we reveal that model parameters of initial layer are highly sensitive to the adversarial examples and show that compressed feature representation retains semantic information for the target objects.
<p align="center"> <img src="figure/semantic information.png" width="465" height="350"> </p>Through extensive experiments on public datasets, we demonstrate that MAD effectively prunes adversarially trained networks without loosing adversarial robustness and shows better performance than previous adversarial pruning methods. For more detail, you can refer to our paper that will be accessible to public soon!.
<p align="center"> <img src="figure/pruning ratio.png" width="720" height="300"> </p>Adversarial attacks can potentially cause negative impacts on various DNN applications due to high computation and its fragility. By pruning model parameters without weakening adversarial robustness, our work contributes important societal impacts in this research area. Furthermore, in our promising observation that model parameters of initial layer are highly sensitive to adversarial loss, we hope to progress in another future direction of utilizing such property to enhance adversarial robustness.
In conclusion, in order to achieve adversarial robustness and model compression concurrently, we propose a novel adversarial pruning method, Masking Adversarial Damage (MAD). By exploiting second-order information with mask optimization and Block-wise K-FAC, we can precisely estimate adversarial saliency of the whole parameters. Through extensive validations, we corroborate pruning model parameters in order of low adversarial saliency retains adversarial robustness while alleviating less performance degradation compared with previous adversarial pruning methods.
Datasets
- CIFAR-10 (32x32, 10 classes)
- SVHN (32x32, 10 classes)
- Tiny-ImageNet (64x64, 200 classes)
Networks
- VGG-16 (models/vgg.py)
- ResNet-18 (models/resnet.py)
- WideResNet-28-10 (models/wide.py)
Masking Adversarial Damage (MAD)
Step 1. Finding Adversarial Saliency
- Run
compute_saliency.py
(Procedure of saving a pickle file for adversarial saliency to all model parameters. Then, you should need a folder (e.g.,pickle
folder) in which the pickle file is saved)
# model parameter
parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny'
parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide'
parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet)
parser.add_argument('--device', default='cuda:0', type=str)
parser.add_argument('--batch_size', default=128, type=float)
# attack parameter
parser.add_argument('--attack', default='pgd', type=str)
parser.add_argument('--eps', default=0.03, type=float)
parser.add_argument('--steps', default=10, type=int)
Among codes for running compute_saliency
, the following code represents the major contribution of our work that is the procedure
of computing adversarial saliency realized with Block-wise K-FAC. Note that it is important to consider the factors of
block1
and block2
below for Block-wise K-FAC that dramatically reduces computation.
def _compute_delta_L(self):
delta_L_list = []
mask_list = []
for idx, m in enumerate(self.modules):
m_aa, m_gg = self.m_aa[m], self.m_gg[m]
w = fetch_mat_weights(m)
mask = fetch_mat_mask_weights(m)
w_mask = w - operator(w, mask)
double_grad_L = torch.empty_like(w_mask)
# 1/2 * Δ𝑤^𝑇 *𝐻 * Δ𝑤
for i in range(m_gg.shape[0]):
block1 = 0.5 * m_gg[i, i] * w_mask.t()[:, i].view(-1, 1)
block2 = w_mask[i].view(1, -1) @ m_aa
block = block1 @ block2
double_grad_L[i, :] = block.diag()
delta_L = double_grad_L
delta_L_list.append(delta_L.detach())
mask_list.append(f(mask).detach())
return delta_L_list, mask_list
Step 2. Pruning Low Advesarial Saliency
- Run
main_mad_pretrain.py
(Necessary to load a pickle file generated in Step 1)
# model parameter
parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny'
parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide'
parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet)
parser.add_argument('--device', default='cuda:0', type=str)
# mad parameter
parser.add_argument('--percnt', default=0.9, type=float) # 0.99 (Sparsity)
parser.add_argument('--pruning_mode', default='element', type=str) # 'random' (randomly pruning)
parser.add_argument('--largest', default='false', type=str2bool) # 'true' (pruning high adversarial saliency)
# learning parameter
parser.add_argument('--learning_rate', default=0.1, type=float)
parser.add_argument('--weight_decay', default=0.0002, type=float)
parser.add_argument('--batch_size', default=128, type=float)
parser.add_argument('--epoch', default=60, type=int)
# attack parameter
parser.add_argument('--attack', default='pgd', type=str)
parser.add_argument('--eps', default=0.03, type=float)
parser.add_argument('--steps', default=10, type=int)
Adversarial Training (+ Recent Adversarial Defenses)
- AT (main_adv_pretrain.py)
- TRADES (main_trades_pretrain.py)
- MART (main_mart_pretrain.py)
- FAST for Tiny-ImageNet (refer to FGSM_train class in attack/attack.py)
Running Adversarial Training
To easily make an adversarially trained model, we first train a standard model by [1] and perform adversarial training (AT) by [2], starting from the trained standard model. To execute recent adversarial defenses, AT model created by [2] would be helpful to train TRADES or MART through [3-1] or [3-2].
-
[1] Plain (Plain Training)
- Run
main_pretrain.py
# model parameter parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) parser.add_argument('--epoch', default=200, type=int) parser.add_argument('--device', default='cuda:0', type=str) # learning parameter parser.add_argument('--learning_rate', default=0.1, type=float) parser.add_argument('--weight_decay', default=0.0002, type=float) parser.add_argument('--batch_size', default=128, type=float)
- Run
-
[2] AT (PGD Adversarial Training)
- Run
main_adv_pretrain.py
# model parameter parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) parser.add_argument('--device', default='cuda:0', type=str) # learning parameter parser.add_argument('--learning_rate', default=0.1, type=float) parser.add_argument('--weight_decay', default=0.0002, type=float) parser.add_argument('--batch_size', default=128, type=float) parser.add_argument('--epoch', default=60, type=int) # attack parameter parser.add_argument('--attack', default='pgd', type=str) parser.add_argument('--eps', default=0.03, type=float) parser.add_argument('--steps', default=10, type=int)
- Run
-
[3-1] TRADES (Recent defense method)
- Run
main_trades_pretrain.py
# model parameter parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) parser.add_argument('--device', default='cuda:0', type=str) # learning parameter parser.add_argument('--learning_rate', default=1e-3, type=float) parser.add_argument('--weight_decay', default=0.0002, type=float) parser.add_argument('--batch_size', default=128, type=float) parser.add_argument('--epoch', default=10, type=int) # attack parameter parser.add_argument('--attack', default='pgd', type=str) parser.add_argument('--eps', default=0.03, type=float) parser.add_argument('--steps', default=10, type=int)
- Run
-
[3-2] MART (Recent defense method)
- Run
main_mart_pretrain.py
# model parameter parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) parser.add_argument('--device', default='cuda:0', type=str) # learning parameter parser.add_argument('--learning_rate', default=1e-3, type=float) parser.add_argument('--weight_decay', default=0.0002, type=float) parser.add_argument('--batch_size', default=128, type=float) parser.add_argument('--epoch', default=60, type=int) # attack parameter parser.add_argument('--attack', default='pgd', type=str) parser.add_argument('--eps', default=0.03, type=float) parser.add_argument('--steps', default=10, type=int)
- Run
Adversarial Attacks (by torchattacks)
- Fast Gradient Sign Method (FGSM)
- Projected Gradient Descent (PGD)
- Carlini & Wagner (CW)
- AutoPGD (AP)
- AutoAttack (AA)
This implementation details for the adversarial attacks are described in attack/attack.py.
# torchattacks
if attack == "fgsm":
return torchattacks.FGSM(model=net, eps=eps)
elif attack == "fgsm_train":
return FGSM_train(model=net, eps=eps)
elif attack == "pgd":
return torchattacks.PGD(model=net, eps=eps, alpha=eps/steps*2.3, steps=steps, random_start=True)
elif attack == "cw_linf":
return CW_Linf(model=net, eps=eps, lr=0.1, steps=30)
elif attack == "apgd":
return torchattacks.APGD(model=net, eps=eps, loss='ce', steps=30)
elif attack == "auto":
return torchattacks.AutoAttack(model=net, eps=eps, n_classes=n_classes)
Testing Adversarial Robustness
-
Mearsuring the robustness in an adversarial trained model
- Run
test.py
# model parameter parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) parser.add_argument('--baseline', default='adv', type=str) # 'trades', 'mart', 'mad' parser.add_argument('--device', default='cuda:0', type=str) # mad parameter parser.add_argument('--percnt', default=0.9, type=float) # 0.99 (Sparsity) parser.add_argument('--pruning_mode', default='el', type=str) # 'rd' (random) parser.add_argument('--largest', default='false', type=str2bool) # 'true' (pruning high adversarial saliency) # attack parameter parser.add_argument('--attack', default='pgd', type=str) parser.add_argument('--eps', default=0.03, type=float) parser.add_argument('--steps', default=30, type=int)
- Run