-
Notifications
You must be signed in to change notification settings - Fork 148
/
run_all_algorithms.py
121 lines (84 loc) · 4.25 KB
/
run_all_algorithms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from Recommenders.Recommender_import_list import *
from Data_manager.Movielens.Movielens1MReader import Movielens1MReader
from Data_manager.DataSplitter_leave_k_out import DataSplitter_leave_k_out
from Recommenders.Incremental_Training_Early_Stopping import Incremental_Training_Early_Stopping
from Recommenders.BaseCBFRecommender import BaseItemCBFRecommender, BaseUserCBFRecommender
from Evaluation.Evaluator import EvaluatorHoldout
import traceback, os
def _get_instance(recommender_class, URM_train, ICM_all, UCM_all):
if issubclass(recommender_class, BaseItemCBFRecommender):
recommender_object = recommender_class(URM_train, ICM_all)
elif issubclass(recommender_class, BaseUserCBFRecommender):
recommender_object = recommender_class(URM_train, UCM_all)
else:
recommender_object = recommender_class(URM_train)
return recommender_object
if __name__ == '__main__':
dataset_object = Movielens1MReader()
dataSplitter = DataSplitter_leave_k_out(dataset_object, k_out_value=2)
dataSplitter.load_data()
URM_train, URM_validation, URM_test = dataSplitter.get_holdout_split()
ICM_all = dataSplitter.get_loaded_ICM_dict()["ICM_genres"]
UCM_all = dataSplitter.get_loaded_UCM_dict()["UCM_all"]
recommender_class_list = [
Random,
TopPop,
GlobalEffects,
SLIMElasticNetRecommender,
UserKNNCFRecommender,
IALSRecommender,
MatrixFactorization_BPR_Cython,
MatrixFactorization_FunkSVD_Cython,
MatrixFactorization_AsySVD_Cython,
EASE_R_Recommender,
ItemKNNCFRecommender,
P3alphaRecommender,
SLIM_BPR_Cython,
RP3betaRecommender,
PureSVDRecommender,
NMFRecommender,
UserKNNCBFRecommender,
ItemKNNCBFRecommender,
UserKNN_CFCBF_Hybrid_Recommender,
ItemKNN_CFCBF_Hybrid_Recommender,
LightFMCFRecommender,
LightFMUserHybridRecommender,
LightFMItemHybridRecommender,
]
evaluator = EvaluatorHoldout(URM_test, [5, 20], exclude_seen=True)
# from MatrixFactorization.PyTorch.MF_MSE_PyTorch import MF_MSE_PyTorch
earlystopping_keywargs = {"validation_every_n": 5,
"stop_on_validation": True,
"evaluator_object": EvaluatorHoldout(URM_validation, [20], exclude_seen=True),
"lower_validations_allowed": 5,
"validation_metric": "MAP",
}
output_root_path = "./result_experiments/"
# If directory does not exist, create
if not os.path.exists(output_root_path):
os.makedirs(output_root_path)
logFile = open(output_root_path + "result_all_algorithms.txt", "a")
for recommender_class in recommender_class_list:
try:
print("Algorithm: {}".format(recommender_class))
recommender_object = _get_instance(recommender_class, URM_train, ICM_all, UCM_all)
if isinstance(recommender_object, Incremental_Training_Early_Stopping):
fit_params = {"epochs": 15, **earlystopping_keywargs}
else:
fit_params = {}
recommender_object.fit(**fit_params)
results_run_1, results_run_string_1 = evaluator.evaluateRecommender(recommender_object)
recommender_object.save_model(output_root_path, file_name = "temp_model.zip")
recommender_object = _get_instance(recommender_class, URM_train, ICM_all, UCM_all)
recommender_object.load_model(output_root_path, file_name = "temp_model.zip")
os.remove(output_root_path + "temp_model.zip")
results_run_2, results_run_string_2 = evaluator.evaluateRecommender(recommender_object)
if recommender_class not in [Random]:
assert results_run_1.equals(results_run_2)
print("Algorithm: {}, results: \n{}".format(recommender_class, results_run_string_1))
logFile.write("Algorithm: {}, results: \n{}\n".format(recommender_class, results_run_string_1))
logFile.flush()
except Exception as e:
traceback.print_exc()
logFile.write("Algorithm: {} - Exception: {}\n".format(recommender_class, str(e)))
logFile.flush()