From 5a0e07fbaeefca9638b47dcf20872a6c1394cdc7 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 16 Aug 2024 07:36:27 -0700 Subject: [PATCH] add immiscible diffusion --- README.md | 11 +++++++++++ .../denoising_diffusion_pytorch.py | 19 ++++++++++++++++++- denoising_diffusion_pytorch/version.py | 2 +- setup.py | 1 + 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e6074d041..20236d90c 100644 --- a/README.md +++ b/README.md @@ -355,3 +355,14 @@ You could consider adding a suitable metric to the training loop yourself after url = {https://api.semanticscholar.org/CorpusID:265659032} } ``` + +```bibtex +@article{Li2024ImmiscibleDA, + title = {Immiscible Diffusion: Accelerating Diffusion Training with Noise Assignment}, + author = {Yiheng Li and Heyang Jiang and Akio Kodaira and Masayoshi Tomizuka and Kurt Keutzer and Chenfeng Xu}, + journal = {ArXiv}, + year = {2024}, + volume = {abs/2406.12303}, + url = {https://api.semanticscholar.org/CorpusID:270562607} +} +``` diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index 6a9f7da92..3039e79f4 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -20,6 +20,8 @@ from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange +from scipy.optimize import linear_sum_assignment + from PIL import Image from tqdm.auto import tqdm from ema_pytorch import EMA @@ -488,7 +490,8 @@ def __init__( auto_normalize = True, offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556 - min_snr_gamma = 5 + min_snr_gamma = 5, + immiscible = False ): super().__init__() assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) @@ -564,6 +567,10 @@ def __init__( register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + # immiscible diffusion + + self.immiscible = immiscible + # offset noise strength - in blogpost, they claimed 0.1 was ideal self.offset_noise_strength = offset_noise_strength @@ -759,10 +766,20 @@ def interpolate(self, x1, x2, t = None, lam = 0.5): return img + def noise_assignment(self, x_start, noise): + x_start, noise = tuple(rearrange(t, 'b ... -> b (...)') for t in (x_start, noise)) + dist = torch.cdist(x_start, noise) + _, assign = linear_sum_assignment(dist.cpu()) + return torch.from_numpy(assign).to(dist.device) + @autocast(enabled = False) def q_sample(self, x_start, t, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) + if self.immiscible: + assign = self.noise_assignment(x_start, noise) + noise = noise[assign] + return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index 2c488f9d1..897e3130d 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '2.0.12' +__version__ = '2.0.15' diff --git a/setup.py b/setup.py index ec301a79f..1118bb080 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ 'numpy', 'pillow', 'pytorch-fid', + 'scipy', 'torch', 'torchvision', 'tqdm'