Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
[#106] provide loss_func to train func; compiles rnns
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Jun 4, 2021
1 parent 17bfab8 commit 2c7580b
Showing 1 changed file with 12 additions and 30 deletions.
42 changes: 12 additions & 30 deletions river_dl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import datetime
import tensorflow as tf
from river_dl.RGCN import RGCNModel
from river_dl.loss_functions import weighted_masked_rmse
from river_dl.rnns import SingletaskLSTMModel, MultitaskLSTMModel, SingletaskGRUModel, MultitaskGRUModel
from river_dl.rnns import LSTMModel, GRUModel


def get_data_if_file(d):
Expand All @@ -26,11 +25,12 @@ def train_model(
pretrain_epochs,
finetune_epochs,
hidden_units,
loss_func,
out_dir,
model_type="rgcn",
seed=None,
dropout=0,
lambdas=(1, 1),
recurrent_dropout=0,
num_tasks=1,
learning_rate_pre=0.005,
learning_rate_ft=0.01,
Expand All @@ -41,12 +41,13 @@ def train_model(
:param pretrain_epochs: [int] number of pretrain epochs
:param finetune_epochs: [int] number of finetune epochs
:param hidden_units: [int] number of hidden layers
:param loss_func: [function] loss function that the model will be fit to
:param out_dir: [str] directory where the output files should be written
:param model_type: [str] which model to use (either 'lstm', 'rgcn', or
'lstm_grad_correction')
'gru')
:param seed: [int] random seed
:param lambdas: [array-like] weights to multiply the loss from each target
variable by
:param recurrent_dropout: [float] value between 0 and 1 for the probability of a reccurent element to be zero
:param dropout: [float] value between 0 and 1 for the probability of an input element to be zero
:param num_tasks: [int] number of tasks (outputs to be predicted)
:param learning_rate_pre: [float] the pretrain learning rate
:param learning_rate_ft: [float] the finetune learning rate
Expand All @@ -69,31 +70,18 @@ def train_model(
batch_size = num_years

if model_type == "lstm":
if num_tasks == 1:
model = SingletaskLSTMModel(hidden_units)
elif num_tasks == 2:
model = MultitaskLSTMModel(hidden_units, lambdas=lambdas)
model = LSTMModel(hidden_units, num_tasks=num_tasks, recurrent_dropout=recurrent_dropout, dropout=dropout)
elif model_type == "rgcn":
model = RGCNModel(
hidden_units,
num_tasks=num_tasks,
A=dist_matrix,
rand_seed=seed,
)
elif model_type == "lstm_grad_correction":
grad_log_file = os.path.join(out_dir, "grad_correction.txt")
model = MultitaskLSTMModel(
hidden_units,
gradient_correction=True,
lambdas=lambdas,
dropout=dropout,
grad_log_file=grad_log_file,
recurrent_dropout=recurrent_dropout
)
elif model_type == "gru":
if num_tasks == 1:
model = SingletaskGRUModel(hidden_units)
elif num_tasks == 2:
model = MultitaskGRUModel(hidden_units, lambdas=lambdas)
model = GRUModel(hidden_units, num_tasks=num_tasks, recurrent_dropout=recurrent_dropout, dropout=dropout)
else:
raise ValueError(f"The 'model_type' provided ({model_type}) is not supported")

Expand All @@ -115,10 +103,7 @@ def train_model(
[io_data["y_pre_trn"], io_data["y_pre_wgts"]], axis=2
)

if model_type == "rgcn":
model.compile(optimizer_pre, loss=weighted_masked_rmse(lambdas=lambdas))
else:
model.compile(optimizer_pre)
model.compile(optimizer_pre, loss=loss_func)

csv_log_pre = tf.keras.callbacks.CSVLogger(
os.path.join(out_dir, f"pretrain_log.csv")
Expand Down Expand Up @@ -146,10 +131,7 @@ def train_model(
if finetune_epochs > 0:
optimizer_ft = tf.optimizers.Adam(learning_rate=learning_rate_ft)

if model_type == "rgcn":
model.compile(optimizer_ft, loss=weighted_masked_rmse(lambdas=lambdas))
else:
model.compile(optimizer_ft)
model.compile(optimizer_ft, loss=loss_func)

csv_log_ft = tf.keras.callbacks.CSVLogger(
os.path.join(out_dir, "finetune_log.csv")
Expand Down

0 comments on commit 2c7580b

Please sign in to comment.