Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support LoRA for clip text encoder in diffusers #2479

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 210 additions & 55 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""

import argparse
import itertools
import json
import logging
import math
import os
Expand Down Expand Up @@ -195,6 +197,44 @@ def parse_args():
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")

# lora args
parser.add_argument("--use_peft", action="store_true", help="Whether to use peft to support lora")
parser.add_argument("--lora_r", type=int, default=4, help="Lora rank, only used if use_lora is True")
parser.add_argument("--lora_alpha", type=int, default=32, help="Lora alpha, only used if lora is True")
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True")
parser.add_argument(
"--lora_bias",
type=str,
default="none",
help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True",
)
parser.add_argument(
"--lora_text_encoder_r",
type=int,
default=4,
help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True",
)
parser.add_argument(
"--lora_text_encoder_alpha",
type=int,
default=32,
help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True",
)
parser.add_argument(
"--lora_text_encoder_dropout",
type=float,
default=0.0,
help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True",
)
parser.add_argument(
"--lora_text_encoder_bias",
type=str,
default="none",
help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
)

parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
Expand Down Expand Up @@ -429,11 +469,6 @@ def main():
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)

text_encoder.requires_grad_(False)

# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
Expand All @@ -443,43 +478,79 @@ def main():
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

if args.use_peft:

from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict

UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]

config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=UNET_TARGET_MODULES,
lora_dropout=args.lora_dropout,
bias=args.lora_bias,
)
unet = LoraModel(config, unet)

vae.requires_grad_(False)
if args.train_text_encoder:

config = LoraConfig(
r=args.lora_text_encoder_r,
lora_alpha=args.lora_text_encoder_alpha,
target_modules=TEXT_ENCODER_TARGET_MODULES,
lora_dropout=args.lora_text_encoder_dropout,
bias=args.lora_text_encoder_bias,
)
text_encoder = LoraModel(config, text_encoder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that we cannot really load LoraModel later for inference currently as it's created in a somewhat hacky way here: https://github.com/huggingface/peft/blob/8358b2744555e8c18262f7befd7ef040527a6f0f/src/peft/tuners/lora.py#L90

Could we maybe move everything to the research_folder project instead of adding it to the "easy" LoRA example script?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make senses to me

else:

# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)

text_encoder.requires_grad_(False)

# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers

# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)

unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers

# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)

unet.set_attn_processor(lora_attn_procs)

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
Expand All @@ -493,8 +564,6 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

lora_layers = AttnProcsLayers(unet.attn_processors)

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
Expand All @@ -518,13 +587,28 @@ def main():
else:
optimizer_cls = torch.optim.AdamW

optimizer = optimizer_cls(
lora_layers.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
if args.peft:
# Optimizer creation
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
optimizer = optimizer_cls(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
else:
optimizer = optimizer_cls(
lora_layers.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)

# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
Expand Down Expand Up @@ -645,9 +729,19 @@ def collate_fn(examples):
)

# Prepare everything with our `accelerator`.
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
)
if args.peft:
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
else:
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -705,6 +799,8 @@ def collate_fn(examples):

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
Expand Down Expand Up @@ -751,7 +847,14 @@ def collate_fn(examples):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers.parameters()
if args.peft:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
else:
params_to_clip = lora_layers.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
Expand Down Expand Up @@ -786,6 +889,7 @@ def collate_fn(examples):
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
torch_dtype=weight_dtype,
)
Expand Down Expand Up @@ -821,8 +925,24 @@ def collate_fn(examples):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)
if args.use_peft:
lora_config = {}
state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet))
lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True)
if args.train_text_encoder:
text_encoder_state_dict = get_peft_model_state_dict(
text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
)
text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
state_dict.update(text_encoder_state_dict)
lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True)

accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt"))
with open(os.path.join(args.output_dir, f"{args.instance_prompt}_lora_config.json"), "w") as f:
json.dump(lora_config, f)
else:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)

if args.push_to_hub:
save_model_card(
Expand All @@ -839,10 +959,45 @@ def collate_fn(examples):
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
)
pipeline = pipeline.to(accelerator.device)

# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)
if args.use_peft:

def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype):
with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f:
lora_config = json.load(f)
print(lora_config)

checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt"
lora_checkpoint_sd = torch.load(checkpoint)
unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
text_encoder_lora_ds = {
k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
}

unet_config = LoraConfig(**lora_config["peft_config"])
pipe.unet = LoraModel(unet_config, pipe.unet)
set_peft_model_state_dict(pipe.unet, unet_lora_ds)

if "text_encoder_peft_config" in lora_config:
text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)

if dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()

pipe.to(device)
return pipe

pipeline = load_and_set_lora_ckpt(
pipeline, args.output_dir, args.instance_prompt, accelerator.device, weight_dtype
)

else:
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
Expand Down