Title: Masking Adversarial Damage: Finding Adversarial Saliency for Robust and Sparse Network

Authors: Byung-Kwan Lee*, Junho Kim*, and Yong Man Ro (*: equally contributed)

Affiliation: School of Electrical Engineering, Korea Advanced Institute of Science and Technology (KAIST)

Email: leebk@kaist.ac.kr, arkimjh@kaist.ac.kr, ymro@kaist.ac.kr

This is official PyTorch Implementation code for the paper of "Masking Adversarial Damage: Finding Adversarial Saliency for Robust and Sparse Network" accepted in CVPR 2022. To bridge adversarial robustness and model compression, we propose a novel adversarial pruning method, Masking Adversarial Damage (MAD) that employs second-order information of adversarial loss. By using it, we can accurately estimate adversarial saliency for model parameters and determine which parameters can be pruned without weakening adversarial robustness.

<p align="center"> <img src="figure/adversarial saliency.png" width="760" height="200"> </p>

Furthermore, we reveal that model parameters of initial layer are highly sensitive to the adversarial examples and show that compressed feature representation retains semantic information for the target objects.

<p align="center"> <img src="figure/semantic information.png" width="465" height="350"> </p>

Through extensive experiments on public datasets, we demonstrate that MAD effectively prunes adversarially trained networks without loosing adversarial robustness and shows better performance than previous adversarial pruning methods. For more detail, you can refer to our paper that will be accessible to public soon!.

<p align="center"> <img src="figure/pruning ratio.png" width="720" height="300"> </p>

Adversarial attacks can potentially cause negative impacts on various DNN applications due to high computation and its fragility. By pruning model parameters without weakening adversarial robustness, our work contributes important societal impacts in this research area. Furthermore, in our promising observation that model parameters of initial layer are highly sensitive to adversarial loss, we hope to progress in another future direction of utilizing such property to enhance adversarial robustness.

In conclusion, in order to achieve adversarial robustness and model compression concurrently, we propose a novel adversarial pruning method, Masking Adversarial Damage (MAD). By exploiting second-order information with mask optimization and Block-wise K-FAC, we can precisely estimate adversarial saliency of the whole parameters. Through extensive validations, we corroborate pruning model parameters in order of low adversarial saliency retains adversarial robustness while alleviating less performance degradation compared with previous adversarial pruning methods.



Masking Adversarial Damage (MAD)

Step 1. Finding Adversarial Saliency

# model parameter
parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny'
parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide'
parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet)
parser.add_argument('--device', default='cuda:0', type=str)
parser.add_argument('--batch_size', default=128, type=float)

# attack parameter
parser.add_argument('--attack', default='pgd', type=str)
parser.add_argument('--eps', default=0.03, type=float)
parser.add_argument('--steps', default=10, type=int)

Among codes for running compute_saliency, the following code represents the major contribution of our work that is the procedure of computing adversarial saliency realized with Block-wise K-FAC. Note that it is important to consider the factors of block1 and block2 below for Block-wise K-FAC that dramatically reduces computation.

    def _compute_delta_L(self):

        delta_L_list = []
        mask_list = []
        for idx, m in enumerate(self.modules):

            m_aa, m_gg = self.m_aa[m], self.m_gg[m]

            w = fetch_mat_weights(m)
            mask = fetch_mat_mask_weights(m)
            w_mask = w - operator(w, mask)

            double_grad_L = torch.empty_like(w_mask)

            # 1/2 * Δ𝑤^𝑇 *𝐻 * Δ𝑤
            for i in range(m_gg.shape[0]):
                block1 = 0.5 * m_gg[i, i] * w_mask.t()[:, i].view(-1, 1)
                block2 = w_mask[i].view(1, -1) @ m_aa
                block = block1 @ block2
                double_grad_L[i, :] = block.diag()

            delta_L = double_grad_L

        return delta_L_list, mask_list

Step 2. Pruning Low Advesarial Saliency

# model parameter
parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny'
parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide'
parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet)
parser.add_argument('--device', default='cuda:0', type=str)

# mad parameter
parser.add_argument('--percnt', default=0.9, type=float) # 0.99 (Sparsity)
parser.add_argument('--pruning_mode', default='element', type=str) # 'random' (randomly pruning)
parser.add_argument('--largest', default='false', type=str2bool) # 'true' (pruning high adversarial saliency)

# learning parameter
parser.add_argument('--learning_rate', default=0.1, type=float)
parser.add_argument('--weight_decay', default=0.0002, type=float)
parser.add_argument('--batch_size', default=128, type=float)
parser.add_argument('--epoch', default=60, type=int)

# attack parameter
parser.add_argument('--attack', default='pgd', type=str)
parser.add_argument('--eps', default=0.03, type=float)
parser.add_argument('--steps', default=10, type=int)

Adversarial Training (+ Recent Adversarial Defenses)

Running Adversarial Training

To easily make an adversarially trained model, we first train a standard model by [1] and perform adversarial training (AT) by [2], starting from the trained standard model. To execute recent adversarial defenses, AT model created by [2] would be helpful to train TRADES or MART through [3-1] or [3-2].

Adversarial Attacks (by torchattacks)

This implementation details for the adversarial attacks are described in attack/attack.py.

# torchattacks
if attack == "fgsm":
    return torchattacks.FGSM(model=net, eps=eps)

elif attack == "fgsm_train":
    return FGSM_train(model=net, eps=eps)

elif attack == "pgd":
    return torchattacks.PGD(model=net, eps=eps, alpha=eps/steps*2.3, steps=steps, random_start=True)

elif attack == "cw_linf":
    return CW_Linf(model=net, eps=eps, lr=0.1, steps=30)

elif attack == "apgd":
    return torchattacks.APGD(model=net, eps=eps, loss='ce', steps=30)

elif attack == "auto":
    return torchattacks.AutoAttack(model=net, eps=eps, n_classes=n_classes)

Testing Adversarial Robustness