Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Signed-off-by: jtigue-bdai <[email protected]>
  • Loading branch information
jtigue-bdai authored Oct 11, 2024
1 parent 5b99208 commit 455ae61
Showing 1 changed file with 5 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def constant_noise(data: torch.Tensor, cfg: noise_cfg.ConstantNoiseCfg) -> torch

# fix tensor device for bias on first call and update config parameters
if isinstance(cfg.bias, torch.Tensor):
cfg.bias = cfg.bias.to(device=data.device)
cfg.bias = cfg.bias.to(device=data.device)

if cfg.operation == "add":
return data + cfg.bias
Expand All @@ -55,12 +55,10 @@ def uniform_noise(data: torch.Tensor, cfg: noise_cfg.UniformNoiseCfg) -> torch.T

# fix tensor device for n_max on first call and update config parameters
if isinstance(cfg.n_max, torch.Tensor):
if cfg.n_max.device is not data.device:
cfg.n_max = cfg.n_max.to(data.device)
cfg.n_max = cfg.n_max.to(data.device)
# fix tensor device for n_min on first call and update config parameters
if isinstance(cfg.n_min, torch.Tensor):
if cfg.n_min.device is not data.device:
cfg.n_min = cfg.n_min.to(data.device)
cfg.n_min = cfg.n_min.to(data.device)

if cfg.operation == "add":
return data + torch.rand_like(data) * (cfg.n_max - cfg.n_min) + cfg.n_min
Expand All @@ -85,12 +83,10 @@ def gaussian_noise(data: torch.Tensor, cfg: noise_cfg.GaussianNoiseCfg) -> torch

# fix tensor device for mean on first call and update config parameters
if isinstance(cfg.mean, torch.Tensor):
if cfg.mean.device is not data.device:
cfg.mean = cfg.mean.to(data.device)
cfg.mean = cfg.mean.to(data.device)
# fix tensor device for std on first call and update config parameters
if isinstance(cfg.std, torch.Tensor):
if cfg.std.device is not data.device:
cfg.std = cfg.std.to(data.device)
cfg.std = cfg.std.to(data.device)

if cfg.operation == "add":
return data + cfg.mean + cfg.std * torch.randn_like(data)
Expand Down

0 comments on commit 455ae61

Please sign in to comment.