From ab4cff11f6a728e2b75cda4d8fbb0c9e8f292397 Mon Sep 17 00:00:00 2001 From: Jake Zwart Date: Thu, 3 Jun 2021 11:07:48 -0400 Subject: [PATCH] explicit about number of tasks for y_data_components Co-authored-by: Alison Appling --- river_dl/loss_functions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/river_dl/loss_functions.py b/river_dl/loss_functions.py index 88ae7b3..5f30809 100644 --- a/river_dl/loss_functions.py +++ b/river_dl/loss_functions.py @@ -80,12 +80,8 @@ def nnse_one_var_samplewise(data, y_pred, var_idx, tasks): def y_data_components(data, y_pred, var_idx, tasks): - if tasks == 2: - weights = data[:, :, -2:] - y_true = data[:, :, :-2] - else: - weights = data[:, :, -1:] - y_true = data[:, :, :-1] + weights = data[:, :, -tasks:] + y_true = data[:, :, :-tasks] # ensure y_pred, weights, and y_true are all tensors the same data type y_true = tf.convert_to_tensor(y_true)