An implementation of DiffFit: Unlocking Transferability of Large Diffusion Models via Simple Parameter-Efficient Fine-Tuning by using d🧨ffusers.

⚠️ This is working in progress. You might need to adjust hyper-params, especially the learning rate.

What is DiffFit?

DiffFit is an extension of an existing PEFT called BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models where only the bias-terms of the model are trainable. In addition to bias-terms, learnable scaling factors to each transformer block and LN are trained for better adaption.

<img src="difffit-overview.png" alt= "overview" width="400">

In the paper, they applied DiffFit to DiT instead of UNet-based DMs but mentioned DiffFit can be generalized to other DMs e.g. text-to-image models like Stable Diffusion. So, in this repo, I extended it for Stable Diffusion in domain-tuning setting.

Also, you can try BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models here.


$ git clone https://github.com/mkshing/DiffFit-pytorch.git
$ pip install -r requirements.txt


export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

accelerate launch train_difffit.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --class_data_dir=$CLASS_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --use_8bit_adam \
  --learning_rate=1e-3 \
  --add_vlb_loss \
  --vlb_lambda=0.001 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_class_images=200 \
  --max_train_steps=500 \
  # you can add prior preservation loss
  # --class_prompt="a photo of dog" \
  # --with_prior_preservation --prior_loss_weight=1.0 \
  # optionally add extra params in text encoder
  # --train_text_encoder \

Add --bitfit for BitFit.


import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from difffit_pytorch import load_unet_for_difffit, load_text_encoder_for_difffit, load_config_for_difffit

pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
efficient_weights_ckpt = "ckpt-dir-path"

training_args = load_config_for_difffit(args.efficient_weights_ckpt)
unet = load_unet_for_difffit(args.pretrained_model_name_or_path, efficient_weights_ckpt=args.efficient_weights_ckpt, is_bitfit=training_args["bitfit"], subfolder="unet")
text_encoder = load_text_encoder_for_difffit(args.pretrained_model_name_or_path, efficient_weights_ckpt=args.efficient_weights_ckpt, is_bitfit=training_args["bitfit"], subfolder="text_encoder")

# load pipe
pipe = StableDiffusionPipeline.from_pretrained(
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]