Skip to content

Commit

Permalink
Merge pull request USGS-R#40 from jsadler2/28-baseline-lstm-model
Browse files Browse the repository at this point in the history
28 baseline lstm model
  • Loading branch information
jsadler2 authored Feb 14, 2022
2 parents 4ef8ad5 + 3f367b8 commit cee3da5
Show file tree
Hide file tree
Showing 9 changed files with 500 additions and 17 deletions.
16 changes: 14 additions & 2 deletions 2_process.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@ p2_targets_list <- list(
p2_daily_with_seg_ids,
{
seg_and_site_ids <- p2_sites_w_segs %>% select(site_id, segidnat)
left_join(p2_daily_combined, seg_and_site_ids, by=c("site_no" = "site_id"))
left_join(p2_daily_combined, seg_and_site_ids, by=c("site_no" = "site_id")) %>%
rename(site_id = site_no,
date = Date,
do_mean = Value,
do_min = Value_Min,
do_max = Value_Max
)
}
),

Expand All @@ -80,7 +86,13 @@ p2_targets_list <- list(
p2_daily_with_seg_ids_csv,
write_to_csv(p2_daily_with_seg_ids, "2_process/out/daily_do_data.csv"),
format = "file"
)
),

# make list of "well-observed" sites
tar_target(
p2_well_observed_sites,
p2_sites_w_segs %>% filter(count_days_total > 300) %>% pull(site_id)
)


)
96 changes: 96 additions & 0 deletions 2a_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
source("2a_model/src/model_ready_data_utils.R")

p2a_targets_list <- list(

## PREPARE (RENAME, JOIN) INPUT AND OUTPUT FILES ##
# match to site_ids to seg_ids
tar_target(
p2a_met_data_w_sites,
match_site_ids_to_segs(p1_prms_met_data, p2_sites_w_segs)
),

# match seg attributes with site_ids
tar_target(
p2a_seg_attr_w_sites,
match_site_ids_to_segs(p1_seg_attr_data, p2_sites_w_segs)
),

## SPLIT SITES INTO (train) and (train and validation) ##
# char vector of well-observed train sites
tar_target(
p2a_trn_sites,
p2_well_observed_sites[!(p2_well_observed_sites %in% val_sites) & !(p2_well_observed_sites %in% tst_sites)]
),

# char vector of well-observed val and training sites
tar_target(
p2a_trn_val_sites,
p2_well_observed_sites[(p2_well_observed_sites %in% p2a_trn_sites) | (p2_well_observed_sites %in% val_sites)]
),

# get sites that we use for trning, but also have data in the val time period
tar_target(
p2a_trn_sites_w_val_data,
p2_daily_with_seg_ids %>%
filter(site_id %in% p2a_trn_val_sites,
!site_id %in% val_sites,
date >= val_start_date,
date < val_end_date) %>%
group_by(site_id) %>%
summarise(val_count = sum(!is.na(do_mean))) %>%
filter(val_count > 0) %>%
pull(site_id)
),

# sites that are trning sites but do not have data in val period
tar_target(
p2a_trn_only,
p2a_trn_sites[!p2a_trn_sites %in% p2a_trn_sites_w_val_data]
),


## WRITE OUT PARTITION INPUT AND OUTPUT DATA ##
# write trn met and seg attribute data to zarr
# note - I have to subset before passing to subset_and_write_zarr or else I
# get a memory error on the join
tar_target(
p2a_trn_inputs_zarr,
{
trn_input <- p2a_met_data_w_sites %>%
filter(site_id %in% p2a_trn_sites) %>%
inner_join(p2a_seg_attr_w_sites, by = "site_id")
subset_and_write_zarr(trn_input, "2a_model/out/well_observed_trn_inputs.zarr")
},
format="file"
),

# write trn and val met and seg attribute data to zarr
# note - I have to subset before passing to subset_and_write_zarr or else I
# get a memory error on the join
tar_target(
p2a_trn_val_inputs_zarr,
{
trn_input <- p2a_met_data_w_sites %>%
filter(site_id %in% p2a_trn_val_sites) %>%
inner_join(p2a_seg_attr_w_sites, by = "site_id")
subset_and_write_zarr(trn_input, "2a_model/out/well_observed_trn_inputs.zarr")
},
format="file"
),


# write trn do data to zarr
tar_target(
p2a_trn_do_zarr,
subset_and_write_zarr(p2_daily_with_seg_ids, "2a_model/out/well_observed_trn_do.zarr", p2a_trn_sites),
format="file"
),

# write trn and val do data to zarr
tar_target(
p2a_trn_val_do_zarr,
subset_and_write_zarr(p2_daily_with_seg_ids, "2a_model/out/well_observed_trn_do.zarr", p2a_trn_val_sites),
format="file"
)

)
57 changes: 57 additions & 0 deletions 2a_model/src/model_ready_data_utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

