Skip to content

Commit

Permalink
add option to use cfg ++
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 16, 2024
1 parent 5a0e07f commit 4019202
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,14 @@ You could consider adding a suitable metric to the training loop yourself after
url = {https://api.semanticscholar.org/CorpusID:270562607}
}
```

```bibtex
@article{Chung2024CFGMC,
title = {CFG++: Manifold-constrained Classifier Free Guidance for Diffusion Models},
author = {Hyungjin Chung and Jeongsol Kim and Geon Yeong Park and Hyelin Nam and Jong Chul Ye},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.08070},
url = {https://api.semanticscholar.org/CorpusID:270391454}
}
```
29 changes: 22 additions & 7 deletions denoising_diffusion_pytorch/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,13 @@ def forward_with_cond_scale(
scaled_logits = null_logits + (logits - null_logits) * cond_scale

if rescaled_phi == 0.:
return scaled_logits
return scaled_logits, null_logits

std_fn = partial(torch.std, dim = tuple(range(1, scaled_logits.ndim)), keepdim = True)
rescaled_logits = scaled_logits * (std_fn(logits) / std_fn(scaled_logits))
interpolated_rescaled_logits = rescaled_logits * rescaled_phi + scaled_logits * (1. - rescaled_phi)

return rescaled_logits * rescaled_phi + scaled_logits * (1. - rescaled_phi)
return interpolated_rescaled_logits, null_logits

def forward(
self,
Expand Down Expand Up @@ -478,7 +479,8 @@ def __init__(
ddim_sampling_eta = 1.,
offset_noise_strength = 0.,
min_snr_loss_weight = False,
min_snr_gamma = 5
min_snr_gamma = 5,
use_cfg_plus_plus = False # https://arxiv.org/pdf/2406.08070
):
super().__init__()
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
Expand Down Expand Up @@ -507,6 +509,10 @@ def __init__(
timesteps, = betas.shape
self.num_timesteps = int(timesteps)

# use cfg++ when ddim sampling

self.use_cfg_plus_plus = use_cfg_plus_plus

# sampling related parameters

self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
Expand Down Expand Up @@ -604,24 +610,33 @@ def q_posterior(self, x_start, x_t, t):
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def model_predictions(self, x, t, classes, cond_scale = 6., rescaled_phi = 0.7, clip_x_start = False):
model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale, rescaled_phi = rescaled_phi)
model_output, model_output_null = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale, rescaled_phi = rescaled_phi)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity

if self.objective == 'pred_noise':
pred_noise = model_output
pred_noise = model_output if not self.use_cfg_plus_plus else model_output_null

x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)

elif self.objective == 'pred_x0':
x_start = model_output
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
x_start_for_pred_noise = x_start if not self.use_cfg_plus_plus else maybe_clip(model_output_null)

pred_noise = self.predict_noise_from_start(x, t, x_start_for_pred_noise)

elif self.objective == 'pred_v':
v = model_output
x_start = self.predict_start_from_v(x, t, v)
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)

x_start_for_pred_noise = x_start
if self.use_cfg_plus_plus:
x_start_for_pred_noise = self.predict_start_from_v(x, t, model_output_null)
x_start_for_pred_noise = maybe_clip(x_start_for_pred_noise)

pred_noise = self.predict_noise_from_start(x, t, x_start_for_pred_noise)

return ModelPrediction(pred_noise, x_start)

Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.15'
__version__ = '2.0.16'

0 comments on commit 4019202

Please sign in to comment.