Skip to content

Commit

Permalink
address #329
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 24, 2024
1 parent b281d55 commit ec0a1c7
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions denoising_diffusion_pytorch/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ def __init__(
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)

self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
self.final_res_block = ResnetBlock(init_dim * 2, init_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)

def forward_with_cond_scale(
self,
Expand Down
4 changes: 2 additions & 2 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def __init__(
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)

self.final_res_block = resnet_block(dim * 2, dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
self.final_res_block = resnet_block(init_dim * 2, init_dim)
self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)

@property
def downsample_factor(self):
Expand Down
4 changes: 2 additions & 2 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ def __init__(
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)

self.final_res_block = resnet_block(dim * 2, dim)
self.final_conv = nn.Conv1d(dim, self.out_dim, 1)
self.final_res_block = resnet_block(init_dim * 2, init_dim)
self.final_conv = nn.Conv1d(init_dim, self.out_dim, 1)

def forward(self, x, time, x_self_cond = None):
if self.self_condition:
Expand Down
4 changes: 2 additions & 2 deletions denoising_diffusion_pytorch/repaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ def __init__(
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)

self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
self.final_res_block = ResnetBlock(init_dim * 2, init_dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)

@property
def downsample_factor(self):
Expand Down
4 changes: 2 additions & 2 deletions denoising_diffusion_pytorch/simple_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ def __init__(
default_out_dim = input_channels
self.out_dim = default(out_dim, default_out_dim)

self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
self.final_res_block = ResnetBlock(init_dim * 2, init_dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)

def forward(self, x, time):
x = self.init_img_transform(x)
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.10'
__version__ = '2.0.12'

0 comments on commit ec0a1c7

Please sign in to comment.