Skip to content

Commit

Permalink
max_number actually not there - but it's in the generation_config!
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Nov 8, 2024
1 parent 744cc68 commit 5e2ac47
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,12 @@ def _forward(self, model_inputs, **generate_kwargs):

input_tokens = model_inputs["input_tokens"]
del model_inputs["input_tokens"]
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
config = self.model.generation_config
else:
config = self.model.config
generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length)
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
output = self.model.generate(
**model_inputs,
Expand Down

0 comments on commit 5e2ac47

Please sign in to comment.