From e9a3c889d777756e755d1235dda7a7e01f13d20e Mon Sep 17 00:00:00 2001 From: gorold Date: Mon, 10 Jun 2024 23:29:55 +0800 Subject: [PATCH] fix:Moirai inference in foundation-time-series-arena --- .../xiuhmolpilli/models/foundational/moirai.py | 6 +++--- .../xiuhmolpilli/models/utils/gluonts_forecaster.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/experiments/foundation-time-series-arena/xiuhmolpilli/models/foundational/moirai.py b/experiments/foundation-time-series-arena/xiuhmolpilli/models/foundational/moirai.py index 2839e2cb..9c151e8c 100644 --- a/experiments/foundation-time-series-arena/xiuhmolpilli/models/foundational/moirai.py +++ b/experiments/foundation-time-series-arena/xiuhmolpilli/models/foundational/moirai.py @@ -21,12 +21,12 @@ def get_predictor(self, prediction_length: int) -> PyTorchPredictor: model = MoiraiForecast( module=MoiraiModule.from_pretrained(self.repo_id), prediction_length=prediction_length, - context_length=200, - patch_size="auto", + context_length=1000, + patch_size=32, num_samples=100, target_dim=1, feat_dynamic_real_dim=0, past_feat_dynamic_real_dim=0, ) - predictor = model.create_predictor(batch_size=32) + predictor = model.create_predictor(batch_size=512) return predictor diff --git a/experiments/foundation-time-series-arena/xiuhmolpilli/models/utils/gluonts_forecaster.py b/experiments/foundation-time-series-arena/xiuhmolpilli/models/utils/gluonts_forecaster.py index a879c329..ba02c106 100644 --- a/experiments/foundation-time-series-arena/xiuhmolpilli/models/utils/gluonts_forecaster.py +++ b/experiments/foundation-time-series-arena/xiuhmolpilli/models/utils/gluonts_forecaster.py @@ -58,7 +58,7 @@ def gluonts_instance_fcst_to_df( freq: str, model_name: str, ) -> pd.DataFrame: - point_forecast = fcst.mean + point_forecast = fcst.median h = len(point_forecast) dates = pd.date_range( fcst.start_date.to_timestamp(),