Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

146 simplify training #148

Merged
merged 22 commits into from
Jan 24, 2022
Merged

146 simplify training #148

merged 22 commits into from
Jan 24, 2022

Conversation

jsadler2
Copy link
Collaborator

@jsadler2 jsadler2 commented Dec 6, 2021

** Warning: This is a breaking PR**

This moves the definition of the tensorflow model out of train.py. If agreed upon, this means you will no longer be able to pass a string as an model_type argument (e.g., 'rgcn') to the train_model function. You would instead pass a model object (river_dl.RGCN.RGCNModel).

Pros
The main pro is that train.py is much lighter and more flexible. This is mainly because we don't have to handle all of possible model_type arguments passed as strings via if else statements. For example, there are now no gw specific pieces in train.py

We also leave the defining of "pretraining" and "finetuning" out of train.py. This means that you could have any number of training phases with different data/epochs/loss functions etc. This also means that you can define your model anywhere (e.g., my_awesome_model.py) import it into whatever file you are using to call train_model, instantiate it, and pass the object into train_model.

Tradeoff
The tradeoff is that the model has to be defined and compiled with its loss and optimizer somewhere else - so the burden is more on the individual projects (e.g., in the Snakefile). I edited the Snakefile to show what that would look like there.


Summary
This PR is intended to make train.py project agnostic; this means whatever calls train.py (e.g., Snakefile) has the burden of project-specific/model definitions.

closes #146, #118

Copy link
Contributor

