From 2c7580b94c9c602a74f2d79d5dd106f177b5398f Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Fri, 4 Jun 2021 11:53:14 -0500 Subject: [PATCH] [#106] provide loss_func to train func; compiles rnns --- river_dl/train.py | 42 ++++++++++++------------------------------ 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/river_dl/train.py b/river_dl/train.py index e03d5e5..3bfb551 100644 --- a/river_dl/train.py +++ b/river_dl/train.py @@ -5,8 +5,7 @@ import datetime import tensorflow as tf from river_dl.RGCN import RGCNModel -from river_dl.loss_functions import weighted_masked_rmse -from river_dl.rnns import SingletaskLSTMModel, MultitaskLSTMModel, SingletaskGRUModel, MultitaskGRUModel +from river_dl.rnns import LSTMModel, GRUModel def get_data_if_file(d): @@ -26,11 +25,12 @@ def train_model( pretrain_epochs, finetune_epochs, hidden_units, + loss_func, out_dir, model_type="rgcn", seed=None, dropout=0, - lambdas=(1, 1), + recurrent_dropout=0, num_tasks=1, learning_rate_pre=0.005, learning_rate_ft=0.01, @@ -41,12 +41,13 @@ def train_model( :param pretrain_epochs: [int] number of pretrain epochs :param finetune_epochs: [int] number of finetune epochs :param hidden_units: [int] number of hidden layers + :param loss_func: [function] loss function that the model will be fit to :param out_dir: [str] directory where the output files should be written :param model_type: [str] which model to use (either 'lstm', 'rgcn', or - 'lstm_grad_correction') + 'gru') :param seed: [int] random seed - :param lambdas: [array-like] weights to multiply the loss from each target - variable by + :param recurrent_dropout: [float] value between 0 and 1 for the probability of a reccurent element to be zero + :param dropout: [float] value between 0 and 1 for the probability of an input element to be zero :param num_tasks: [int] number of tasks (outputs to be predicted) :param learning_rate_pre: [float] the pretrain learning rate :param learning_rate_ft: [float] the finetune learning rate @@ -69,31 +70,18 @@ def train_model( batch_size = num_years if model_type == "lstm": - if num_tasks == 1: - model = SingletaskLSTMModel(hidden_units) - elif num_tasks == 2: - model = MultitaskLSTMModel(hidden_units, lambdas=lambdas) + model = LSTMModel(hidden_units, num_tasks=num_tasks, recurrent_dropout=recurrent_dropout, dropout=dropout) elif model_type == "rgcn": model = RGCNModel( hidden_units, num_tasks=num_tasks, A=dist_matrix, rand_seed=seed, - ) - elif model_type == "lstm_grad_correction": - grad_log_file = os.path.join(out_dir, "grad_correction.txt") - model = MultitaskLSTMModel( - hidden_units, - gradient_correction=True, - lambdas=lambdas, dropout=dropout, - grad_log_file=grad_log_file, + recurrent_dropout=recurrent_dropout ) elif model_type == "gru": - if num_tasks == 1: - model = SingletaskGRUModel(hidden_units) - elif num_tasks == 2: - model = MultitaskGRUModel(hidden_units, lambdas=lambdas) + model = GRUModel(hidden_units, num_tasks=num_tasks, recurrent_dropout=recurrent_dropout, dropout=dropout) else: raise ValueError(f"The 'model_type' provided ({model_type}) is not supported") @@ -115,10 +103,7 @@ def train_model( [io_data["y_pre_trn"], io_data["y_pre_wgts"]], axis=2 ) - if model_type == "rgcn": - model.compile(optimizer_pre, loss=weighted_masked_rmse(lambdas=lambdas)) - else: - model.compile(optimizer_pre) + model.compile(optimizer_pre, loss=loss_func) csv_log_pre = tf.keras.callbacks.CSVLogger( os.path.join(out_dir, f"pretrain_log.csv") @@ -146,10 +131,7 @@ def train_model( if finetune_epochs > 0: optimizer_ft = tf.optimizers.Adam(learning_rate=learning_rate_ft) - if model_type == "rgcn": - model.compile(optimizer_ft, loss=weighted_masked_rmse(lambdas=lambdas)) - else: - model.compile(optimizer_ft) + model.compile(optimizer_ft, loss=loss_func) csv_log_ft = tf.keras.callbacks.CSVLogger( os.path.join(out_dir, "finetune_log.csv")