Skip to content

MLSAKIIT/stablediffusionlora

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

44 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Stable Diffusion 1.4 Fine-tuning with LoRA: Technical Implementation

Hacktoberfest 2024

This document outlines the technical implementation of fine-tuning Stable Diffusion 1.4 using Low-Rank Adaptation (LoRA). It provides a detailed guide for beginners to understand and contribute to the project.

Project Structure

sd-lora-finetuning/
├── CONTRIBUTING.md
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── requirements.txt
├── src/  (Example implementation)
│   ├── Dataset/
│   │   ├── ImageCaptions/
│   │   │   └── example1.txt
│   │   └── Images/
│   │       └── example1.png
│   ├── dataset.py
│   ├── generate.py
│   ├── lora.py
│   ├── main.py
│   ├── train.py
│   └── utils.py
└── CONTRIBUTIONS/
    └── Example1/
        ├── Dataset/
        │   ├── ImageCaptions/
        │   │   └── example1.txt
        │   └── Images/
        │       └── example1.png
        └── src/
            ├── dataset.py
            ├── generate.py
            ├── lora.py
            ├── main.py
            ├── train.py
            └── utils.py

  • src/: Contains the example implementation (refer to this for your contribution)
  • CONTRIBUTIONS/: Directory where participants should add their implementations
  • CONTRIBUTING.md and CODE_OF_CONDUCT.md: Guidelines and help regarding contributing(MUST READ!)
  • Other files in the root directory are for project documentation and setup

Technical Overview

1. LoRA Implementation (lora.py)

LoRA is implemented as follows:

a) LoRALayer class:

class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=1):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank, in_features)))
        self.lora_B = nn.Parameter(torch.zeros((out_features, rank)))
        self.scale = alpha / rank
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        return (x @ self.lora_A.T @ self.lora_B.T) * self.scale

b) apply_lora_to_model function:

def apply_lora_to_model(model, rank=4, alpha=1):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            lora_layer = LoRALayer(module.in_features, module.out_features, rank, alpha)
            setattr(module, 'lora', lora_layer)
    return model

Key concept: LoRA adds trainable low-rank matrices to existing layers, allowing for efficient fine-tuning.

2. Dataset Handling (dataset.py)

The CustomDataset class:

class CustomDataset(Dataset):
    def __init__(self, img_dir, caption_dir=None, transform=None):
        self.img_dir = img_dir
        self.caption_dir = caption_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        self.images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg', '.webp'))]

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)

        if self.caption_dir:
            caption_path = os.path.join(self.caption_dir, self.images[idx].rsplit('.', 1)[0] + '.txt')
            with open(caption_path, 'r') as f:
                caption = f.read().strip()
        else:
            caption = ""

        return image, caption

3. Training Process (train.py)

The train_loop function implements the core training logic:

def train_loop(dataloader, unet, text_encoder, vae, noise_scheduler, optimizer, device, num_epochs):
    for epoch in range(num_epochs):
        for batch in dataloader:
            images, captions = batch
            latents = vae.encode(images.to(device)).latent_dist.sample().detach()
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            
            text_embeddings = text_encoder(captions)[0]
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
            
            loss = F.mse_loss(noise_pred, noise)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

Key concept: We're training the model to denoise latent representations, conditioned on text embeddings.

4. Image Generation (generate.py)

def generate_image(prompt, pipeline, num_inference_steps=50):
    with torch.no_grad():
        image = pipeline(prompt, num_inference_steps=num_inference_steps).images[0]
    return image

Setup and Installation

  1. Clone the repository:

    git clone https://github.com/your-username/sd-lora-finetuning.git
    cd sd-lora-finetuning
    
  2. Create and activate a virtual environment:

    python -m venv venv
    source venv/bin/activate  # On Windows, use `venv\Scripts\activate`
    
  3. Install dependencies:

    pip install -r requirements.txt
    

Contributing

  1. Fork the repository and clone your fork.
  2. Create a new folder in the CONTRIBUTIONS directory with your username or project name.
  3. Implement your version of the LoRA fine-tuning following the structure in the src directory.
  4. Ensure you include a Dataset folder with example images and captions.
  5. Create a pull request with your contribution.

Refer to the src directory for an example of how to structure your contribution.

Refer to CONTRIBUTING.md for a detailed overview, if you're a beginner!

Technical Deep Dive

LoRA Mechanism

LoRA adapts the model by injecting trainable rank decomposition matrices into existing layers:

  1. For a layer with weight W, LoRA adds BA where B ∈ R^(d×r) and A ∈ R^(r×k)
  2. The output is computed as: h = Wx + BAx
  3. Only A and B are trained, keeping the original weights W frozen

This is implemented in the LoRALayer class:

def forward(self, x):
    return (x @ self.lora_A.T @ self.lora_B.T) * self.scale

Training Objective

The model is trained to predict the noise added to the latent representation:

  1. Images are encoded to latent space: z = Encode(x)
  2. Noise is added: z_noisy = z + ε
  3. The model predicts the noise: ε_pred = Model(z_noisy, t, text_embedding)
  4. Loss is computed: L = MSE(ε_pred, ε)

This is implemented in the training loop:

noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
loss = F.mse_loss(noise_pred, noise)

By learning to denoise, the model implicitly learns to generate images conditioned on text.

Customization and Extension

Customimzation and uniqueness is expected from each contributor.

  • Feel free to modify LoRALayer in lora.py to experiment with different LoRA architectures
  • Adjust the U-Net architecture in main.py by modifying which layers receive LoRA
  • Implement additional training techniques in train.py (e.g., gradient clipping, learning rate scheduling)

Resources