diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c34077fa2bfaf..ee9db7048f1f6 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -187,9 +187,11 @@ def process_weights_after_loading(self): self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data, (0, 128), "constant", 0), requires_grad=False) + torch.cuda.empty_cache() self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data, (0, 128), "constant", 0), requires_grad=False) + torch.cuda.empty_cache() return # If checkpoint is fp16, quantize here.