diff --git a/superstyl/svm.py b/superstyl/svm.py index b66cbe24..3bf49924 100755 --- a/superstyl/svm.py +++ b/superstyl/svm.py @@ -1,3 +1,6 @@ + + + import sklearn.svm as sk import sklearn.metrics as metrics import sklearn.decomposition as decomp @@ -101,20 +104,27 @@ def train_svm(train, test, cross_validate=None, k=10, dim_reduc=None, norms=True estimators.append(('sampling', over.RandomOverSampler(random_state=42))) if balance in ['SMOTE', 'SMOTETomek']: - # Adjust n_neighbors for SMOTE/SMOTETomek based on smallest class size: + # Adjust n_neighbors for SMOTE based on smallest class size: # Ensures that the resampling method does not attempt to use more neighbors than available samples # in the minority class, which produced the error. min_class_size = min(Counter(classes).values()) n_neighbors = min(5, min_class_size - 1) # Default n_neighbors in SMOTE is 5 # In case we have to temper with the n_neighbors, we print a warning message to the user # (might be written more clearly, but we want a short message, right?) - if n_neighbors >= min_class_size: + if 0 < n_neighbors >= min_class_size: print( - f"Warning: Adjusting n_neighbors for SMOTE / SMOTETomek to {n_neighbors} due to small class size.") - if balance == 'SMOTE': - estimators.append(('sampling', over.SMOTE(n_neighbors=n_neighbors, random_state=42))) - elif balance == 'SMOTETomek': - estimators.append(('sampling', comb.SMOTETomek(n_neighbors=n_neighbors, random_state=42))) + f"Warning: Adjusting n_neighbors for SMOTE to {n_neighbors} due to small class size.") + + if n_neighbors == 0: + print( + f"Warning: at least one class only has a single individual; cannot apply SMOTE(Tomek) due to small class size.") + else: + + if balance == 'SMOTE': + estimators.append(('sampling', over.SMOTE(k_neighbors=n_neighbors, random_state=42))) + + elif balance == 'SMOTETomek': + estimators.append(('sampling', comb.SMOTETomek(random_state=42, smote=over.SMOTE(k_neighbors=n_neighbors, random_state=42)))) print(".......... choosing SVM ........") @@ -255,7 +265,6 @@ def train_svm(train, test, cross_validate=None, k=10, dim_reduc=None, norms=True return results - # Following function from Aneesha Bakharia # https://aneesha.medium.com/visualising-top-features-in-linear-svm-with-scikit-learn-and-matplotlib-3454ab18a14d diff --git a/tests/test_train_svm.py b/tests/test_train_svm.py index 2cc11668..e5335e96 100644 --- a/tests/test_train_svm.py +++ b/tests/test_train_svm.py @@ -53,8 +53,51 @@ def test_train_svm(self): self.assertEqual(results["misattributions"].to_dict(), expected_results["misattributions"]) self.assertEqual(list(results.keys()), expected_keys) + #TODO: quick tests for SMOTE, SMOTETOMEK, to improve # WHEN - #results = superstyl.train_svm(train, test, final_pred=False, balance="SMOTETomek") + results = superstyl.train_svm(train, test, final_pred=False, balance="SMOTETomek") + # THEN + self.assertEqual(results["confusion_matrix"].to_dict(), expected_results["confusion_matrix"]) + self.assertEqual(results["classification_report"], expected_results["classification_report"]) + self.assertEqual(results["misattributions"].to_dict(), expected_results["misattributions"]) + self.assertEqual(list(results.keys()), expected_keys) + + # WHEN + results = superstyl.train_svm(train, test, final_pred=False, balance="SMOTE") + # THEN + self.assertEqual(results["confusion_matrix"].to_dict(), expected_results["confusion_matrix"]) + self.assertEqual(results["classification_report"], expected_results["classification_report"]) + self.assertEqual(results["misattributions"].to_dict(), expected_results["misattributions"]) + self.assertEqual(list(results.keys()), expected_keys) + + # now, when it can be applied, but needs to be recomputed, because of a maximum possible number of + # neighbors < 5 + train2 = pandas.DataFrame({'author': {'Dupont_Letter1.txt': 'Dupont', 'Dupont_Letter2.txt': 'Dupont', 'Smith_Letter1.txt': 'Smith', 'Smith_Letter2.txt': 'Smith'}, + 'lang': {'Dupont_Letter1.txt': 'NA', 'Dupont_Letter2.txt': 'NA', 'Smith_Letter1.txt': 'NA', 'Smith_Letter2.txt': 'NA'}, + 'this': {'Dupont_Letter1.txt': 0.0, 'Dupont_Letter2.txt': 0.0, 'Smith_Letter1.txt': 0.25, 'Smith_Letter2.txt': 0.2}, + 'is': {'Dupont_Letter1.txt': 0.0, 'Dupont_Letter2.txt': 0.0, 'Smith_Letter1.txt': 0.25, 'Smith_Letter2.txt': 0.2}, + 'the': {'Dupont_Letter1.txt': 0.0, 'Dupont_Letter2.txt': 0.0, 'Smith_Letter1.txt': 0.25, 'Smith_Letter2.txt': 0.2}, + 'text': {'Dupont_Letter1.txt': 0.0, 'Dupont_Letter2.txt': 0.0, 'Smith_Letter1.txt': 0.25, 'Smith_Letter2.txt': 0.2}, + 'voici': {'Dupont_Letter1.txt': 1/3, 'Dupont_Letter2.txt': 1/3, 'Smith_Letter1.txt': 0.0, 'Smith_Letter2.txt': 0.0}, + 'le': {'Dupont_Letter1.txt': 1/3, 'Dupont_Letter2.txt': 1/3, 'Smith_Letter1.txt': 0.0, 'Smith_Letter2.txt': 0.0}, + 'texte': {'Dupont_Letter1.txt': 1/3, 'Dupont_Letter2.txt': 1/3, 'Smith_Letter1.txt': 0.0, 'Smith_Letter2.txt': 0.0}, + 'also': {'Dupont_Letter1.txt': 0.0, 'Dupont_Letter2.txt': 0.0, 'Smith_Letter1.txt': 0.0, 'Smith_Letter2.txt': 0.2}}) + + # WHEN + results = superstyl.train_svm(train2, test, final_pred=True, balance="SMOTETomek") + # THEN + expected_preds = {'filename': {0: 'Dupont_Letter1.txt', 1: 'Smith_Letter1.txt', + 2: 'Smith_Letter2.txt'}, + 'author': {0: 'Dupont', 1: 'Smith', 2: 'Smith'}, + 'Decision function': {0: -0.883772535448984, 1: 0.8756912342726781, + 2: 0.873288374519472}} + + self.assertEqual(results['final_predictions'].to_dict()["author"], expected_preds["author"]) + + # WHEN + results = superstyl.train_svm(train2, test, final_pred=True, balance="SMOTE") + # THEN + self.assertEqual(results['final_predictions'].to_dict()["author"], expected_preds["author"]) # This is only the first minimal tests for this function