Peng Xing12* Β· Haofan Wang1* Β· Yanpeng Sun2 Β· Qixun Wang1 Β· Xu Bai13 Β· Hao Ai14 Β· Renyuan Huang15 Zechao Li2β
1InstantX Team Β· 2Nanjing University of Science and Technology Β· 3Xiaohongshu Β· 4Beihang University Β· 5Peking University
*equal contributions, βcorresponding authors
2024/09/04
: π₯ We released the gradio code. You can simply configure it and use it directly.2024/09/03
: π₯ We released the online demo on Hugggingface.2024/09/03
: π₯ We released the pre-trained weight.2024/09/03
: π₯ We released the initial version of the inference code.2024/08/30
: π₯ We released the technical report on arXiv2024/07/15
: π₯ We released the homepage.
- technical report
- inference code
- pre-trained weight [4_16]
- pre-trained weight [4_32]
- online demo
- pre-trained weight_v2 [4_32]
- IMAGStyle dataset
- training code
- more pre-trained weight
This repo, named CSGO, contains the official PyTorch implementation of our paper CSGO: Content-Style Composition in Text-to-Image Generation. We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) π.
π₯ Our CSGO achieves image-driven style transfer, text-driven stylized synthesis, and text editing-driven stylized synthesis.
π₯ For more results, visit our homepage π₯
git clone https://github.com/instantX-research/CSGO
cd CSGO
# create env using conda
conda create -n CSGO python=3.9
conda activate CSGO
# install dependencies with pip
# for Linux and Windows users
pip install -r requirements.txt
We currently release two model weights.
Mode | content token | style token | Other |
---|---|---|---|
csgo.bin | 4 | 16 | - |
csgo_4_32.bin | 4 | 32 | Deepspeed zero2 |
csgo_4_32_v2.bin | 4 | 32 | Deepspeed zero2+more(coming soon) |
The easiest way to download the pretrained weights is from HuggingFace:
# first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
git lfs install
# clone and move the weights
git clone https://huggingface.co/InstantX/CSGO
Our method is fully compatible with SDXL, VAE, ControlNet, and Image Encoder. Please download them and place them in the ./base_models folder.
tips:If you expect to load Controlnet directly using ControlNetPipeline as in CSGO, do the following:
git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic
mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors
import torch
from ip_adapter.utils import BLOCKS as BLOCKS
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
from PIL import Image
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionXLControlNetPipeline,
)
from ip_adapter import CSGO
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
base_model_path = "./base_models/stable-diffusion-xl-base-1.0"
image_encoder_path = "./base_models/IP-Adapter/sdxl_models/image_encoder"
csgo_ckpt = "./CSGO/csgo.bin"
pretrained_vae_name_or_path ='./base_models/sdxl-vae-fp16-fix'
controlnet_path = "./base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic"
weight_dtype = torch.float16
vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
add_watermarker=False,
vae=vae
)
pipe.enable_vae_tiling()
target_content_blocks = BLOCKS['content']
target_style_blocks = BLOCKS['style']
controlnet_target_content_blocks = controlnet_BLOCKS['content']
controlnet_target_style_blocks = controlnet_BLOCKS['style']
csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32,
target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet=False,controlnet_adapter=True,
controlnet_target_content_blocks=controlnet_target_content_blocks,
controlnet_target_style_blocks=controlnet_target_style_blocks,
content_model_resampler=True,
style_model_resampler=True,
load_controlnet=False,
)
style_name = 'img_0.png'
content_name = 'img_0.png'
style_image = "../assets/{}".format(style_name)
content_image = Image.open('../assets/{}'.format(content_name)).convert('RGB')
caption ='a small house with a sheep statue on top of it'
num_sample=4
#image-driven style transfer
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
content_scale=1.0,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.6,
)
#text-driven stylized synthesis
caption='a cat'
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
content_scale=1.0,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.01,
)
#text editing-driven stylized synthesis
caption='a small house'
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
content_scale=1.0,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.4,
)
We also provide a Gradio interface for a better experience, just run by:
# For Linux and Windows users (and macOS)
python gradio/app.py
If you don't have the resources to configure it, we provide an online demo.
π₯ For more results, visit our homepage π₯
This project is developed by InstantX Team and Xiaohongshu, all copyright reserved. Sincere thanks to xiaohongshu for providing the computing resources.
If you find CSGO useful for your research, welcome to π this repo and cite our work using the following BibTeX:
@article{xing2024csgo,
title={CSGO: Content-Style Composition in Text-to-Image Generation},
author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},
year={2024},
journal = {arXiv 2408.16766},
}