Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
explicit about number of tasks for y_data_components
Browse files Browse the repository at this point in the history
Co-authored-by: Alison Appling <[email protected]>
  • Loading branch information
jzwart and aappling-usgs authored Jun 3, 2021
1 parent c582d3d commit ab4cff1
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions river_dl/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,8 @@ def nnse_one_var_samplewise(data, y_pred, var_idx, tasks):


def y_data_components(data, y_pred, var_idx, tasks):
if tasks == 2:
weights = data[:, :, -2:]
y_true = data[:, :, :-2]
else:
weights = data[:, :, -1:]
y_true = data[:, :, :-1]
weights = data[:, :, -tasks:]
y_true = data[:, :, :-tasks]

# ensure y_pred, weights, and y_true are all tensors the same data type
y_true = tf.convert_to_tensor(y_true)
Expand Down

0 comments on commit ab4cff1

Please sign in to comment.