match_site_ids_to_segs <-
function(seg_data, sites_w_segs) {
#'
#' @description match site ids to segment data (e.g., met or attributes)
#'
#' @param seg_data a data frame of meterological data with column 'seg_id_nat'
#' @param sites_w_segs a dataframe with both segment ids ('segidnat') and site ids ('site_id')
#'
#' @value A data frame of seg data with site ids

seg_data <- seg_data %>%
left_join(.,sites_w_segs[,c("site_id","segidnat")],
by=c("seg_id_nat" = "segidnat"))
return(seg_data)
}

write_df_to_zarr <- function(df, index_cols, out_zarr) {
#'
#' @description use reticulate to write an R data frame to a Zarr data store (the file format river-dl currently takes)
#'
#' @param df a data frame of data
#' @param index vector of strings - the column(s) that should be the index
#' @param out_zarr where the zarr data will be written
#'
#' @value the out_zarr path

# convert to a python (pandas) DataFrame so we have access to the object methods (set_index and to_xarray)
py_df <- reticulate::r_to_py(df)

# set the index so that when we convert to an xarray dataset it is indexed properly
py_df <- py_df$set_index(index_cols)

# convert to an xarray dataset
ds <- py_df$to_xarray()

ds$to_zarr(out_zarr, mode = 'w')

return(out_zarr)

}


subset_and_write_zarr <- function(df, out_zarr, sites_subset = NULL){
#' @description write out to zarr and optionally take a subset. This assumes your zarr index
#' names will be "site_id" and "date"
#'
#' @param df a data frame of data
#' @param out_zarr where the zarr data will be written
#' @param sites_subset - character vector of sites to subset to
#'
#' @value the out_zarr path
if (!is.null(sites_subset)){
df <- df %>% filter(site_id %in% sites_subset)
}
write_df_to_zarr(df, c("site_id", "date"), out_zarr)
}
206 changes: 206 additions & 0 deletions 2a_model/src/models/0_baseline_LSTM/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import os
import tensorflow as tf
import numpy as np
import pandas as pd

from river_dl.preproc_utils import asRunConfig
from river_dl.preproc_utils import prep_all_data
from river_dl.evaluate import combined_metrics
from river_dl.postproc_utils import plot_obs, plot_ts
from river_dl.predict import predict_from_arbitrary_data
from river_dl.train import train_model
from river_dl import loss_functions as lf
from model import LSTMModel

out_dir = os.path.join(config['out_dir'], config['exp_name'])
loss_function = lf.multitask_rmse(config['lambdas'])

rule all:
input:
expand("{outdir}/{metric_type}_metrics.csv",
outdir=out_dir,
metric_type=['overall', 'reach'],
)


rule as_run_config:
output:
"{outdir}/asRunConfig.yml"
run:
asRunConfig(config,output[0])


rule prep_io_data:
input:
"../../../out/well_observed_train_inputs.zarr",
"../../../out/well_observed_train_do.zarr",
output:
"{outdir}/prepped.npz"
run:
prep_all_data(
x_data_file=input[0],
y_data_file=input[1],
x_vars=config['x_vars'],
y_vars_finetune=config['y_vars'],
spatial_idx_name='site_id',
time_idx_name='date',
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
val_end_date=config['val_end_date'],
test_start_date=config['test_start_date'],
test_end_date=config['test_end_date'],
out_file=output[0],
trn_offset = config['trn_offset'],
tst_val_offset = config['tst_val_offset'])


model = LSTMModel(
config['hidden_size'],
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)


