Awesome
ConceptPrune
Code for the paper - ConceptPrune: Concept Editing in Diffusion Models via Skilled Neuron Pruning (arxiv preprint)
Introduction
While large-scale text-to-image diffusion models have demonstrated impressive image-generation capabilities, there are significant concerns about their potential misuse for generating unsafe content, violating copyright, and perpetuating societal biases. Recently, the text-to-image generation community has begun addressing these concerns by editing or unlearning undesired concepts from pre-trained models. However, these methods often involve data-intensive and inefficient fine-tuning or utilize various forms of token remapping, rendering them susceptible to adversarial jailbreaks. In this paper, we present a simple and effective training-free approach, ConceptPrune, wherein we first identify critical regions within pre-trained models responsible for generating undesirable concepts, thereby facilitating straightforward concept unlearning via weight pruning. Experiments across a range of concepts including artistic styles, nudity, object erasure, and gender debiasing demonstrate that target concepts can be efficiently erased by pruning a tiny fraction, approximately 0.12% of total weights, enabling multi-concept erasure and robustness against various white-box and black-box adversarial attacks.
Experiments
Environment Setup
Create the environment from the environment.yml
file.
conda env create -f environment.yml
conda activate concept-prune
We recommend using diffusers v0.29.2
as the results may change for different versions.
Code Structure
The file structure is as follows
configs
- Contains ```.yaml`` file for basic arguments. These arguments can be changed within the scripts using argument parsers.
datasets
- Contains txt or csv file with prompts for different concepts
neuron_receivers
- Contains classes to hook Feed Forward network (FFN) modules within the Unet to record neuron activations
wanda
- Contains scripts to calculate WANDA pruing metric introduced in Sun et. al for FFN weights
utils
- Basic helper functions
benchmarking
- Scripts to run all the benchmarks in the paper for different concepts
Pruning the model using WANDA
To obtain a pruned model for a concept <target>
, run the following -
-
Discover skilled neurons for a concept
python wanda/wanda.py --target <target> --skill_ratio 0.01
<target>
is the concept that we want to erase. Replace<target
with any of -1. Artist Styles -
Van Gogh, Monet, Pablo Picasso, Da Vinci, Salvador Dali
. Example - base prompt =a cat
and target prompt =a cat in the style of Van Gogh
2. Nudity -
naked
. Example - base prompt =a photo of a man
and target prompt =a photo of a naked man
3. Objects (Imagnette classes) -
parachute, golf ball, garbage truck, cassette player, church, tench, english springer, french horn, chain saw, gas pump
.Example - base prompt =
a room
and target prompt =a parachute in a room
4. Gender reversal -
male, female
. Example - base prompt =a son
and target prompt =a daughter
for female to male reversal.5. Memorization -
memorize_$i$
. For this concept, prompts are loaded from corresponding to datasets/memorize_0.txt. Please pass--target_file memorize_$i$
for this concept.The argument
skill_ratio
denotes the sparsity level which defines the top-k% neurons considered for WANDA pruning. This command saves skilled neurons discovered for every timestep and layer in a different .pkl file as a sparse matrix. -
Check if removing skilled neurons from all timesteps and layers removes the concept.
We first check whether hyper-parameters like
skill_ratio
used in the previous step are optimal for concept removal. We attach hook functions to FFN layers for every timestep and apply the pruning mask. The following command will save images after skilled neurons are removed.python wanda/remove_neurons.py --target <target> --skill_ratio <skill_ratio>
-
Next, we take a union over skilled neurons for the first few timesteps.
Run the follwing command to obtain the pruned model.
python wanda/save_union_over_time.py --target <target> --timesteps <tau> --skill_ratio <skill_ratio>
We provide the values of these hyper-parameters in Table 7 in the Appendix for every concept.
Benchmarks
Baselines
Train concept erasure baselines - UCE, FMN, ESD, Concept-Ablation using their respective repositories. n our code base, we provide code to evalaute these baselines on concept-erasure benchmarks for different concepts. In the following experiments, ausedd uce, esd, fmn, concept-ablation
for <baseline>
respectively to run the above baselines.
Download Checkpoints
We will provide checkpoints on Hugging Face soon!
Evaluate ConceptPrune
-
Artist Styles
To evaluate artist style erasure for ```Van Gogh, Monet, Pablo Picasso, Da Vinci, Salvador Dali`` for ConceptPrune, run
python benchmarking/artist_erasure.py --target <target> --baseline concept-prune --ckpt_name <path to checkpoint>
We created a dataset of 50 prompts using ChatGPT for different artists such that each prompt contains the painting name along with the name of the artist. These propmts are available in
datasets/
. The script saves images and a json files with CLIP metric reported in the paper in theresults/
folder. -
Nudity
To evaluate nudity erasure with ConceptPrune on the I2P dataset, run
python benchmarking/nudity_eval.py --eval_dataset i2p --baseline 'concept-prune' --gpu 0 --ckpt_name <path to checkpoint>
To run ConceptPrune on black-box adversarial prompt datasets, MMA and Ring-A-Bell, replace
i2p
withmma
andring-a-bell
respectively.We evaluate nudity in images using the NudeNet detector. The script saves images and a json files with NudeNet scores reported in the paper in the
results/
folder. -
Object Erasing
To evaluate object erasure with ConceptPrune, run
python benchmarking/object_erase.py --target <object> --baseline concept-prune --removal_mode erase --ckpt_name <path to checkpoint>
To check interference of concept removal with unrelated classes, run
python benchmarking/object_erase.py --target <object> --baseline concept-prune --removal_mode keep --ckpt_name <path to checkpoint>
where
<object>
is the name of a class in ImageNette classes. he script saves images and a json files with ResNet50 accuracies reported in the paper in theresults/
folder. -
Gender reversal
To evaluate gender reversal from Female to Male, run
python benchmarking/gender_reversal.py --target male --ckpt_name <path to checkpoint>
Replace
male
withfemale
to reverse gender from Male to Female. We calculate the success of gender reversal using CLIP to classify between males females. The script saves images in theresults/
folder for 250 seeds. -
COCO evaluation
To evaluate ConceptPrune on COCO dataset, run
python benchmarking/eval_coco.py --target <target> --baseline concept-prune --ckpt_name <path to checkpoint>
-
Memorization
To evaluate ConceptPrune on COCO dataset, run
python benchmarking/inference_mem.py --target memorize_$i$ --baseline concept-prune --ckpt_name <path to checkpoint>
This will save images and calculate SSCD and CLIP score and store the results in a json file. We run this script for 10 different seeds for every model and report average performance.
Cite us!
If you find our paper useful, please consider citing our work.
@article{chavhan2024conceptpruneconcepteditingdiffusion,
title={ConceptPrune: Concept Editing in Diffusion Models via Skilled Neuron Pruning},
author={Ruchika Chavhan and Da Li and Timothy Hospedales},
year={2024},
journal={ArXiv}
}
@article{chavhan2024conceptpruneconcepteditingdiffusion,
title={Memorized Images in Diffusion Models share a Subspace that can be Located and Deleted},
author={Ruchika Chavhan and Ondrej Bohdal and Yongshuo Zong and Da Li and Timothy Hospedales},
year={2024},
journal={ArXiv}
}
Contact
Please contact ruchika.chavhan@ed.ac.uk for any questions!