Skip to content

Commit

Permalink
[USGS-R#146] model def in snakefile, not predict.py
Browse files Browse the repository at this point in the history
also found bug in not passing spatial/time idx names all the way through
  • Loading branch information
jsadler2 committed Dec 1, 2021
1 parent ce413ce commit 8ffe6d0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 74 deletions.
6 changes: 3 additions & 3 deletions river_dl/postproc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def plot_ts(pred_file, obs_file, variable, out_file):
plt.savefig(out_file)


def prepped_array_to_df(data_array, dates, ids, col_names):
def prepped_array_to_df(data_array, dates, ids, col_names, spatial_idx_name='seg_id_nat', time_idx_name='date'):
"""
convert prepped x or y_dataset data in numpy array to pandas df
(reshape and make into pandas DFs)
Expand All @@ -125,7 +125,7 @@ def prepped_array_to_df(data_array, dates, ids, col_names):
dates = np.reshape(dates, [dates.shape[0] * dates.shape[1], dates.shape[2]])
ids = np.reshape(ids, [ids.shape[0] * ids.shape[1], ids.shape[2]])
df_preds = pd.DataFrame(data_array, columns=col_names)
df_dates = pd.DataFrame(dates, columns=["date"])
df_ids = pd.DataFrame(ids, columns=["seg_id_nat"])
df_dates = pd.DataFrame(dates, columns=[time_idx_name])
df_ids = pd.DataFrame(ids, columns=[spatial_idx_name])
df = pd.concat([df_dates, df_ids, df_preds], axis=1)
return df
97 changes: 26 additions & 71 deletions river_dl/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import xarray as xr
import datetime
from numpy.lib.npyio import NpzFile

from river_dl.RGCN import RGCNModel
from river_dl.postproc_utils import prepped_array_to_df
Expand All @@ -11,7 +12,18 @@
coord_as_reshaped_array,
)
from river_dl.rnns import LSTMModel, GRUModel
from river_dl.train import get_data_if_file


def get_data_if_file(d):
"""
rudimentary check if data .npz file is already loaded. if not, load it
:param d:
:return:
"""
if isinstance(d, NpzFile) or isinstance(d, dict):
return d
else:
return np.load(d, allow_pickle=True)


def unscale_output(y_scl, y_std, y_mean, y_vars, log_vars=None):
Expand All @@ -36,70 +48,32 @@ def unscale_output(y_scl, y_std, y_mean, y_vars, log_vars=None):
return y_unscaled


def load_model_from_weights(
model_type, model_weights_dir, hidden_size, dist_matrix=None, num_tasks=1,
):
"""
load a TF model from the model weights directory
:param model_type: [str] model to use either 'rgcn', 'lstm', or 'gru'
:param model_weights_dir: [str] directory to saved model weights
:param hidden_size: [int] the number of hidden units in model
:param dist_matrix: [np array] the distance matrix if using 'rgcn'
:param num_tasks: [int] number of tasks (variables_to_log to be predicted)
:return: TF model
"""
if model_type == "rgcn":
model = RGCNModel(hidden_size, A=dist_matrix, num_tasks=num_tasks)
elif model_type.startswith("lstm"):
model = LSTMModel(hidden_size, num_tasks=num_tasks)
elif model_type == "gru":
model = GRUModel(hidden_size, num_tasks=num_tasks)
else:
raise ValueError(
f'model_type must be "lstm", "gru" or "rgcn", (not {model_type})'
)

model.load_weights(model_weights_dir)
return model


def predict_from_io_data(
model_type,
model_weights_dir,
hidden_size,
model,
io_data,
partition,
outfile,
log_vars=False,
num_tasks=1,
trn_offset = 1.0,
tst_val_offset = 1.0,
spatial_idx_name="seg_id_nat",
time_idx_name="date"
):
"""
make predictions from trained model
:param model_type: [str] model to use either 'rgcn', 'lstm', or 'gru'
:param model_weights_dir: [str] directory to saved model weights
:param io_data: [str] directory to prepped data file
:param hidden_size: [int] the number of hidden units in model
:param partition: [str] must be 'trn' or 'tst'; whether you want to predict
for the train or the dev period
:param outfile: [str] the file where the output data should be stored
:param log_vars: [list-like] which variables_to_log (if any) were logged in data
prep
:param num_tasks: [int] number of tasks (variables_to_log to be predicted)
:param trn_offset: [str] value for the training offset
:param tst_val_offset: [str] value for the testing and validation offset
:return: [pd dataframe] predictions
"""
io_data = get_data_if_file(io_data)
model = load_model_from_weights(
model_type,
model_weights_dir,
hidden_size,
io_data.get("dist_matrix"),
num_tasks=num_tasks,
)

if partition == "trn":
keep_frac = trn_offset
else:
Expand All @@ -116,6 +90,8 @@ def predict_from_io_data(
keep_last_frac=keep_frac,
outfile=outfile,
log_vars=log_vars,
spatial_idx_name=spatial_idx_name,
time_idx_name=time_idx_name
)
return preds

Expand All @@ -131,10 +107,12 @@ def predict(
keep_last_frac=1.0,
outfile=None,
log_vars=False,
spatial_idx_name="seg_id_nat",
time_idx_name="date"
):
"""
use trained model to make predictions
:param model: the trained TF model
:param model: [tf model] trained TF model to use for predictions
:param x_data: [np array] numpy array of scaled and centered x_data
:param pred_ids: [np array] the ids of the segments (same shape as x_data)
:param pred_dates: [np array] the dates of the segments (same shape as
Expand All @@ -159,7 +137,7 @@ def predict(
pred_ids = pred_ids[:, frac_seq_len:, :]
pred_dates = pred_dates[:, frac_seq_len:, :]

y_pred_pp = prepped_array_to_df(y_pred, pred_dates, pred_ids, y_vars,)
y_pred_pp = prepped_array_to_df(y_pred, pred_dates, pred_ids, y_vars, spatial_idx_name, time_idx_name)

y_pred_pp = unscale_output(y_pred_pp, y_stds, y_means, y_vars, log_vars,)

Expand Down Expand Up @@ -290,6 +268,8 @@ def predict_one_date_range(
train_io_data["y_obs_vars"],
keep_last_frac=keep_last_frac,
log_vars=log_vars,
spatial_idx_name=spatial_idx_name,
time_idx_name=time_idx_name
)
return predictions

Expand All @@ -299,15 +279,11 @@ def predict_from_arbitrary_data(
pred_start_date,
pred_end_date,
train_io_data,
model_weights_dir,
model_type,
hidden_size,
model,
spatial_idx_name="seg_id_nat",
time_idx_name="date",
seq_len=365,
dist_matrix=None,
log_vars=None,
num_tasks=1,
):
"""
make predictions given raw data that is potentially independent from the
Expand All @@ -321,17 +297,12 @@ def predict_from_arbitrary_data(
that was used to train the model. This file must contain the variables_to_log
names, the standard deviations, and the means of the X and Y variables_to_log. Only
in with this information can the model be used properly
:param model_weights_dir: [str] path to the directory where the TF model
weights are stored
:param model_type: [str] model to use either 'rgcn', 'lstm', or 'gru'
:param hidden_size: [int] the number of hidden units in model
:param model: [tf model] model to use for predictions
:param spatial_idx_name: [str] name of column that is used for spatial
index (e.g., 'seg_id_nat')
:param time_idx_name: [str] name of column that is used for temporal index
(usually 'time')
:param seq_len: [int] length of input sequences given to model
:param dist_matrix: [np array] the distance matrix if using 'rgcn'. if not
provided, will look for it in the "train_io_data" file.
:param flow_in_temp: [bool] whether the flow should be an input into temp
for the rgcn model
:param log_vars: [list-like] which variables_to_log (if any) were logged in data
Expand All @@ -340,22 +311,6 @@ def predict_from_arbitrary_data(
"""
train_io_data = get_data_if_file(train_io_data)

if model_type == "rgcn":
if not dist_matrix:
dist_matrix = train_io_data.get("dist_matrix")
if not isinstance(dist_matrix, np.ndarray):
raise ValueError(
"model type is 'rgcn', but there is no" "distance matrix"
)

model = load_model_from_weights(
model_type,
model_weights_dir,
hidden_size,
dist_matrix,
num_tasks=num_tasks,
)

ds = xr.open_zarr(raw_data_file)

ds_x = ds[train_io_data["x_cols"]]
Expand Down

0 comments on commit 8ffe6d0

Please sign in to comment.