@SimonTopp SimonTopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this is much cleaner and makes some great progress towards having a more modular, generic workflow. The thing that stands out most to me is that since this is more modular, it puts the onus on the user to configure the Snakefile rather than just setting some arguments in the config.yml and hitting go. The only problem with that is that the Snakefile and config.yml change regularly depending on who last did a PR (i.e. they aren't really canonical in the repository). What do you think about adding some language to that extent in the readme, and then maybe adding a folder with some example Snakefile/config pairs for specific use cases as examples?

# Pretrain the model on process based model
rule pre_train:
input:
"{outdir}/prepped.npz"
output:
directory("{outdir}/pretrained_weights/"),
touch("{outdir}/pretrained_weights/pretrain.done")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here if you don't want to pre-train? This touch call just made it so you could set pretraining to zero but not break the pipeline. Are you thinking that that's a use-case scenario where folks should just write their own snakefile that doesn't include pre-training? This relates to a larger discussion I've had with Janet where we've chatted about creating a handful of example Snakefiles (e.g. baseline run, running replicates, with and without pre-training, etc) and explicitely stating that the config.yml and Snakefile in the repo aren't canonical and should only be used as reference. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I was thinking is that if someone didn't want to pretrain, they would just nix that from their Snakefile. So I think that what you are saying about having example Snakefile/config.yml files instead of canonical ones is exactly what I was thinking.

I'm thinking that those can maybe go in their own directory. I'll adjust that.

params:
# getting the base path to put the training outputs in
# I omit the last slash (hence '[:-1]' so the split works properly
run_dir=lambda wildcards, output: os.path.split(output[0][:-1])[0],
weight_dir=lambda wildcards, output: os.path.split(output[0][:-1])[0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not new to this PR, but since the wildcard {outdir} gets defined in the inputs/outputs, can't you just pass wildcards.outdir to the function rather than creating the parameter?

hidden_size=config['hidden_size'], io_data=input[1],
partition=wildcards.partition, outfile=output[0],
num_tasks=len(config['y_vars_finetune']),
weight_dir = input[0] + '/'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should probably remain an option in config.yml (whether you want to use the final fine-tune weights or the early stopping weights). Seems like an easy thing to overlook in the Snakefile.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that keeping this in the config file is a good idea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a good idea too. I'll make sure that's in one of the examples.

num_tasks=len(config['y_vars_finetune']),
weight_dir = input[0] + '/'
model.load_weights(weight_dir)
predict_from_io_data(model=model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So much cleaner just to pass a compiled model into here!!!!!

y_val_obs = np.concatenate(
[io_data["y_obs_val"], io_data["GW_val_reshape"], air_val], axis=2
)
# Run the finetuning within the training engine on CPU for the GW loss function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing something or do you to pass the use_cpu argument here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I just missed this. Thanks


#Choose whether to use the final weights from the end of training ('trained_weights') or the weights from the best
# validation epoch ('best_val_weights')
pred_weights: 'best_val_weights'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, maybe this isn't the best way to do it, but I think if you define early stopping in the config then it makes sense to be able to point to the early stopping vs finetune weights in the config.

Comment on lines 125 to 123
dates = np.reshape(dates, [dates.shape[0] * dates.shape[1], dates.shape[2]])
ids = np.reshape(ids, [ids.shape[0] * ids.shape[1], ids.shape[2]])
df_preds = pd.DataFrame(data_array, columns=col_names)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not specific to this PR, nor am I sure if it's more or less elegant, but in my adapted version of this I just use x.flatten() for dates, ids, and preds. It just makes it agnostic to the shape of the inputs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. Wow. That's so much simpler! Great idea.


# Initialize our model within the training engine
engine = trainer(model, optimizer, loss_func, weights)
print(best_val_weight_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to make this print statement a little more informative.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha 😆. that was actually for debugging. glad you caught that.

Suggested change
print(best_val_weight_dir)

if weight_dir:
model.save_weights(weight_dir)

# Save alternate weight file that saves the best validation weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this comment up to where you define the early stopping log directory.

@jsadler2
Copy link
Collaborator Author

jsadler2 commented Dec 7, 2021

Thanks for the review and thoughts, @SimonTopp. I think you hit the nail on the head here:

it puts the onus on the user to configure the Snakefile rather than just setting some arguments in the config.yml and hitting go.

This change definitely puts more of the onus on the modeler/user. And I think it's still an open question whether that is the direction we want to go.

For fun I just sketched out where along the "flexibility - helpfulness" spectrum I see this PR.

PXL_20211207_191839010 MP

I think the upside of flexibility we gain is pretty nice. True we are doing less of the work for a given user/modeler, that said, I think that sometimes we need the ability to make custom models and not having to figure out how to plug them into a rigid (but helpful when plugged in) code base can be pretty freeing. And knowing how to instantiate a model object is pretty powerful and actually not too hard.

The one thing that gives me pause is needing to know better how to manipulate the Snakefile. And maybe that's where you and Janet's idea comes in:

we've chatted about creating a handful of example Snakefiles (e.g. baseline run, running replicates, with and without pre-training, etc) and explicitely stating that the config.yml and Snakefile in the repo aren't canonical and should only be used as reference.

This could be the happy medium where we provide a more flexible tool but also show people how to use it for their own application.

Snakefile_gw Outdated
out_dir = config['out_dir']
code_dir = config['code_dir']
pred_weights = config['pred_weights']
out_dir = config['out_dir'] + "_gw"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for adding the gw tag to the output directory here? I might rather keep the directory naming in the config file since I'm already reviewing that before every run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Much better in the config file

Snakefile_gw Outdated
out_file=output[0],
reach_file= config['reach_attr_file'])
out_file=output[0])
#reach_file= config['reach_attr_file'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the reach file coming out? we use it in gw_utils to flag the reaches that are known to be in / downstream of reserviors so we don't try to calculate the annual temperature signal properties on those reaches.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and a lot of the changes I made were to make the testing of code changes self-contained. In the river_dl/tests/test_data/ directory we don't have a reach attributes file, but know that you say this, I think it makes sense to have one so that we can test that functionality.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the risk of sounding like a dumb dumb, river_dl/tests/ has always been somewhat of an enigma to me. In looking at it now it seems very helpful and like I should probably use it more ;).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha. My intent was to let that have some test data and scripts to let us have a standard way of testing our code. I haven't really used it probably as much as I should either :).

Snakefile_gw Outdated
pred_weights = config['pred_weights']
out_dir = config['out_dir'] + "_gw"
#code_dir = config['code_dir']
#pred_weights = config['pred_weights']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flagging this so it can be adjusted to match the pred_weights in the main Snakefile (if that's edited to keep the option for training / best validation weights in the config file)

output:
directory("{outdir}/trained_weights/"),
directory("{outdir}/finetune_weights/"),
directory("{outdir}/best_val_weights/"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flagging this so it can be adjusted to match the pred_weights in the main Snakefile (if that's edited to keep the option for training / best validation weights in the config file)

Comment on lines +121 to +133
temp_air_index = np.where(io_data['x_vars'] == 'seg_tave_air')[0]
air_unscaled = io_data['x_trn'][:, :, temp_air_index] * io_data['x_std'][temp_air_index] + \
io_data['x_mean'][temp_air_index]
y_trn_obs = np.concatenate(
[io_data["y_obs_trn"], io_data["GW_trn_reshape"], air_unscaled], axis=2
)
air_val = io_data['x_val'][:, :, temp_air_index] * io_data['x_std'][temp_air_index] + io_data['x_mean'][
temp_air_index]
y_val_obs = np.concatenate(
[io_data["y_obs_val"], io_data["GW_val_reshape"], air_val], axis=2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should "io_data" be changed to data here? (per line 120)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. Yes indeed. Good catch.

x_trn = data['x_pre_full'],
y_trn = data['y_pre_full'],
epochs = config['pt_epochs'],
batch_size = 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this (batch_size=2) correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because I was using the data in river_dl/tests/test_data which only has two sites

x_trn = data['x_trn'],
y_trn = y_trn_obs,
epochs = config['pt_epochs'],
batch_size = 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is batch_size=2 correct? (same comment as on the Snakefile)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because I was using the data in river_dl/tests/test_data which only has two sites

@@ -1,18 +1,13 @@
# Input files
obs_file: "data_DRB/Obs_temp_flow_drb_full_no3558"
sntemp_file: "data_DRB/sntemp_inputs_outputs_drb_full_no3558"
dist_matrix_file: "data_DRB/distance_matrix_drb_full_no3558.npz"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the distance matrix taken out b/c this example config / Snakefile are using an lstm rather than the rgcn?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. But since we are doing different examples, I'll add this back in.

@@ -79,54 +80,89 @@ rule prep_io_data:
# """


model = LSTMModel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is anyone using this repo and using an lstm model? If not, would it make sense to have an rgcn as the example since that is being used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using an LSTM for the DO project

@jsadler2
Copy link
Collaborator Author

@janetrbarclay mentioned that it would be ideal if (paraphrasing here) if we can add the new functionality without losing the existing functionality. I think that is wise. I think there is a pretty straightforward way to do that. I will revise this PR to do that.

@jsadler2
Copy link
Collaborator Author

jsadler2 commented Jan 4, 2022

@janetrbarclay and @SimonTopp - Thank you guys for your review comments. I think the biggest change is just to shift our thinking from the Snakefile/config.yml files from being "this is the way to use river-dl" to "this is an example for how one could use river-dl" but you will likely have to modify for your own purposes. I will make some edits to this PR to address that and your guys' comments.

@jsadler2 jsadler2 force-pushed the 146-simplify-training branch from 2628249 to 6157859 Compare January 20, 2022 21:38
@jsadler2
Copy link
Collaborator Author

@janetrbarclay and @SimonTopp - this is ready for another look. I summarize the major changes as:

  1. training and prediction functions takes a compiled tensorflow model
  2. the Snakefiles and config files are now in their own directory (workflow_examples/). Also in that directory is a readme that describes the different examples
  3. the asRunConfig function now takes the code directory as an argument (since we are no longer assuming the Snakefile one is using is located in the root river-dl directory.

Copy link
Contributor

@SimonTopp SimonTopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jsadler2, this was a huge lift and super well done!!! I think it manages to actually make the repository more modular while maintaining its value on the "helpfullness" dimension. All my comments are pretty minor, but let me know if you want to talk about any of them. I think it's ready to go after a couple small changes.
image

river_dl/postproc_utils.py Show resolved Hide resolved
Comment on lines -39 to -63
def load_model_from_weights(
model_type, model_weights_dir, hidden_size, dist_matrix=None, num_tasks=1,
):
"""
load a TF model from the model weights directory
:param model_type: [str] model to use either 'rgcn', 'lstm', or 'gru'
:param model_weights_dir: [str] directory to saved model weights
:param hidden_size: [int] the number of hidden units in model
:param dist_matrix: [np array] the distance matrix if using 'rgcn'
:param num_tasks: [int] number of tasks (variables_to_log to be predicted)
:return: TF model
"""
if model_type == "rgcn":
model = RGCNModel(hidden_size, A=dist_matrix, num_tasks=num_tasks)
elif model_type.startswith("lstm"):
model = LSTMModel(hidden_size, num_tasks=num_tasks)
elif model_type == "gru":
model = GRUModel(hidden_size, num_tasks=num_tasks)
else:
raise ValueError(
f'model_type must be "lstm", "gru" or "rgcn", (not {model_type})'
)

model.load_weights(model_weights_dir)
return model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have a strong opinion here, but do you think it'd keep the Snakefile cleaner if we changed this to something like compile_model and moved it to one of the utils files? That way in the Snakefile you could compile the model and optionally load weights in one line rather than a handful of lines. Probably only a marginal gain and potentially makes the workflow more opaque. Thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not have a compile_model function. I think that we'd be back to needing to maintain something that will be hard to actually maintain.

river_dl/tests/generate_test_data.py Show resolved Hide resolved
river_dl/train.py Outdated Show resolved Hide resolved
river_dl/train.py Outdated Show resolved Hide resolved
workflow_examples/Snakefile_rgcn.smk Show resolved Hide resolved
workflow_examples/Snakefile_rgcn.smk Outdated Show resolved Hide resolved
Comment on lines +23 to +34
train_start_date:
- '2003-09-15'
train_end_date:
- '2005-09-14'
val_start_date:
- '2005-09-14'
val_end_date:
- '2006-09-14'
test_start_date:
- '1980-10-01'
test_end_date:
- '1985-09-30'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere (maybe in the readme), we should explicitly state the baseline run conditions we've all agreed upon across projects (partition years, segments)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Good idea.

Comment on lines +2 to +5
obs_file: "../river_dl/tests/test_data/obs_temp_flow"
sntemp_file: "../river_dl/tests/test_data/test_data"
dist_matrix_file: "../river_dl/tests/test_data/test_dist_matrix.npz"
code_dir: ".."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are two ways we could think of these examples.

  1. They are loose guidelines and we state explicitly in the Readme what aspects of them will likely change for individual runs (dates, data files, input vars), or
  2. We make them as "out-of-the-box" as possible, meaning they have the most common input files and run conditions so users literally don't have to change anything.

Thoughts? Maybe some combination of the two is possible as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I much prefer 1. - loose guidelines.

workflow_examples/readme.md Show resolved Hide resolved
@jsadler2 jsadler2 merged commit dd3b84c into USGS-R:main Jan 24, 2022
@jsadler2 jsadler2 deleted the 146-simplify-training branch January 24, 2022 17:16
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Simplify training routine
3 participants