Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
num_tasks, lambdas in train functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Jun 4, 2021
1 parent c8c140a commit 3e54465
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 37 deletions.
36 changes: 21 additions & 15 deletions river_dl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tensorflow as tf
from river_dl.RGCN import RGCNModel
from river_dl.loss_functions import weighted_masked_rmse
from river_dl.rnns import LSTMModel, GRUModel
from river_dl.rnns import SingletaskLSTMModel, MultitaskLSTMModel, SingletaskGRUModel, MultitaskGRUModel


def get_data_if_file(d):
Expand All @@ -27,11 +27,11 @@ def train_model(
finetune_epochs,
hidden_units,
out_dir,
flow_in_temp=False,
model_type="rgcn",
seed=None,
dropout=0,
lambda_aux=1,
lambdas=(1, 1),
num_tasks=1,
learning_rate_pre=0.005,
learning_rate_ft=0.01,
):
Expand All @@ -42,14 +42,12 @@ def train_model(
:param finetune_epochs: [int] number of finetune epochs
:param hidden_units: [int] number of hidden layers
:param out_dir: [str] directory where the output files should be written
:param flow_in_temp: [bool] whether the flow predictions should feed
into the temp predictions
:param model_type: [str] which model to use (either 'lstm', 'rgcn', or
'lstm_grad_correction')
:param seed: [int] random seed
:param lambda_aux: [float] weight between 0 and 1. How much
to weight the auxiliary rmse is weighted compared to the main rmse. The
difference between one and lambda becomes the main rmse weight.
:param lambdas: [array-like] weights to multiply the loss from each target
variable by
:param num_tasks: [int] number of tasks (outputs to be predicted)
:param learning_rate_pre: [float] the pretrain learning rate
:param learning_rate_ft: [float] the finetune learning rate
:return: [tf model] finetuned model
Expand All @@ -71,25 +69,33 @@ def train_model(
batch_size = num_years

if model_type == "lstm":
model = LSTMModel(hidden_units, lambda_aux=lambda_aux)
if num_tasks == 1:
model = SingletaskLSTMModel(hidden_units)
elif num_tasks == 2:
model = MultitaskLSTMModel(hidden_units, lambdas=lambdas)
elif model_type == "rgcn":
model = RGCNModel(
hidden_units,
flow_in_temp=flow_in_temp,
num_tasks=num_tasks,
A=dist_matrix,
rand_seed=seed,
)
elif model_type == "lstm_grad_correction":
grad_log_file = os.path.join(out_dir, "grad_correction.txt")
model = LSTMModel(
model = MultitaskLSTMModel(
hidden_units,
gradient_correction=True,
lambda_aux=lambda_aux,
lambdas=lambdas,
dropout=dropout,
grad_log_file=grad_log_file,
)
elif model_type == "gru":
model = GRUModel(hidden_units, lambda_aux=lambda_aux)
if num_tasks == 1:
model = SingletaskGRUModel(hidden_units)
elif num_tasks == 2:
model = MultitaskGRUModel(hidden_units, lambdas=lambdas)
else:
raise ValueError(f"The 'model_type' provided ({model_type}) is not supported")

if seed:
os.environ["PYTHONHASHSEED"] = str(seed)
Expand All @@ -110,7 +116,7 @@ def train_model(
)

if model_type == "rgcn":
model.compile(optimizer_pre, loss=weighted_masked_rmse(lambda_aux=lambda_aux))
model.compile(optimizer_pre, loss=weighted_masked_rmse(lambdas=lambdas))
else:
model.compile(optimizer_pre)

Expand Down Expand Up @@ -141,7 +147,7 @@ def train_model(
optimizer_ft = tf.optimizers.Adam(learning_rate=learning_rate_ft)

if model_type == "rgcn":
model.compile(optimizer_ft, loss=weighted_masked_rmse(lambda_aux=lambda_aux))
model.compile(optimizer_ft, loss=weighted_masked_rmse(lambdas=lambdas))
else:
model.compile(optimizer_ft)

Expand Down
34 changes: 12 additions & 22 deletions river_dl/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@
parser.add_argument(
"-f", "--finetune_epochs", help="number of finetune" "epochs", type=int
)
parser.add_argument(
"-q",
"--flow-in-temp",
help="whether or not to do flow\
in temp",
action="store_true",
)
parser.add_argument(
"--pt_learn_rate",
help="learning rate for pretraining",
Expand All @@ -45,31 +38,28 @@
parser.add_argument(
"--model",
help="type of model to train",
choices=["lstm", "rgcn"],
choices=["lstm", "rgcn", "gru"],
default="rgcn",
)
parser.add_argument(
"--lambda_aux", help="lambda for weighting aux gradient", default=1.0, type=float
"--num_tasks", help="number of tasks (outputs to be predicted)", default=1, type=int
)
parser.add_argument(
"--lambdas", help="lambdas for weighting variable losses", default=[1, 1], type=list
)


args = parser.parse_args()
flow_in_temp = args.flow_in_temp
in_data_file = args.in_data
hidden_units = args.hidden_units
out_dir = args.outdir
pt_epochs = args.pretrain_epochs
ft_epochs = args.finetune_epochs

# -------- train ------
model = train_model(
in_data_file,
pt_epochs,
ft_epochs,
hidden_units,
out_dir=out_dir,
flow_in_temp=flow_in_temp,
lambda_aux=args.lambda_aux,
args.in_data_file,
args.pretrain_epochs,
args.finetune_epochs,
args.hidden_units,
out_dir=args.out_dir,
num_tasks=args.num_tasks,
lambdas=args.lambdas,
seed=args.random_seed,
learning_rate_ft=args.ft_learn_rate,
learning_rate_pre=args.pt_learn_rate,
Expand Down

0 comments on commit 3e54465

Please sign in to comment.