diff --git a/river_dl/loss_functions.py b/river_dl/loss_functions.py index b1989d0..431b972 100644 --- a/river_dl/loss_functions.py +++ b/river_dl/loss_functions.py @@ -80,12 +80,8 @@ def nnse_one_var_samplewise(data, y_pred, var_idx, tasks): def y_data_components(data, y_pred, var_idx, tasks): - if tasks == 2: - weights = data[:, :, -2:] - y_true = data[:, :, :-2] - else: - weights = data[:, :, -1:] - y_true = data[:, :, :-1] + weights = data[:, :, -tasks:] + y_true = data[:, :, :-tasks] # ensure y_pred, weights, and y_true are all tensors the same data type y_true = tf.convert_to_tensor(y_true)