-
Notifications
You must be signed in to change notification settings - Fork 14
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I think this is much cleaner and makes some great progress towards having a more modular, generic workflow. The thing that stands out most to me is that since this is more modular, it puts the onus on the user to configure the Snakefile
rather than just setting some arguments in the config.yml
and hitting go. The only problem with that is that the Snakefile
and config.yml
change regularly depending on who last did a PR (i.e. they aren't really canonical in the repository). What do you think about adding some language to that extent in the readme
, and then maybe adding a folder with some example Snakefile
/config
pairs for specific use cases as examples?
# Pretrain the model on process based model | ||
rule pre_train: | ||
input: | ||
"{outdir}/prepped.npz" | ||
output: | ||
directory("{outdir}/pretrained_weights/"), | ||
touch("{outdir}/pretrained_weights/pretrain.done") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens here if you don't want to pre-train? This touch
call just made it so you could set pretraining to zero but not break the pipeline. Are you thinking that that's a use-case scenario where folks should just write their own snakefile
that doesn't include pre-training? This relates to a larger discussion I've had with Janet where we've chatted about creating a handful of example Snakefiles (e.g. baseline run, running replicates, with and without pre-training, etc) and explicitely stating that the config.yml
and Snakefile
in the repo aren't canonical and should only be used as reference. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I was thinking is that if someone didn't want to pretrain, they would just nix that from their Snakefile. So I think that what you are saying about having example Snakefile
/config.yml
files instead of canonical ones is exactly what I was thinking.
I'm thinking that those can maybe go in their own directory. I'll adjust that.
params: | ||
# getting the base path to put the training outputs in | ||
# I omit the last slash (hence '[:-1]' so the split works properly | ||
run_dir=lambda wildcards, output: os.path.split(output[0][:-1])[0], | ||
weight_dir=lambda wildcards, output: os.path.split(output[0][:-1])[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not new to this PR, but since the wildcard {outdir}
gets defined in the inputs/outputs, can't you just pass wildcards.outdir
to the function rather than creating the parameter?
hidden_size=config['hidden_size'], io_data=input[1], | ||
partition=wildcards.partition, outfile=output[0], | ||
num_tasks=len(config['y_vars_finetune']), | ||
weight_dir = input[0] + '/' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should probably remain an option in config.yml
(whether you want to use the final fine-tune weights or the early stopping weights). Seems like an easy thing to overlook in the Snakefile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that keeping this in the config file is a good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's a good idea too. I'll make sure that's in one of the examples.
num_tasks=len(config['y_vars_finetune']), | ||
weight_dir = input[0] + '/' | ||
model.load_weights(weight_dir) | ||
predict_from_io_data(model=model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So much cleaner just to pass a compiled model into here!!!!!
y_val_obs = np.concatenate( | ||
[io_data["y_obs_val"], io_data["GW_val_reshape"], air_val], axis=2 | ||
) | ||
# Run the finetuning within the training engine on CPU for the GW loss function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I missing something or do you to pass the use_cpu
argument here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I just missed this. Thanks
|
||
#Choose whether to use the final weights from the end of training ('trained_weights') or the weights from the best | ||
# validation epoch ('best_val_weights') | ||
pred_weights: 'best_val_weights' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, maybe this isn't the best way to do it, but I think if you define early stopping in the config then it makes sense to be able to point to the early stopping vs finetune weights in the config.
river_dl/postproc_utils.py
Outdated
dates = np.reshape(dates, [dates.shape[0] * dates.shape[1], dates.shape[2]]) | ||
ids = np.reshape(ids, [ids.shape[0] * ids.shape[1], ids.shape[2]]) | ||
df_preds = pd.DataFrame(data_array, columns=col_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, not specific to this PR, nor am I sure if it's more or less elegant, but in my adapted version of this I just use x.flatten()
for dates, ids, and preds. It just makes it agnostic to the shape of the inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh. Wow. That's so much simpler! Great idea.
river_dl/train.py
Outdated
|
||
# Initialize our model within the training engine | ||
engine = trainer(model, optimizer, loss_func, weights) | ||
print(best_val_weight_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might want to make this print statement a little more informative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha 😆. that was actually for debugging. glad you caught that.
print(best_val_weight_dir) |
river_dl/train.py
Outdated
if weight_dir: | ||
model.save_weights(weight_dir) | ||
|
||
# Save alternate weight file that saves the best validation weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this comment up to where you define the early stopping log directory.
Thanks for the review and thoughts, @SimonTopp. I think you hit the nail on the head here:
This change definitely puts more of the onus on the modeler/user. And I think it's still an open question whether that is the direction we want to go. For fun I just sketched out where along the "flexibility - helpfulness" spectrum I see this PR. I think the upside of flexibility we gain is pretty nice. True we are doing less of the work for a given user/modeler, that said, I think that sometimes we need the ability to make custom models and not having to figure out how to plug them into a rigid (but helpful when plugged in) code base can be pretty freeing. And knowing how to instantiate a model object is pretty powerful and actually not too hard. The one thing that gives me pause is needing to know better how to manipulate the
This could be the happy medium where we provide a more flexible tool but also show people how to use it for their own application. |
Snakefile_gw
Outdated
out_dir = config['out_dir'] | ||
code_dir = config['code_dir'] | ||
pred_weights = config['pred_weights'] | ||
out_dir = config['out_dir'] + "_gw" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for adding the gw tag to the output directory here? I might rather keep the directory naming in the config file since I'm already reviewing that before every run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Much better in the config file
Snakefile_gw
Outdated
out_file=output[0], | ||
reach_file= config['reach_attr_file']) | ||
out_file=output[0]) | ||
#reach_file= config['reach_attr_file']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the reach file coming out? we use it in gw_utils to flag the reaches that are known to be in / downstream of reserviors so we don't try to calculate the annual temperature signal properties on those reaches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and a lot of the changes I made were to make the testing of code changes self-contained. In the river_dl/tests/test_data/
directory we don't have a reach attributes file, but know that you say this, I think it makes sense to have one so that we can test that functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the risk of sounding like a dumb dumb, river_dl/tests/
has always been somewhat of an enigma to me. In looking at it now it seems very helpful and like I should probably use it more ;).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha. My intent was to let that have some test data and scripts to let us have a standard way of testing our code. I haven't really used it probably as much as I should either :).
Snakefile_gw
Outdated
pred_weights = config['pred_weights'] | ||
out_dir = config['out_dir'] + "_gw" | ||
#code_dir = config['code_dir'] | ||
#pred_weights = config['pred_weights'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flagging this so it can be adjusted to match the pred_weights in the main Snakefile (if that's edited to keep the option for training / best validation weights in the config file)
output: | ||
directory("{outdir}/trained_weights/"), | ||
directory("{outdir}/finetune_weights/"), | ||
directory("{outdir}/best_val_weights/"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flagging this so it can be adjusted to match the pred_weights in the main Snakefile (if that's edited to keep the option for training / best validation weights in the config file)
temp_air_index = np.where(io_data['x_vars'] == 'seg_tave_air')[0] | ||
air_unscaled = io_data['x_trn'][:, :, temp_air_index] * io_data['x_std'][temp_air_index] + \ | ||
io_data['x_mean'][temp_air_index] | ||
y_trn_obs = np.concatenate( | ||
[io_data["y_obs_trn"], io_data["GW_trn_reshape"], air_unscaled], axis=2 | ||
) | ||
air_val = io_data['x_val'][:, :, temp_air_index] * io_data['x_std'][temp_air_index] + io_data['x_mean'][ | ||
temp_air_index] | ||
y_val_obs = np.concatenate( | ||
[io_data["y_obs_val"], io_data["GW_val_reshape"], air_val], axis=2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should "io_data" be changed to data here? (per line 120)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah. Yes indeed. Good catch.
x_trn = data['x_pre_full'], | ||
y_trn = data['y_pre_full'], | ||
epochs = config['pt_epochs'], | ||
batch_size = 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this (batch_size=2) correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because I was using the data in river_dl/tests/test_data
which only has two sites
x_trn = data['x_trn'], | ||
y_trn = y_trn_obs, | ||
epochs = config['pt_epochs'], | ||
batch_size = 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is batch_size=2 correct? (same comment as on the Snakefile)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because I was using the data in river_dl/tests/test_data
which only has two sites
@@ -1,18 +1,13 @@ | |||
# Input files | |||
obs_file: "data_DRB/Obs_temp_flow_drb_full_no3558" | |||
sntemp_file: "data_DRB/sntemp_inputs_outputs_drb_full_no3558" | |||
dist_matrix_file: "data_DRB/distance_matrix_drb_full_no3558.npz" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the distance matrix taken out b/c this example config / Snakefile are using an lstm rather than the rgcn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. But since we are doing different examples, I'll add this back in.
@@ -79,54 +80,89 @@ rule prep_io_data: | |||
# """ | |||
|
|||
|
|||
model = LSTMModel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is anyone using this repo and using an lstm model? If not, would it make sense to have an rgcn as the example since that is being used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using an LSTM for the DO project
@janetrbarclay mentioned that it would be ideal if (paraphrasing here) if we can add the new functionality without losing the existing functionality. I think that is wise. I think there is a pretty straightforward way to do that. I will revise this PR to do that. |
@janetrbarclay and @SimonTopp - Thank you guys for your review comments. I think the biggest change is just to shift our thinking from the |
rm trainer class, model as input, files as input
this is to accommodate new train routine
also found bug in not passing spatial/time idx names all the way through
2628249
to
6157859
Compare
@janetrbarclay and @SimonTopp - this is ready for another look. I summarize the major changes as:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jsadler2, this was a huge lift and super well done!!! I think it manages to actually make the repository more modular while maintaining its value on the "helpfullness" dimension. All my comments are pretty minor, but let me know if you want to talk about any of them. I think it's ready to go after a couple small changes.
def load_model_from_weights( | ||
model_type, model_weights_dir, hidden_size, dist_matrix=None, num_tasks=1, | ||
): | ||
""" | ||
load a TF model from the model weights directory | ||
:param model_type: [str] model to use either 'rgcn', 'lstm', or 'gru' | ||
:param model_weights_dir: [str] directory to saved model weights | ||
:param hidden_size: [int] the number of hidden units in model | ||
:param dist_matrix: [np array] the distance matrix if using 'rgcn' | ||
:param num_tasks: [int] number of tasks (variables_to_log to be predicted) | ||
:return: TF model | ||
""" | ||
if model_type == "rgcn": | ||
model = RGCNModel(hidden_size, A=dist_matrix, num_tasks=num_tasks) | ||
elif model_type.startswith("lstm"): | ||
model = LSTMModel(hidden_size, num_tasks=num_tasks) | ||
elif model_type == "gru": | ||
model = GRUModel(hidden_size, num_tasks=num_tasks) | ||
else: | ||
raise ValueError( | ||
f'model_type must be "lstm", "gru" or "rgcn", (not {model_type})' | ||
) | ||
|
||
model.load_weights(model_weights_dir) | ||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't have a strong opinion here, but do you think it'd keep the Snakefile cleaner if we changed this to something like compile_model
and moved it to one of the utils files? That way in the Snakefile you could compile the model and optionally load weights in one line rather than a handful of lines. Probably only a marginal gain and potentially makes the workflow more opaque. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather not have a compile_model
function. I think that we'd be back to needing to maintain something that will be hard to actually maintain.
train_start_date: | ||
- '2003-09-15' | ||
train_end_date: | ||
- '2005-09-14' | ||
val_start_date: | ||
- '2005-09-14' | ||
val_end_date: | ||
- '2006-09-14' | ||
test_start_date: | ||
- '1980-10-01' | ||
test_end_date: | ||
- '1985-09-30' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somewhere (maybe in the readme), we should explicitly state the baseline run conditions we've all agreed upon across projects (partition years, segments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. Good idea.
obs_file: "../river_dl/tests/test_data/obs_temp_flow" | ||
sntemp_file: "../river_dl/tests/test_data/test_data" | ||
dist_matrix_file: "../river_dl/tests/test_data/test_dist_matrix.npz" | ||
code_dir: ".." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are two ways we could think of these examples.
- They are loose guidelines and we state explicitly in the
Readme
what aspects of them will likely change for individual runs (dates, data files, input vars), or - We make them as "out-of-the-box" as possible, meaning they have the most common input files and run conditions so users literally don't have to change anything.
Thoughts? Maybe some combination of the two is possible as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I much prefer 1. - loose guidelines.
** Warning: This is a breaking PR**
This moves the definition of the tensorflow model out of
train.py
. If agreed upon, this means you will no longer be able to pass a string as anmodel_type
argument (e.g.,'rgcn'
) to thetrain_model
function. You would instead pass a model object (river_dl.RGCN.RGCNModel
).Pros
The main pro is that
train.py
is much lighter and more flexible. This is mainly because we don't have to handle all of possiblemodel_type
arguments passed as strings viaif else
statements. For example, there are now no gw specific pieces intrain.py
We also leave the defining of "pretraining" and "finetuning" out of
train.py
. This means that you could have any number of training phases with different data/epochs/loss functions etc. This also means that you can define your model anywhere (e.g.,my_awesome_model.py
) import it into whatever file you are using to calltrain_model
, instantiate it, and pass the object intotrain_model
.Tradeoff
The tradeoff is that the model has to be defined and compiled with its loss and optimizer somewhere else - so the burden is more on the individual projects (e.g., in the Snakefile). I edited the
Snakefile
to show what that would look like there.Summary
This PR is intended to make
train.py
project agnostic; this means whatever callstrain.py
(e.g.,Snakefile
) has the burden of project-specific/model definitions.closes #146, #118