forked from USGS-R/drb-do-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request USGS-R#40 from jsadler2/28-baseline-lstm-model
28 baseline lstm model
- Loading branch information
Showing
9 changed files
with
500 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
source("2a_model/src/model_ready_data_utils.R") | ||
|
||
p2a_targets_list <- list( | ||
|
||
## PREPARE (RENAME, JOIN) INPUT AND OUTPUT FILES ## | ||
# match to site_ids to seg_ids | ||
tar_target( | ||
p2a_met_data_w_sites, | ||
match_site_ids_to_segs(p1_prms_met_data, p2_sites_w_segs) | ||
), | ||
|
||
# match seg attributes with site_ids | ||
tar_target( | ||
p2a_seg_attr_w_sites, | ||
match_site_ids_to_segs(p1_seg_attr_data, p2_sites_w_segs) | ||
), | ||
|
||
## SPLIT SITES INTO (train) and (train and validation) ## | ||
# char vector of well-observed train sites | ||
tar_target( | ||
p2a_trn_sites, | ||
p2_well_observed_sites[!(p2_well_observed_sites %in% val_sites) & !(p2_well_observed_sites %in% tst_sites)] | ||
), | ||
|
||
# char vector of well-observed val and training sites | ||
tar_target( | ||
p2a_trn_val_sites, | ||
p2_well_observed_sites[(p2_well_observed_sites %in% p2a_trn_sites) | (p2_well_observed_sites %in% val_sites)] | ||
), | ||
|
||
# get sites that we use for trning, but also have data in the val time period | ||
tar_target( | ||
p2a_trn_sites_w_val_data, | ||
p2_daily_with_seg_ids %>% | ||
filter(site_id %in% p2a_trn_val_sites, | ||
!site_id %in% val_sites, | ||
date >= val_start_date, | ||
date < val_end_date) %>% | ||
group_by(site_id) %>% | ||
summarise(val_count = sum(!is.na(do_mean))) %>% | ||
filter(val_count > 0) %>% | ||
pull(site_id) | ||
), | ||
|
||
# sites that are trning sites but do not have data in val period | ||
tar_target( | ||
p2a_trn_only, | ||
p2a_trn_sites[!p2a_trn_sites %in% p2a_trn_sites_w_val_data] | ||
), | ||
|
||
|
||
## WRITE OUT PARTITION INPUT AND OUTPUT DATA ## | ||
# write trn met and seg attribute data to zarr | ||
# note - I have to subset before passing to subset_and_write_zarr or else I | ||
# get a memory error on the join | ||
tar_target( | ||
p2a_trn_inputs_zarr, | ||
{ | ||
trn_input <- p2a_met_data_w_sites %>% | ||
filter(site_id %in% p2a_trn_sites) %>% | ||
inner_join(p2a_seg_attr_w_sites, by = "site_id") | ||
subset_and_write_zarr(trn_input, "2a_model/out/well_observed_trn_inputs.zarr") | ||
}, | ||
format="file" | ||
), | ||
|
||
# write trn and val met and seg attribute data to zarr | ||
# note - I have to subset before passing to subset_and_write_zarr or else I | ||
# get a memory error on the join | ||
tar_target( | ||
p2a_trn_val_inputs_zarr, | ||
{ | ||
trn_input <- p2a_met_data_w_sites %>% | ||
filter(site_id %in% p2a_trn_val_sites) %>% | ||
inner_join(p2a_seg_attr_w_sites, by = "site_id") | ||
subset_and_write_zarr(trn_input, "2a_model/out/well_observed_trn_inputs.zarr") | ||
}, | ||
format="file" | ||
), | ||
|
||
|
||
# write trn do data to zarr | ||
tar_target( | ||
p2a_trn_do_zarr, | ||
subset_and_write_zarr(p2_daily_with_seg_ids, "2a_model/out/well_observed_trn_do.zarr", p2a_trn_sites), | ||
format="file" | ||
), | ||
|
||
# write trn and val do data to zarr | ||
tar_target( | ||
p2a_trn_val_do_zarr, | ||
subset_and_write_zarr(p2_daily_with_seg_ids, "2a_model/out/well_observed_trn_do.zarr", p2a_trn_val_sites), | ||
format="file" | ||
) | ||
|
||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
|
||
match_site_ids_to_segs <- | ||
function(seg_data, sites_w_segs) { | ||
#' | ||
#' @description match site ids to segment data (e.g., met or attributes) | ||
#' | ||
#' @param seg_data a data frame of meterological data with column 'seg_id_nat' | ||
#' @param sites_w_segs a dataframe with both segment ids ('segidnat') and site ids ('site_id') | ||
#' | ||
#' @value A data frame of seg data with site ids | ||
|
||
seg_data <- seg_data %>% | ||
left_join(.,sites_w_segs[,c("site_id","segidnat")], | ||
by=c("seg_id_nat" = "segidnat")) | ||
return(seg_data) | ||
} | ||
|
||
write_df_to_zarr <- function(df, index_cols, out_zarr) { | ||
#' | ||
#' @description use reticulate to write an R data frame to a Zarr data store (the file format river-dl currently takes) | ||
#' | ||
#' @param df a data frame of data | ||
#' @param index vector of strings - the column(s) that should be the index | ||
#' @param out_zarr where the zarr data will be written | ||
#' | ||
#' @value the out_zarr path | ||
|
||
# convert to a python (pandas) DataFrame so we have access to the object methods (set_index and to_xarray) | ||
py_df <- reticulate::r_to_py(df) | ||
|
||
# set the index so that when we convert to an xarray dataset it is indexed properly | ||
py_df <- py_df$set_index(index_cols) | ||
|
||
# convert to an xarray dataset | ||
ds <- py_df$to_xarray() | ||
|
||
ds$to_zarr(out_zarr, mode = 'w') | ||
|
||
return(out_zarr) | ||
|
||
} | ||
|
||
|
||
subset_and_write_zarr <- function(df, out_zarr, sites_subset = NULL){ | ||
#' @description write out to zarr and optionally take a subset. This assumes your zarr index | ||
#' names will be "site_id" and "date" | ||
#' | ||
#' @param df a data frame of data | ||
#' @param out_zarr where the zarr data will be written | ||
#' @param sites_subset - character vector of sites to subset to | ||
#' | ||
#' @value the out_zarr path | ||
if (!is.null(sites_subset)){ | ||
df <- df %>% filter(site_id %in% sites_subset) | ||
} | ||
write_df_to_zarr(df, c("site_id", "date"), out_zarr) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import os | ||
import tensorflow as tf | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from river_dl.preproc_utils import asRunConfig | ||
from river_dl.preproc_utils import prep_all_data | ||
from river_dl.evaluate import combined_metrics | ||
from river_dl.postproc_utils import plot_obs, plot_ts | ||
from river_dl.predict import predict_from_arbitrary_data | ||
from river_dl.train import train_model | ||
from river_dl import loss_functions as lf | ||
from model import LSTMModel | ||
|
||
out_dir = os.path.join(config['out_dir'], config['exp_name']) | ||
loss_function = lf.multitask_rmse(config['lambdas']) | ||
|
||
rule all: | ||
input: | ||
expand("{outdir}/{metric_type}_metrics.csv", | ||
outdir=out_dir, | ||
metric_type=['overall', 'reach'], | ||
) | ||
|
||
|
||
rule as_run_config: | ||
output: | ||
"{outdir}/asRunConfig.yml" | ||
run: | ||
asRunConfig(config,output[0]) | ||
|
||
|
||
rule prep_io_data: | ||
input: | ||
"../../../out/well_observed_train_inputs.zarr", | ||
"../../../out/well_observed_train_do.zarr", | ||
output: | ||
"{outdir}/prepped.npz" | ||
run: | ||
prep_all_data( | ||
x_data_file=input[0], | ||
y_data_file=input[1], | ||
x_vars=config['x_vars'], | ||
y_vars_finetune=config['y_vars'], | ||
spatial_idx_name='site_id', | ||
time_idx_name='date', | ||
train_start_date=config['train_start_date'], | ||
train_end_date=config['train_end_date'], | ||
val_start_date=config['val_start_date'], | ||
val_end_date=config['val_end_date'], | ||
test_start_date=config['test_start_date'], | ||
test_end_date=config['test_end_date'], | ||
out_file=output[0], | ||
trn_offset = config['trn_offset'], | ||
tst_val_offset = config['tst_val_offset']) | ||
|
||
|
||
model = LSTMModel( | ||
config['hidden_size'], | ||
recurrent_dropout=config['recurrent_dropout'], | ||
dropout=config['dropout'], | ||
num_tasks=len(config['y_vars']) | ||
) | ||
|
||
|
||
# Finetune/train the model on observations | ||
rule train: | ||
input: | ||
"{outdir}/prepped.npz", | ||
output: | ||
directory("{outdir}/train_weights/"), | ||
#directory("{outdir}/best_val_weights/"), | ||
"{outdir}/train_log.csv", | ||
"{outdir}/train_time.txt", | ||
run: | ||
optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate']) | ||
model.compile(optimizer=optimizer, loss=loss_function) | ||
data = np.load(input[0], allow_pickle=True) | ||
nsegs = len(np.unique(data["ids_trn"])) | ||
train_model(model, | ||
x_trn = data['x_trn'], | ||
y_trn = data['y_obs_trn'], | ||
epochs = config['pt_epochs'], | ||
batch_size = nsegs, | ||
x_val = data['x_val'], | ||
y_val = data['y_obs_val'], | ||
# I need to add a trailing slash here. Otherwise the wgts | ||
# get saved in the "outdir" | ||
weight_dir = output[0] + "/", | ||
#best_val_weight_dir = output[1] + "/", | ||
log_file = output[1], | ||
time_file = output[2], | ||
early_stop_patience=config['early_stopping']) | ||
|
||
|
||
rule make_predictions: | ||
input: | ||
"{outdir}/train_weights/", | ||
"../../../out/well_observed_train_val_inputs.zarr", | ||
"{outdir}/prepped.npz", | ||
output: | ||
"{outdir}/preds.feather", | ||
run: | ||
weight_dir = input[0] + "/" | ||
model.load_weights(weight_dir) | ||
preds = predict_from_arbitrary_data(raw_data_file=input[1], | ||
pred_start_date="1980-01-01", | ||
pred_end_date="2019-01-01", | ||
train_io_data=input[2], | ||
model=model, | ||
spatial_idx_name='site_id', | ||
time_idx_name='date') | ||
preds.reset_index(drop=True).to_feather(output[0]) | ||
|
||
|
||
def filter_predictions(all_preds_file, partition, out_file): | ||
df_preds = pd.read_feather(all_preds_file) | ||
all_sites = df_preds.site_id.unique() | ||
trn_sites = all_sites[(~np.isin(all_sites, config["validation_sites"])) & | ||
(~np.isin(all_sites, config["test_sites"]))] | ||
|
||
df_preds_trn_sites = df_preds[df_preds.site_id.isin(trn_sites)] | ||
|
||
df_preds_val_sites = df_preds[df_preds.site_id.isin(config['validation_sites'])] | ||
|
||
|
||
if partition == "train": | ||
df_preds_filt = df_preds_trn_sites[(df_preds_trn_sites.date >= config['train_start_date'][0]) & | ||
(df_preds_trn_sites.date < config['train_end_date'][0])] | ||
elif partition == "val": | ||
# get all of the data in the validation sites and in the validation period | ||
# this assumes that the test period follows the validation period which follows the train period | ||
df_preds_filt_val = df_preds_val_sites[df_preds_val_sites.date < config['test_start_date'][0]] | ||
df_preds_filt_trn = df_preds_trn_sites[(df_preds_trn_sites.date < config['val_end_date'][0]) & | ||
(df_preds_trn_sites.date >= config['val_start_date'][0])] | ||
df_preds_filt = pd.concat([df_preds_filt_val , df_preds_filt_trn], axis=0) | ||
|
||
elif partition == "val_times": | ||
# get the data in just the validation times at train and val sites | ||
df_preds_filt_val = df_preds_val_sites[(df_preds_val_sites.date < config['val_end_date'][0]) & | ||
(df_preds_val_sites.date >= config['val_start_date'][0])] | ||
df_preds_filt_trn = df_preds_trn_sites[(df_preds_trn_sites.date < config['val_end_date'][0]) & | ||
(df_preds_trn_sites.date >= config['val_start_date'][0])] | ||
df_preds_filt = pd.concat([df_preds_filt_val , df_preds_filt_trn], axis=0) | ||
|
||
|
||
df_preds_filt.reset_index(drop=True).to_feather(out_file) | ||
|
||
|
||
|
||
|
||
rule make_filtered_predictions: | ||
input: | ||
"{outdir}/preds.feather" | ||
output: | ||
"{outdir}/{partition}_preds.feather" | ||
run: | ||
filter_predictions(input[0], wildcards.partition, output[0]) | ||
|
||
|
||
def get_grp_arg(wildcards): | ||
if wildcards.metric_type == 'overall': | ||
return None | ||
elif wildcards.metric_type == 'month': | ||
return 'month' | ||
elif wildcards.metric_type == 'reach': | ||
return 'seg_id_nat' | ||
elif wildcards.metric_type == 'month_reach': | ||
return ['seg_id_nat', 'month'] | ||
|
||
|
||
rule combine_metrics: | ||
input: | ||
"../../../out/well_observed_train_val_do.zarr", | ||
"{outdir}/trn_preds.feather", | ||
"{outdir}/val_preds.feather", | ||
"{outdir}/val_times_preds.feather" | ||
output: | ||
"{outdir}/{metric_type}_metrics.csv" | ||
params: | ||
grp_arg = get_grp_arg | ||
run: | ||
combined_metrics(obs_file=input[0], | ||
pred_data = {"train": input[1], | ||
"val": input[2], | ||
"val_times": input[3]}, | ||
spatial_idx_name='site_id', | ||
time_idx_name='date', | ||
group=params.grp_arg, | ||
outfile=output[0]) | ||
|
||
|
||
rule plot_prepped_data: | ||
input: | ||
"{outdir}/prepped.npz", | ||
output: | ||
"{outdir}/{variable}_part_{partition}.png", | ||
run: | ||
plot_obs(input[0], | ||
wildcards.variable, | ||
output[0], | ||
spatial_idx_name="site_id", | ||
time_idx_name="date", | ||
partition=wildcards.partition) | ||
|
||
|
Oops, something went wrong.