# Finetune/train the model on observations
rule train:
input:
"{outdir}/prepped.npz",
output:
directory("{outdir}/train_weights/"),
#directory("{outdir}/best_val_weights/"),
"{outdir}/train_log.csv",
"{outdir}/train_time.txt",
run:
optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate'])
model.compile(optimizer=optimizer, loss=loss_function)
data = np.load(input[0], allow_pickle=True)
nsegs = len(np.unique(data["ids_trn"]))
train_model(model,
x_trn = data['x_trn'],
y_trn = data['y_obs_trn'],
epochs = config['pt_epochs'],
batch_size = nsegs,
x_val = data['x_val'],
y_val = data['y_obs_val'],
# I need to add a trailing slash here. Otherwise the wgts
# get saved in the "outdir"
weight_dir = output[0] + "/",
#best_val_weight_dir = output[1] + "/",
log_file = output[1],
time_file = output[2],
early_stop_patience=config['early_stopping'])


rule make_predictions:
input:
"{outdir}/train_weights/",
"../../../out/well_observed_train_val_inputs.zarr",
"{outdir}/prepped.npz",
output:
"{outdir}/preds.feather",
run:
weight_dir = input[0] + "/"
model.load_weights(weight_dir)
preds = predict_from_arbitrary_data(raw_data_file=input[1],
pred_start_date="1980-01-01",
pred_end_date="2019-01-01",
train_io_data=input[2],
model=model,
spatial_idx_name='site_id',
time_idx_name='date')
preds.reset_index(drop=True).to_feather(output[0])


def filter_predictions(all_preds_file, partition, out_file):
df_preds = pd.read_feather(all_preds_file)
all_sites = df_preds.site_id.unique()
trn_sites = all_sites[(~np.isin(all_sites, config["validation_sites"])) &
(~np.isin(all_sites, config["test_sites"]))]

df_preds_trn_sites = df_preds[df_preds.site_id.isin(trn_sites)]

df_preds_val_sites = df_preds[df_preds.site_id.isin(config['validation_sites'])]


if partition == "train":
df_preds_filt = df_preds_trn_sites[(df_preds_trn_sites.date >= config['train_start_date'][0]) &
(df_preds_trn_sites.date < config['train_end_date'][0])]
elif partition == "val":
# get all of the data in the validation sites and in the validation period
# this assumes that the test period follows the validation period which follows the train period
df_preds_filt_val = df_preds_val_sites[df_preds_val_sites.date < config['test_start_date'][0]]
df_preds_filt_trn = df_preds_trn_sites[(df_preds_trn_sites.date < config['val_end_date'][0]) &
(df_preds_trn_sites.date >= config['val_start_date'][0])]
df_preds_filt = pd.concat([df_preds_filt_val , df_preds_filt_trn], axis=0)

elif partition == "val_times":
# get the data in just the validation times at train and val sites
df_preds_filt_val = df_preds_val_sites[(df_preds_val_sites.date < config['val_end_date'][0]) &
(df_preds_val_sites.date >= config['val_start_date'][0])]
df_preds_filt_trn = df_preds_trn_sites[(df_preds_trn_sites.date < config['val_end_date'][0]) &
(df_preds_trn_sites.date >= config['val_start_date'][0])]
df_preds_filt = pd.concat([df_preds_filt_val , df_preds_filt_trn], axis=0)


df_preds_filt.reset_index(drop=True).to_feather(out_file)




rule make_filtered_predictions:
input:
"{outdir}/preds.feather"
output:
"{outdir}/{partition}_preds.feather"
run:
filter_predictions(input[0], wildcards.partition, output[0])


def get_grp_arg(wildcards):
if wildcards.metric_type == 'overall':
return None
elif wildcards.metric_type == 'month':
return 'month'
elif wildcards.metric_type == 'reach':
return 'seg_id_nat'
elif wildcards.metric_type == 'month_reach':
return ['seg_id_nat', 'month']


rule combine_metrics:
input:
"../../../out/well_observed_train_val_do.zarr",
"{outdir}/trn_preds.feather",
"{outdir}/val_preds.feather",
"{outdir}/val_times_preds.feather"
output:
"{outdir}/{metric_type}_metrics.csv"
params:
grp_arg = get_grp_arg
run:
combined_metrics(obs_file=input[0],
pred_data = {"train": input[1],
"val": input[2],
"val_times": input[3]},
spatial_idx_name='site_id',
time_idx_name='date',
group=params.grp_arg,
outfile=output[0])


rule plot_prepped_data:
input:
"{outdir}/prepped.npz",
output:
"{outdir}/{variable}_part_{partition}.png",
run:
plot_obs(input[0],
wildcards.variable,
output[0],
spatial_idx_name="site_id",
time_idx_name="date",
partition=wildcards.partition)


Loading

0 comments on commit cee3da5

Please sign in to comment.