Skip to content

Commit

Permalink
[USGS-R#146] gw modification to data load
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Jan 19, 2022
1 parent 4fd0e22 commit 3ed6b07
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions Snakefile_gw
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ from river_dl import loss_functions as lf
from river_dl.gw_utils import prep_annual_signal_data, calc_pred_ann_temp,calc_gw_metrics


out_dir = config['out_dir']
code_dir = config['code_dir']
pred_weights = config['pred_weights']
out_dir = config['out_dir'] + "_gw"
#code_dir = config['code_dir']
#pred_weights = config['pred_weights']
loss_function = lf.multitask_rmse(config['lambdas'])


Expand Down Expand Up @@ -116,18 +116,37 @@ def get_gw_loss(input_data, temp_var="temp_c"):
rule finetune_train:
input:
"{outdir}/prepped_withGW.npz",
"{outdir}/pretrained_weights/pretrain.done"
output:
directory("{outdir}/trained_weights/"),
directory("{outdir}/finetune_weights/"),
directory("{outdir}/best_val_weights/"),
params:
# getting the base path to put the training outputs in
# I omit the last slash (hence '[:-1]' so the split works properly
run_dir=lambda wildcards, output: os.path.split(output[0][:-1])[0],
run:
train_model(input[0],config['ft_epochs'],config['hidden_size'],loss_func=get_gw_loss(input[0]),
out_dir=params.run_dir,model_type='rgcn',num_tasks=len(config['y_vars_finetune']),
learning_rate=0.01, dropout = config['dropout'], recurrent_dropout=config['recurrent_dropout'],train_type='finetune',early_stop_patience=config['early_stopping'], seed = config['seed'])
data = np.load(input[0])
temp_air_index = np.where(io_data['x_vars'] == 'seg_tave_air')[0]
air_unscaled = io_data['x_trn'][:, :, temp_air_index] * io_data['x_std'][temp_air_index] + \
io_data['x_mean'][temp_air_index]
y_trn_obs = np.concatenate(
[io_data["y_obs_trn"], io_data["GW_trn_reshape"], air_unscaled], axis=2
)
air_val = io_data['x_val'][:, :, temp_air_index] * io_data['x_std'][temp_air_index] + io_data['x_mean'][
temp_air_index]
y_val_obs = np.concatenate(
[io_data["y_obs_val"], io_data["GW_val_reshape"], air_val], axis=2
)
# Run the finetuning within the training engine on CPU for the GW loss function
train_model(model,
x_trn = data['x_trn'],
y_trn = y_trn_obs,
epochs = config['pt_epochs'],
batch_size = 2,
x_val = data['x_val'],
y_val = y_val_obs,
# I need to add a trailing slash here. Otherwise the wgts
# get saved in the "outdir"
weight_dir = output[0] + "/",
best_val_weight_dir = output[1] + "/",
log_file = output[1],
time_file = output[2],
early_stop_patience=config['early_stopping'])


rule compile_pred_GW_stats:
Expand Down

0 comments on commit 3ed6b07

Please sign in to comment.