From 267e12de86d19f3c447a023f8192842ae1f0c31f Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Wed, 23 Feb 2022 10:47:36 -0600 Subject: [PATCH] [#45] add replicates to baseline model --- 2a_model/src/models/0_baseline_LSTM/Snakefile | 5 +- .../models/0_baseline_LSTM/analyze_states.smk | 61 +++++++++++++++---- .../src/models/0_baseline_LSTM/config.yml | 1 + 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/2a_model/src/models/0_baseline_LSTM/Snakefile b/2a_model/src/models/0_baseline_LSTM/Snakefile index 0ec50458..8b48256b 100644 --- a/2a_model/src/models/0_baseline_LSTM/Snakefile +++ b/2a_model/src/models/0_baseline_LSTM/Snakefile @@ -17,9 +17,10 @@ loss_function = lf.multitask_rmse(config['lambdas']) rule all: input: - expand("{outdir}/{metric_type}_metrics.csv", + expand("{outdir}/rep_{rep}/{metric_type}_metrics.csv", outdir=out_dir, metric_type=['overall', 'reach'], + rep=list(range(config['num_replicates'])), ) @@ -124,7 +125,7 @@ def filter_predictions(all_preds_file, partition, out_file): df_preds_val_sites = df_preds[df_preds.site_id.isin(config['validation_sites'])] - if partition == "train": + if partition == "trn": df_preds_filt = df_preds_trn_sites[(df_preds_trn_sites.date >= config['train_start_date'][0]) & (df_preds_trn_sites.date < config['train_end_date'][0])] elif partition == "val": diff --git a/2a_model/src/models/0_baseline_LSTM/analyze_states.smk b/2a_model/src/models/0_baseline_LSTM/analyze_states.smk index 5e31605d..56756b6e 100644 --- a/2a_model/src/models/0_baseline_LSTM/analyze_states.smk +++ b/2a_model/src/models/0_baseline_LSTM/analyze_states.smk @@ -1,29 +1,34 @@ +code_dir = '../river-dl' +import sys +sys.path.insert(0, code_dir) +# if using river_dl installed with pip this is not needed + from model import LSTMModelStates from river_dl.postproc_utils import prepped_array_to_df import numpy as np import matplotlib.pyplot as plt import pandas as pd -code_dir = '../river-dl' -# if using river_dl installed with pip this is not needed -import sys -sys.path.insert(0, code_dir) out_dir = "../../../out/models/0_baseline_LSTM/analyze_states" in_dir = "../../../out/models/0_baseline_LSTM" def get_site_ids(): - df = pd.read_csv(f"{in_dir}/reach_metrics.csv", dtype={"site_id": str}) + df = pd.read_csv(f"{in_dir}/rep_0/reach_metrics.csv", dtype={"site_id": str}) return df.site_id.unique() rule all: input: - expand("{outdir}/states_{trained_or_random}_{site_id}.png", + expand("{outdir}/rep_{rep}/states_{trained_or_random}_{site_id}.png", outdir=out_dir, + rep=list(range(6)), trained_or_random = ["trained", "random"], - site_id = get_site_ids()) + site_id = get_site_ids()), + expand("{outdir}/rep_{rep}/output_weights.jpg", + outdir=out_dir, + rep=list(range(6))), model = LSTMModelStates( @@ -36,10 +41,10 @@ model = LSTMModelStates( rule write_states: input: - f"{in_dir}/prepped.npz", - f"{in_dir}/train_weights/", + f"{in_dir}/rep_{{rep}}/prepped.npz", + f"{in_dir}/rep_{{rep}}/train_weights/", output: - "{outdir}/states_{trained_or_random}.csv" + "{outdir}/rep_{rep}/states_{trained_or_random}.csv" run: data = np.load(input[0], allow_pickle=True) if wildcards.trained_or_random == "trained": @@ -62,7 +67,41 @@ rule plot_states: df_site = df.query(f"site_id == '{wildcards.site_id}'") del df_site["site_id"] df_site = df_site.set_index("date") - df_site.plot(subplots=True, figsize=(8,10)) + axs = df_site.plot(subplots=True, figsize=(8,10)) + for ax in axs.flatten(): + ax.legend(loc = "upper left") + plt.suptitle(wildcards.site_id) plt.tight_layout() plt.savefig(output[0]) + +rule plot_output_weights: + input: + f"{in_dir}/rep_{{rep}}/prepped.npz", + f"{in_dir}/rep_{{rep}}/train_weights/", + output: + "{outdir}/rep_{rep}/output_weights.jpg" + run: + data = np.load(input[0], allow_pickle=True) + m = LSTMModelStates( + config['hidden_size'], + recurrent_dropout=config['recurrent_dropout'], + dropout=config['dropout'], + num_tasks=len(config['y_vars']) + ) + m.load_weights(input[1] + "/") + m(data['x_val']) + w = m.weights + ax = plt.imshow(w[3].numpy()) + fig = plt.gcf() + cbar = fig.colorbar(ax) + cbar.set_label('weight value') + ax = plt.gca() + ax.set_yticks(list(range(10))) + ax.set_yticklabels(f"h{i}" for i in range(10)) + ax.set_ylabel('hidden state') + ax.set_xticks(list(range(3))) + ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90) + ax.set_xlabel('output variable') + plt.tight_layout() + plt.savefig(output[0], bbox_inches='tight') diff --git a/2a_model/src/models/0_baseline_LSTM/config.yml b/2a_model/src/models/0_baseline_LSTM/config.yml index 411c1ae8..f950940a 100644 --- a/2a_model/src/models/0_baseline_LSTM/config.yml +++ b/2a_model/src/models/0_baseline_LSTM/config.yml @@ -6,6 +6,7 @@ x_vars: ['seg_ccov', 'seg_rain', 'seg_slope', 'seg_tave_air', 'hru_slope', 'hru_ seed: False #random seed for training False==No seed, otherwise specify the seed +num_replicates: 6 y_vars: ['do_min', 'do_mean', 'do_max']