diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index 41e2cebb9..0c00c6c23 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -27,7 +27,6 @@ from accelerate import Accelerator from denoising_diffusion_pytorch.attend import Attend -from denoising_diffusion_pytorch.fid_evaluation import FIDEvaluation from denoising_diffusion_pytorch.version import __version__ @@ -145,11 +144,12 @@ def forward(self, x): # building block modules class Block(Module): - def __init__(self, dim, dim_out): + def __init__(self, dim, dim_out, dropout = 0.): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) self.norm = RMSNorm(dim_out) self.act = nn.SiLU() + self.dropout = nn.Dropout(dropout) def forward(self, x, scale_shift = None): x = self.proj(x) @@ -160,17 +160,17 @@ def forward(self, x, scale_shift = None): x = x * (scale + 1) + shift x = self.act(x) - return x + return self.dropout(x) class ResnetBlock(Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None): + def __init__(self, dim, dim_out, *, time_emb_dim = None, dropout = 0.): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None - self.block1 = Block(dim, dim_out) + self.block1 = Block(dim, dim_out, dropout = dropout) self.block2 = Block(dim_out, dim_out) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() @@ -285,6 +285,7 @@ def __init__( random_fourier_features = False, learned_sinusoidal_dim = 16, sinusoidal_pos_emb_theta = 10000, + dropout = 0., attn_dim_head = 32, attn_heads = 4, full_attn = None, # defaults to full attention only for inner most layer @@ -336,7 +337,10 @@ def __init__( assert len(full_attn) == len(dim_mults) + # prepare blocks + FullAttention = partial(Attention, flash = flash_attn) + resnet_block = partial(ResnetBlock, time_emb_dim = time_dim, dropout = dropout) # layers @@ -350,16 +354,16 @@ def __init__( attn_klass = FullAttention if layer_full_attn else LinearAttention self.downs.append(ModuleList([ - ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), - ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), + resnet_block(dim_in, dim_in), + resnet_block(dim_in, dim_in), attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads), Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) ])) mid_dim = dims[-1] - self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = resnet_block(mid_dim, mid_dim) self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) - self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = resnet_block(mid_dim, mid_dim) for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): is_last = ind == (len(in_out) - 1) @@ -367,8 +371,8 @@ def __init__( attn_klass = FullAttention if layer_full_attn else LinearAttention self.ups.append(ModuleList([ - ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + resnet_block(dim_out + dim_in, dim_out), + resnet_block(dim_out + dim_in, dim_out), attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads), Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) ])) @@ -376,7 +380,7 @@ def __init__( default_out_dim = channels * (1 if not learned_variance else 2) self.out_dim = default(out_dim, default_out_dim) - self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim) + self.final_res_block = resnet_block(dim * 2, dim) self.final_conv = nn.Conv2d(dim, self.out_dim, 1) @property @@ -954,11 +958,14 @@ def __init__( self.calculate_fid = calculate_fid and self.accelerator.is_main_process if self.calculate_fid: + from denoising_diffusion_pytorch.fid_evaluation import FIDEvaluation + if not is_ddim_sampling: self.accelerator.print( "WARNING: Robust FID computation requires a lot of generated samples and can therefore be very time consuming."\ "Consider using DDIM sampling to save time." ) + self.fid_scorer = FIDEvaluation( batch_size=self.batch_size, dl=self.dl, diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py index 5da033f27..68a60d6f9 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py @@ -7,6 +7,7 @@ import torch from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList import torch.nn.functional as F from torch.cuda.amp import autocast from torch.optim import Adam @@ -83,7 +84,7 @@ def __getitem__(self, idx): # small helper modules -class Residual(nn.Module): +class Residual(Module): def __init__(self, fn): super().__init__() self.fn = fn @@ -100,7 +101,7 @@ def Upsample(dim, dim_out = None): def Downsample(dim, dim_out = None): return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1) -class RMSNorm(nn.Module): +class RMSNorm(Module): def __init__(self, dim): super().__init__() self.g = nn.Parameter(torch.ones(1, dim, 1)) @@ -108,7 +109,7 @@ def __init__(self, dim): def forward(self, x): return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5) -class PreNorm(nn.Module): +class PreNorm(Module): def __init__(self, dim, fn): super().__init__() self.fn = fn @@ -120,7 +121,7 @@ def forward(self, x): # sinusoidal positional embeds -class SinusoidalPosEmb(nn.Module): +class SinusoidalPosEmb(Module): def __init__(self, dim, theta = 10000): super().__init__() self.dim = dim @@ -135,7 +136,7 @@ def forward(self, x): emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb -class RandomOrLearnedSinusoidalPosEmb(nn.Module): +class RandomOrLearnedSinusoidalPosEmb(Module): """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ @@ -154,12 +155,13 @@ def forward(self, x): # building block modules -class Block(nn.Module): - def __init__(self, dim, dim_out): +class Block(Module): + def __init__(self, dim, dim_out, dropout = 0.): super().__init__() self.proj = nn.Conv1d(dim, dim_out, 3, padding = 1) self.norm = RMSNorm(dim_out) self.act = nn.SiLU() + self.dropout = nn.Dropout(dropout) def forward(self, x, scale_shift = None): x = self.proj(x) @@ -170,17 +172,17 @@ def forward(self, x, scale_shift = None): x = x * (scale + 1) + shift x = self.act(x) - return x + return self.dropout(x) -class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None): +class ResnetBlock(Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None, dropout = 0.): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None - self.block1 = Block(dim, dim_out) + self.block1 = Block(dim, dim_out, dropout = dropout) self.block2 = Block(dim_out, dim_out) self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity() @@ -198,7 +200,7 @@ def forward(self, x, time_emb = None): return h + self.res_conv(x) -class LinearAttention(nn.Module): +class LinearAttention(Module): def __init__(self, dim, heads = 4, dim_head = 32): super().__init__() self.scale = dim_head ** -0.5 @@ -227,7 +229,7 @@ def forward(self, x): out = rearrange(out, 'b h c n -> b (h c) n', h = self.heads) return self.to_out(out) -class Attention(nn.Module): +class Attention(Module): def __init__(self, dim, heads = 4, dim_head = 32): super().__init__() self.scale = dim_head ** -0.5 @@ -253,7 +255,7 @@ def forward(self, x): # model -class Unet1D(nn.Module): +class Unet1D(Module): def __init__( self, dim, @@ -261,6 +263,7 @@ def __init__( out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, + dropout = 0., self_condition = False, learned_variance = False, learned_sinusoidal_cond = False, @@ -304,33 +307,35 @@ def __init__( nn.Linear(time_dim, time_dim) ) + resnet_block = partial(ResnetBlock, time_emb_dim = time_dim, dropout = dropout) + # layers - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) + self.downs = ModuleList([]) + self.ups = ModuleList([]) num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) - self.downs.append(nn.ModuleList([ - ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), - ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), + self.downs.append(ModuleList([ + resnet_block(dim_in, dim_in), + resnet_block(dim_in, dim_in), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Downsample(dim_in, dim_out) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1) ])) mid_dim = dims[-1] - self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = resnet_block(mid_dim, mid_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, dim_head = attn_dim_head, heads = attn_heads))) - self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = resnet_block(mid_dim, mid_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind == (len(in_out) - 1) - self.ups.append(nn.ModuleList([ - ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + self.ups.append(ModuleList([ + resnet_block(dim_out + dim_in, dim_out), + resnet_block(dim_out + dim_in, dim_out), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Upsample(dim_out, dim_in) if not is_last else nn.Conv1d(dim_out, dim_in, 3, padding = 1) ])) @@ -338,7 +343,7 @@ def __init__( default_out_dim = channels * (1 if not learned_variance else 2) self.out_dim = default(out_dim, default_out_dim) - self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim) + self.final_res_block = resnet_block(dim * 2, dim) self.final_conv = nn.Conv1d(dim, self.out_dim, 1) def forward(self, x, time, x_self_cond = None): @@ -407,7 +412,7 @@ def cosine_beta_schedule(timesteps, s = 0.008): betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) -class GaussianDiffusion1D(nn.Module): +class GaussianDiffusion1D(Module): def __init__( self, model, diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index 13ce17d8e..4b259db3e 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '2.0.6' +__version__ = '2.0.7'