Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
[#106] match train cli with train.py fxn
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Jun 4, 2021
1 parent 3799509 commit a4d54e2
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions river_dl/train_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
import os
import argparse
from river_dl.train import train_model
import river_dl.loss_functions as lf


def get_loss_func_from_str(loss_func_str, lambdas=None):
if loss_func_str == 'rmse':
return lf.rmse
elif loss_func_str == 'nse':
return lf.nse
elif loss_func_str == 'kge':
return lf.kge
elif loss_func_str == 'multitask_rmse':
return lf.multitask_rmse(lambdas)
elif loss_func_str == 'multitask_nse':
return lf.multitask_nse(lambdas)
elif loss_func_str == 'multitask_kge':
return lf.multitask_kge(lambdas)
else:
raise ValueError(f'loss function {loss_func_str} not supported')


parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -44,13 +61,25 @@
parser.add_argument(
"--num_tasks", help="number of tasks (outputs to be predicted)", default=1, type=int
)
parser.add_argument(
"--loss_func", help="loss function", default='rmse', type=str,
choices=["rmse", "nse", "kge", "multitask_rmse", "multitask_kge", "multitask_nse"],
)
parser.add_argument(
"--dropout", help="dropout rate", default=0, type=float
)
parser.add_argument(
"--recurrent_dropout", help="recurrent dropout", default=0, type=float
)
parser.add_argument(
"--lambdas", help="lambdas for weighting variable losses", default=[1, 1], type=list
)


args = parser.parse_args()

loss_func = get_loss_func_from_str(args.loss_func)


# -------- train ------
model = train_model(
args.in_data_file,
Expand All @@ -59,7 +88,9 @@
args.hidden_units,
out_dir=args.out_dir,
num_tasks=args.num_tasks,
lambdas=args.lambdas,
loss_func=loss_func,
dropout=args.dropout,
recurrent_dropout=args.recurrent_dropout,
seed=args.random_seed,
learning_rate_ft=args.ft_learn_rate,
learning_rate_pre=args.pt_learn_rate,
Expand Down

0 comments on commit a4d54e2

Please sign in to comment.