Skip to content

Commit

Permalink
[USGS-R#54] compile metrics into one file
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Mar 16, 2022
1 parent 2bfdaa0 commit 9581a0d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 58 deletions.
79 changes: 51 additions & 28 deletions 2a_model/src/models/0_baseline_LSTM/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ loss_function = lf.multitask_rmse(config['lambdas'])

rule all:
input:
expand("{outdir}/rep_{rep}/{metric_type}_metrics.csv",
outdir=out_dir,
metric_type=['overall', 'reach'],
rep=list(range(config['num_replicates'])),
)
expand("{outdir}/exp_{metric_type}_metrics.csv",
outdir=out_dir,
metric_type=['overall', 'reach'])


rule as_run_config:
Expand All @@ -36,7 +34,7 @@ rule prep_io_data:
"../../../out/well_observed_train_inputs.zarr",
"../../../out/well_observed_train_do.zarr",
output:
"{outdir}/prepped.npz"
"{outdir}/nstates_{nstates}/rep_{rep}/prepped.npz"
run:
prep_all_data(
x_data_file=input[0],
Expand All @@ -56,24 +54,24 @@ rule prep_io_data:
tst_val_offset = config['tst_val_offset'])


model = LSTMModel(
config['hidden_size'],
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)


# Finetune/train the model on observations
rule train:
input:
"{outdir}/prepped.npz",
"{outdir}/nstates_{nstates}/rep_{rep}/prepped.npz"
output:
directory("{outdir}/train_weights/"),
directory("{outdir}/nstates_{nstates}/rep_{rep}/train_weights/"),
#directory("{outdir}/best_val_weights/"),
"{outdir}/train_log.csv",
"{outdir}/train_time.txt",
"{outdir}/nstates_{nstates}/rep_{rep}/train_log.csv",
"{outdir}/nstates_{nstates}/rep_{rep}/train_time.txt",
run:
model = LSTMModel(
int(wildcards.nstates),
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)

optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate'])
model.compile(optimizer=optimizer, loss=loss_function)
data = np.load(input[0], allow_pickle=True)
Expand All @@ -96,13 +94,21 @@ rule train:

rule make_predictions:
input:
"{outdir}/train_weights/",
"{outdir}/nstates_{nstates}/rep_{rep}/train_weights/",
"../../../out/well_observed_train_val_inputs.zarr",
"{outdir}/prepped.npz",
"{outdir}/nstates_{nstates}/rep_{rep}/prepped.npz",
output:
"{outdir}/preds.feather",
"{outdir}/nstates_{nstates}/rep_{rep}/preds.feather",
run:
weight_dir = input[0] + "/"

model = LSTMModel(
int(wildcards.nstates),
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)

model.load_weights(weight_dir)
preds = predict_from_arbitrary_data(raw_data_file=input[1],
pred_start_date="1980-01-01",
Expand Down Expand Up @@ -152,9 +158,9 @@ def filter_predictions(all_preds_file, partition, out_file):

rule make_filtered_predictions:
input:
"{outdir}/preds.feather"
"{outdir}/nstates_{nstates}/rep_{rep}/preds.feather"
output:
"{outdir}/{partition}_preds.feather"
"{outdir}/nstates_{nstates}/rep_{rep}/{partition}_preds.feather"
run:
filter_predictions(input[0], wildcards.partition, output[0])

Expand All @@ -173,11 +179,11 @@ def get_grp_arg(wildcards):
rule combine_metrics:
input:
"../../../out/well_observed_train_val_do.zarr",
"{outdir}/trn_preds.feather",
"{outdir}/val_preds.feather",
"{outdir}/val_times_preds.feather"
"{outdir}/nstates_{nstates}/rep_{rep}/trn_preds.feather",
"{outdir}/nstates_{nstates}/rep_{rep}/val_preds.feather",
"{outdir}/nstates_{nstates}/rep_{rep}/val_times_preds.feather"
output:
"{outdir}/{metric_type}_metrics.csv"
"{outdir}/nstates_{nstates}/rep_{rep}/{metric_type}_metrics.csv"
params:
grp_arg = get_grp_arg
run:
Expand All @@ -188,14 +194,31 @@ rule combine_metrics:
spatial_idx_name='site_id',
time_idx_name='date',
group=params.grp_arg,
id_dict={"nstates": wildcards.nstates,
"rep_id": wildcards.rep},
outfile=output[0])


rule exp_metrics:
input:
expand("{outdir}/nstates_{nstates}/rep_{rep}/{{metric_type}}_metrics.csv",
outdir=out_dir,
rep=list(range(config['num_replicates'])),
nstates=config['hidden_size'],
)
output:
"{outdir}/exp_{metric_type}_metrics.csv"
run:
all_df = pd.concat([pd.read_csv(met_file, dtype={"site_id": str}) for met_file in input])
all_df.to_csv(output[0], index=False)



rule plot_prepped_data:
input:
"{outdir}/prepped.npz",
"{outdir}/nstates_{nstates}/rep_{rep}/prepped.npz",
output:
"{outdir}/{variable}_part_{partition}.png",
"{outdir}/nstates_{nstates}/rep_{rep}/{variable}_part_{partition}.png",
run:
plot_obs(input[0],
wildcards.variable,
Expand Down
60 changes: 31 additions & 29 deletions 2a_model/src/models/0_baseline_LSTM/analyze_states.smk
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,60 @@ import matplotlib.pyplot as plt
import pandas as pd


out_dir = "../../../out/models/0_baseline_LSTM/analyze_states"
in_dir = "../../../out/models/0_baseline_LSTM"


def get_site_ids():
df = pd.read_csv(f"{in_dir}/rep_0/reach_metrics.csv", dtype={"site_id": str})
df = pd.read_csv(f"{in_dir}/nstates_10/rep_0/reach_metrics.csv", dtype={"site_id": str})
return df.site_id.unique()


rule all:
input:
expand("{outdir}/rep_{rep}/states_{trained_or_random}_{site_id}.png",
outdir=out_dir,
rep=list(range(6)),
expand("{outdir}/nstates_{nstates}/analyze_states/rep_{rep}/states_{trained_or_random}_{site_id}.png",
outdir=in_dir,
rep=list(range(config['num_replicates'])),
nstates=config['hidden_size'],
trained_or_random = ["trained", "random"],
site_id = get_site_ids()),
expand("{outdir}/rep_{rep}/output_weights.jpg",
outdir=out_dir,
rep=list(range(6))),


model = LSTMModelStates(
config['hidden_size'],
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)
expand("{outdir}/nstates_{nstates}/analyze_states/rep_{rep}/output_weights.jpg",
outdir=in_dir,
rep=list(range(config['num_replicates'])),
nstates=config['hidden_size'],
),


rule write_states:
input:
f"{in_dir}/rep_{{rep}}/prepped.npz",
f"{in_dir}/rep_{{rep}}/train_weights/",
f"{in_dir}/nstates_{{nstates}}/rep_{{rep}}/prepped.npz",
f"{in_dir}/nstates_{{nstates}}/rep_{{rep}}/train_weights/",
output:
"{outdir}/rep_{rep}/states_{trained_or_random}.csv"
"{outdir}/nstates_{nstates}/analyze_states/rep_{rep}/states_{trained_or_random}.csv"
run:
model = LSTMModelStates(
int(wildcards.nstates),
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)


data = np.load(input[0], allow_pickle=True)
if wildcards.trained_or_random == "trained":
model.load_weights(input[1] + "/")
states = model(data['x_val'])
states = model(data['x_val']).numpy()
states_df = prepped_array_to_df(states, data["times_val"], data["ids_val"],
col_names=[f"h{i}" for i in range(10)],
col_names=[f"h{i}" for i in range(int(wildcards.nstates))],
spatial_idx_name="site_id")
states_df["site_id"] = states_df["site_id"].astype(str)
states_df.to_csv(output[0], index=False)


rule plot_states:
input:
"{outdir}/states_{trained_or_random}.csv"
"{outdir}/nstates_{nstates}/analyze_states/rep_{rep}/states_{trained_or_random}.csv"
output:
"{outdir}/states_{trained_or_random}_{site_id}.png"
"{outdir}/nstates_{nstates}/analyze_states/rep_{rep}/states_{trained_or_random}_{site_id}.png"
run:
df = pd.read_csv(input[0], parse_dates=["date"], infer_datetime_format=True, dtype={"site_id": str})
df_site = df.query(f"site_id == '{wildcards.site_id}'")
Expand All @@ -77,14 +79,14 @@ rule plot_states:

rule plot_output_weights:
input:
f"{in_dir}/rep_{{rep}}/prepped.npz",
f"{in_dir}/rep_{{rep}}/train_weights/",
f"{in_dir}/nstates_{{nstates}}/rep_{{rep}}/prepped.npz",
f"{in_dir}/nstates_{{nstates}}/rep_{{rep}}/train_weights/",
output:
"{outdir}/rep_{rep}/output_weights.jpg"
"{outdir}/nstates_{nstates}/analyze_states/rep_{rep}/output_weights.jpg"
run:
data = np.load(input[0], allow_pickle=True)
m = LSTMModelStates(
config['hidden_size'],
int(wildcards.nstates),
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
Expand All @@ -97,8 +99,8 @@ rule plot_output_weights:
cbar = fig.colorbar(ax)
cbar.set_label('weight value')
ax = plt.gca()
ax.set_yticks(list(range(10)))
ax.set_yticklabels(f"h{i}" for i in range(10))
ax.set_yticks(list(range(int(wildcards.nstates))))
ax.set_yticklabels(f"h{i}" for i in range(int(wildcards.nstates)))
ax.set_ylabel('hidden state')
ax.set_xticks(list(range(3)))
ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90)
Expand Down
2 changes: 1 addition & 1 deletion 2a_model/src/models/0_baseline_LSTM/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ test_sites: ["01475530", "01475548"]


pt_epochs: 100
hidden_size: 10
hidden_size: [10, 5]

dropout: 0.2
recurrent_dropout: 0.2
Expand Down

0 comments on commit 9581a0d

Please sign in to comment.