From 5e2ac4721c642b19ca325f6511ca6031875ee297 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Fri, 8 Nov 2024 14:45:04 -0500 Subject: [PATCH] max_number actually not there - but it's in the generation_config! --- .../translation/huggingface/hugging_face_nmt_engine.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 4da476c..04086af 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -236,8 +236,12 @@ def _forward(self, model_inputs, **generate_kwargs): input_tokens = model_inputs["input_tokens"] del model_inputs["input_tokens"] - generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length) - generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length) + if hasattr(self.model, "generation_config") and self.model.generation_config is not None: + config = self.model.generation_config + else: + config = self.model.config + generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length) + generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length) self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) output = self.model.generate( **model_inputs,