diff --git a/loralib/layers.py b/loralib/layers.py index 0e54a64b..80e45b54 100644 --- a/loralib/layers.py +++ b/loralib/layers.py @@ -56,8 +56,8 @@ def reset_parameters(self): nn.Embedding.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero - nn.init.zeros_(self.lora_A) - nn.init.normal_(self.lora_B) + nn.init.normal_(self.lora_A) + nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): nn.Embedding.train(self, mode)