Skip to content

Commit

Permalink
updates to trainer and monitoring
Browse files Browse the repository at this point in the history
  • Loading branch information
borauyar committed Jul 6, 2024
1 parent fa91072 commit e306756
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class AvailableModels(NamedTuple):
string_organism=args.string_organism,
string_node_name=args.string_node_name,
downsample = args.subsample)
train_dataset, test_dataset = data_importer.import_data(force = True)
train_dataset, test_dataset = data_importer.import_data(force = False)

if args.model_class == 'GNNEarly':
# overlay datasets with network info
Expand Down
8 changes: 8 additions & 0 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from lightning import seed_everything
seed_everything(42, workers=True)

import torch
from torch.utils.data import DataLoader, random_split
import torch_geometric
Expand Down Expand Up @@ -151,8 +153,11 @@ def setup_trainer(self, params, current_step, total_steps, full_train = False):
mycallbacks.append(early_stop_callback)

trainer = pl.Trainer(
#deterministic = True,
precision = '16-mixed', # mixed precision training
max_epochs=int(params['epochs']),
gradient_clip_val=1.0,
gradient_clip_algorithm='norm',
log_every_n_steps=5,
callbacks=mycallbacks,
default_root_dir="./",
Expand Down Expand Up @@ -274,6 +279,9 @@ def perform_tuning(self, hpo_patience = 0):
if no_improvement_count >= hpo_patience & hpo_patience > 0:
print(f"No improvement in best loss for {hpo_patience} iterations, stopping hyperparameter optimisation early.")
break # Break out of the loop
best_params_dict = {param.name: value for param, value in zip(self.space, best_params)} if best_params else None
print(f"[INFO] current best val loss: {best_loss}; best params: {best_params_dict} since {no_improvement_count} hpo iterations")


# Convert best parameters from list to dictionary and include epochs
best_params_dict = {param.name: value for param, value in zip(self.space, best_params)}
Expand Down

0 comments on commit e306756

Please sign in to comment.