diff --git a/README.md b/README.md index 3265451e3..6e6ead427 100644 --- a/README.md +++ b/README.md @@ -377,3 +377,12 @@ You could consider adding a suitable metric to the training loop yourself after url = {https://api.semanticscholar.org/CorpusID:270391454} } ``` + +```bibtex +@inproceedings{Sadat2024EliminatingOA, + title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models}, + author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273098845} +} +``` diff --git a/denoising_diffusion_pytorch/classifier_free_guidance.py b/denoising_diffusion_pytorch/classifier_free_guidance.py index 6970ccec2..6e403c7ff 100644 --- a/denoising_diffusion_pytorch/classifier_free_guidance.py +++ b/denoising_diffusion_pytorch/classifier_free_guidance.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from torch.amp import autocast -from einops import rearrange, reduce, repeat +from einops import rearrange, reduce, repeat, pack, unpack from einops.layers.torch import Rearrange from tqdm.auto import tqdm @@ -54,6 +54,15 @@ def convert_image_to_fn(img_type, image): return image.convert(img_type) return image +def pack_one_with_inverse(x, pattern): + packed, packed_shape = pack([x], pattern) + + def inverse(x, inverse_pattern = None): + inverse_pattern = default(inverse_pattern, pattern) + return unpack(x, packed_shape, inverse_pattern)[0] + + return packed, inverse + # normalization functions def normalize_to_neg_one_to_one(img): @@ -75,6 +84,19 @@ def prob_mask_like(shape, prob, device): else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob +def project(x, y): + x, inverse = pack_one_with_inverse(x, 'b *') + y, _ = pack_one_with_inverse(y, 'b *') + + dtype = x.dtype + x, y = x.double(), y.double() + unit = F.normalize(y, dim = -1) + + parallel = (x * unit).sum(dim = -1, keepdim = True) * unit + orthogonal = x - parallel + + return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype) + # small helper modules class Residual(nn.Module): @@ -357,6 +379,8 @@ def forward_with_cond_scale( *args, cond_scale = 1., rescaled_phi = 0., + remove_parallel_component = True, + keep_parallel_frac = 0., **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) @@ -365,7 +389,13 @@ def forward_with_cond_scale( return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) - scaled_logits = null_logits + (logits - null_logits) * cond_scale + update = logits - null_logits + + if remove_parallel_component: + parallel, orthog = project(update, logits) + update = orthog + parallel * keep_parallel_frac + + scaled_logits = logits + update * (cond_scale - 1.) if rescaled_phi == 0.: return scaled_logits, null_logits diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index ed6ed8969..2b6bf429f 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '2.0.17' +__version__ = '2.0.18'