This repository has been archived by the owner on May 28, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
analysis of baseline states #53
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
25856e7
[#45] code for looking at model states
jsadler2 4bebd77
[#45] py code for plotting hidden states
jsadler2 c56f35e
[#45] add replicates to baseline model
jsadler2 08febe2
[#45], [#48] pull aux (flow, water temp) data
jsadler2 026781c
[#48] plot water temp vs do preds/obs
jsadler2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
code_dir = '../river-dl' | ||
import sys | ||
sys.path.insert(0, code_dir) | ||
# if using river_dl installed with pip this is not needed | ||
|
||
from model import LSTMModelStates | ||
from river_dl.postproc_utils import prepped_array_to_df | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
|
||
|
||
out_dir = "../../../out/models/0_baseline_LSTM/analyze_states" | ||
in_dir = "../../../out/models/0_baseline_LSTM" | ||
|
||
|
||
def get_site_ids(): | ||
df = pd.read_csv(f"{in_dir}/rep_0/reach_metrics.csv", dtype={"site_id": str}) | ||
return df.site_id.unique() | ||
|
||
|
||
rule all: | ||
input: | ||
expand("{outdir}/rep_{rep}/states_{trained_or_random}_{site_id}.png", | ||
outdir=out_dir, | ||
rep=list(range(6)), | ||
trained_or_random = ["trained", "random"], | ||
site_id = get_site_ids()), | ||
expand("{outdir}/rep_{rep}/output_weights.jpg", | ||
outdir=out_dir, | ||
rep=list(range(6))), | ||
|
||
|
||
model = LSTMModelStates( | ||
config['hidden_size'], | ||
recurrent_dropout=config['recurrent_dropout'], | ||
dropout=config['dropout'], | ||
num_tasks=len(config['y_vars']) | ||
) | ||
|
||
|
||
rule write_states: | ||
input: | ||
f"{in_dir}/rep_{{rep}}/prepped.npz", | ||
f"{in_dir}/rep_{{rep}}/train_weights/", | ||
output: | ||
"{outdir}/rep_{rep}/states_{trained_or_random}.csv" | ||
run: | ||
data = np.load(input[0], allow_pickle=True) | ||
if wildcards.trained_or_random == "trained": | ||
model.load_weights(input[1] + "/") | ||
states = model(data['x_val']) | ||
states_df = prepped_array_to_df(states, data["times_val"], data["ids_val"], | ||
col_names=[f"h{i}" for i in range(10)], | ||
spatial_idx_name="site_id") | ||
states_df["site_id"] = states_df["site_id"].astype(str) | ||
states_df.to_csv(output[0], index=False) | ||
|
||
|
||
rule plot_states: | ||
input: | ||
"{outdir}/states_{trained_or_random}.csv" | ||
output: | ||
"{outdir}/states_{trained_or_random}_{site_id}.png" | ||
run: | ||
df = pd.read_csv(input[0], parse_dates=["date"], infer_datetime_format=True, dtype={"site_id": str}) | ||
df_site = df.query(f"site_id == '{wildcards.site_id}'") | ||
del df_site["site_id"] | ||
df_site = df_site.set_index("date") | ||
axs = df_site.plot(subplots=True, figsize=(8,10)) | ||
for ax in axs.flatten(): | ||
ax.legend(loc = "upper left") | ||
plt.suptitle(wildcards.site_id) | ||
plt.tight_layout() | ||
plt.savefig(output[0]) | ||
|
||
|
||
rule plot_output_weights: | ||
input: | ||
f"{in_dir}/rep_{{rep}}/prepped.npz", | ||
f"{in_dir}/rep_{{rep}}/train_weights/", | ||
output: | ||
"{outdir}/rep_{rep}/output_weights.jpg" | ||
run: | ||
data = np.load(input[0], allow_pickle=True) | ||
m = LSTMModelStates( | ||
config['hidden_size'], | ||
recurrent_dropout=config['recurrent_dropout'], | ||
dropout=config['dropout'], | ||
num_tasks=len(config['y_vars']) | ||
) | ||
m.load_weights(input[1] + "/") | ||
m(data['x_val']) | ||
w = m.weights | ||
ax = plt.imshow(w[3].numpy()) | ||
fig = plt.gcf() | ||
cbar = fig.colorbar(ax) | ||
cbar.set_label('weight value') | ||
ax = plt.gca() | ||
ax.set_yticks(list(range(10))) | ||
ax.set_yticklabels(f"h{i}" for i in range(10)) | ||
ax.set_ylabel('hidden state') | ||
ax.set_xticks(list(range(3))) | ||
ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90) | ||
ax.set_xlabel('output variable') | ||
plt.tight_layout() | ||
plt.savefig(output[0], bbox_inches='tight') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# --- | ||
# jupyter: | ||
# jupytext: | ||
# formats: ipynb,py:percent | ||
# text_representation: | ||
# extension: .py | ||
# format_name: percent | ||
# format_version: '1.3' | ||
# jupytext_version: 1.13.7 | ||
# kernelspec: | ||
# display_name: Python 3 (ipykernel) | ||
# language: python | ||
# name: python3 | ||
# --- | ||
|
||
# %% | ||
import sys | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
# %% | ||
sys.path.insert(0, "../../2a_model/src/models/0_baseline_LSTM/") | ||
|
||
# %% | ||
from model import LSTMModel | ||
|
||
# %% | ||
m = LSTMModel(10, 3) | ||
|
||
# %% | ||
m.load_weights("../../2a_model/out/models/0_baseline_LSTM/train_weights/") | ||
|
||
# %% | ||
data = np.load("../../2a_model/out/models/0_baseline_LSTM/prepped.npz", allow_pickle=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also confused here as to what you are loading, are these input data to use for plotting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes these are input data. They are used to just run the model once so I can get the weight values. ... now that I think of it, I'm not 100% sure this step is necessary. |
||
|
||
# %% | ||
y = m(data['x_val']) | ||
|
||
# %% | ||
w = m.weights | ||
|
||
# %% | ||
ax = plt.imshow(w[3].numpy()) | ||
fig = plt.gcf() | ||
cbar = fig.colorbar(ax) | ||
cbar.set_label('weight value') | ||
ax = plt.gca() | ||
ax.set_yticks(list(range(10))) | ||
ax.set_yticklabels(f"h{i}" for i in range(10)) | ||
ax.set_ylabel('hidden state') | ||
ax.set_xticks(list(range(3))) | ||
ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90) | ||
ax.set_xlabel('output variable') | ||
plt.tight_layout() | ||
plt.savefig('../out/hidden_states/out_weights.jpg', bbox_inches='tight') | ||
|
||
# %% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# --- | ||
# jupyter: | ||
# jupytext: | ||
# formats: ipynb,py:percent | ||
# text_representation: | ||
# extension: .py | ||
# format_name: percent | ||
# format_version: '1.3' | ||
# jupytext_version: 1.13.7 | ||
# kernelspec: | ||
# display_name: Python 3 (ipykernel) | ||
# language: python | ||
# name: python3 | ||
# --- | ||
|
||
# %% | ||
import pandas as pd | ||
import xarray as xr | ||
import matplotlib.pyplot as plt | ||
|
||
# %% [markdown] | ||
# ## load states and aux data | ||
|
||
# %% | ||
df_states = pd.read_csv("../../2a_model/out/models/0_baseline_LSTM/analyze_states/rep_0/states_trained.csv", | ||
dtype={"site_id": str}, parse_dates=["date"], infer_datetime_format=True) | ||
|
||
# %% | ||
df_aux = pd.read_csv("../../1_fetch/out/daily_aux_data.csv", | ||
dtype={"site_no": str}, parse_dates=["Date"], infer_datetime_format=True) | ||
df_aux = df_aux.rename(columns={"site_no": "site_id", "Date":"date"}) | ||
|
||
# %% | ||
site_id = "01480870" | ||
|
||
# %% | ||
df_aux_site = df_aux.query(f"site_id == '{site_id}'").set_index('date') | ||
df_states_site = df_states.query(f"site_id == '{site_id}'").set_index('date') | ||
|
||
# %% [markdown] | ||
# ## load input data | ||
|
||
# %% | ||
ds = xr.open_zarr("../../2a_model/out/well_observed_train_val_inputs.zarr/", consolidated=False) | ||
|
||
# %% | ||
df_air_temp = ds.seg_tave_air.sel(site_id=site_id).to_dataframe() | ||
|
||
# %% | ||
del df_air_temp['site_id'] | ||
del df_aux_site['site_id'] | ||
del df_states_site['site_id'] | ||
|
||
# %% | ||
df_comb = df_states_site.join(df_aux_site).join(df_air_temp) | ||
|
||
# %% [markdown] | ||
# ___ | ||
|
||
# %% [markdown] | ||
# # Comparison with Flow | ||
|
||
# %% | ||
axs = df_comb.loc[:, df_comb.columns.str.startswith('h')].plot(subplots=True, figsize=(16,20)) | ||
axs = axs.ravel() | ||
for ax in axs: | ||
ax.legend(loc="upper left") | ||
ax_twin = ax.twinx() | ||
df_comb.Flow.plot(ax=ax_twin, color="black", alpha=0.6) | ||
ax_twin.set_ylabel('flow [cfs]') | ||
plt.tight_layout() | ||
plt.savefig("../out/states_with_flow.jpg") | ||
|
||
# %% | ||
axs = df_comb.loc[:, df_comb.columns.str.startswith('h0')].plot(subplots=True, figsize=(20,5)) | ||
axs = axs.ravel() | ||
for ax in axs: | ||
ax.legend(loc="upper left") | ||
ax_twin = ax.twinx() | ||
df_comb.Flow.plot(ax=ax_twin, color="darkgray") | ||
ax_twin.set_ylabel('flow [cfs]') | ||
|
||
|
||
# %% | ||
def plot_one_state_w_flow(df_comb, state, color): | ||
axs = df_comb.loc["2018", df_comb.columns.str.startswith(state)].plot(subplots=True, figsize=(20,5), | ||
color=color, fontsize=20) | ||
axs = axs.ravel() | ||
for ax in axs: | ||
ax.legend(loc="upper left", fontsize=20) | ||
ax_twin = ax.twinx() | ||
df_comb.loc["2018", "Flow"].plot(ax=ax_twin, color="black", alpha=0.6, fontsize=20) | ||
ax_twin.set_ylabel('flow [cfs]', fontsize=20) | ||
ax.set_xlabel('date', fontsize=20) | ||
plt.tight_layout() | ||
plt.savefig(f"../out/{state}_2018_w_flow.jpg") | ||
|
||
|
||
# %% | ||
plot_one_state_w_flow(df_comb, "h0", color="#1f77b4") | ||
|
||
# %% | ||
df_comb.plot.scatter('h0', 'Flow', alpha=0.5) | ||
plt.tight_layout() | ||
plt.savefig("../out/flow_h0_scatter.jpg") | ||
|
||
# %% | ||
plot_one_state_w_flow(df_comb, "h1", "#ff7f0e") | ||
|
||
# %% [markdown] | ||
# # Comparison with Temperature | ||
|
||
# %% | ||
axs = df_comb.loc[:, df_comb.columns.str.startswith('h')].plot(subplots=True, figsize=(16,20)) | ||
axs = axs.ravel() | ||
for ax in axs: | ||
ax.legend(loc="upper left") | ||
ax_twin = ax.twinx() | ||
df_comb.seg_tave_air.plot(ax=ax_twin, color="darkgray") | ||
ax_twin.set_ylabel('avg air temp [degC]') | ||
plt.tight_layout() | ||
plt.savefig("../out/states_w_air_temp.jpg") | ||
|
||
# %% | ||
df_comb.tail() | ||
|
||
# %% |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little confused where these weights come from
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are weights produced by the snakemake workflow:
drb-do-ml/2a_model/src/models/0_baseline_LSTM/Snakefile
Line 71 in cee3da5