diff --git a/code/machine_learning_models/decision_tree.py b/code/machine_learning_models/decision_tree.py index c46df4d15470c710ca209dad0a34555f8ca1006f..2438a152b01c6641228146c9c59d6e2082f6f1dd 100644 --- a/code/machine_learning_models/decision_tree.py +++ b/code/machine_learning_models/decision_tree.py @@ -9,8 +9,7 @@ from sklearn.metrics import classification_report, confusion_matrix from sklearn.preprocessing import StandardScaler, LabelEncoder from sklearn.tree import DecisionTreeClassifier -from utilities import plot_counts -from utilities import plot_features, ordinal_encode, normalize, plot_confusion_matrix, print_high_confidence_samples, import_data +from utilities import plot_features, ordinal_encode, normalize, plot_confusion_matrix, print_high_confidence_samples, import_data, plot_counts warnings.filterwarnings("ignore") @@ -77,14 +76,9 @@ def train(): print("Training complete.") def graphs(): - y_prediction = dtc.predict(X_test) print("Classification report: \n", classification_report(y_test, y_prediction)) - # Plot absolute quantities of class 0 and class 1 - sns.countplot(x=y_data, data=df_train) - plot_counts(model_name=model_name) - # Determine feature importance features = pd.DataFrame(dtc.feature_importances_, index=X_train.columns, @@ -96,5 +90,6 @@ def graphs(): model_name=model_name) print_high_confidence_samples(model=dtc, x=X_train) + plot_counts(y_data, df_train) print("Graphs complete.") \ No newline at end of file diff --git a/code/machine_learning_models/knn.py b/code/machine_learning_models/knn.py index cd62cd80c7cee288a49e426c2a6fbd233127289b..37d97955b4b203333a108b54e9b3835e9cb297ea 100644 --- a/code/machine_learning_models/knn.py +++ b/code/machine_learning_models/knn.py @@ -8,7 +8,7 @@ from sklearn.metrics import classification_report, confusion_matrix from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler, LabelEncoder -from utilities import ordinal_encode, normalize, plot_confusion_matrix, plot_counts, import_data, plot_roc_curve +from utilities import ordinal_encode, normalize, plot_confusion_matrix, import_data, plot_roc_curve, plot_counts warnings.filterwarnings("ignore") @@ -58,8 +58,6 @@ def graphs(): # Calculate prediction probabilities for ROC curve y_score = model.predict_proba(X_test)[:, 1] plot_roc_curve(y_test, y_score, model_name=model_name) + plot_counts(y_data, df_train) - # Plot absolute quantities of class 0 and class 1 - sns.countplot(x=y_data, data=df_train) - plot_counts(model_name=model_name) print("Graphs complete.") \ No newline at end of file diff --git a/code/machine_learning_models/logistic_regression.py b/code/machine_learning_models/logistic_regression.py index a0aea01af336f17bac26bef974f398c409d3eebc..1b56321d0c5ae6662867fe781b96825ae2132900 100644 --- a/code/machine_learning_models/logistic_regression.py +++ b/code/machine_learning_models/logistic_regression.py @@ -1,6 +1,5 @@ import numpy as np import pandas as pd -import seaborn as sns import os import warnings @@ -9,7 +8,7 @@ from sklearn.model_selection import learning_curve from sklearn.preprocessing import StandardScaler, LabelEncoder from utilities import ordinal_encode, heat_map, plot_features, plot_confusion_matrix, normalize, \ - print_high_confidence_samples, plot_counts, import_data, plot_roc_curve, plot_learning_curve + print_high_confidence_samples, import_data, plot_roc_curve, plot_learning_curve, plot_counts warnings.filterwarnings("ignore") @@ -117,4 +116,6 @@ def graphs(): # Plot ROC curve using the function from utilities plot_roc_curve(y_test, y_score, model_name=model_name) + plot_counts(y_data, df_train) + print("Graphs complete.") diff --git a/code/machine_learning_models/random_forest.py b/code/machine_learning_models/random_forest.py index 6e4168d4c3cc1503ae6dff50e625f80ed492b9c6..7fa6abbe3ffa5565edfea84f1d70bfbab7abad6b 100644 --- a/code/machine_learning_models/random_forest.py +++ b/code/machine_learning_models/random_forest.py @@ -10,7 +10,7 @@ from sklearn.model_selection import learning_curve from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve -from utilities import plot_precision_recall_curve, plot_learning_curve, import_data, plot_roc_curve +from utilities import plot_precision_recall_curve, plot_learning_curve, import_data, plot_roc_curve, plot_counts from utilities import ordinal_encode, heat_map, plot_features, plot_confusion_matrix, normalize, print_high_confidence_samples sys.stdout.reconfigure(line_buffering=True) @@ -95,4 +95,10 @@ def graphs(): # Plot ROC curve using the function from utilities plot_roc_curve(y_test, y_score, model_name=model_name) + plot_counts(y_data, df_train) + print("Graphs complete.") + +if __name__ == "__main__": + train() + graphs() \ No newline at end of file diff --git a/code/machine_learning_models/utilities.py b/code/machine_learning_models/utilities.py index 54a77c3c61e968067027ec4794d53699aeaad806..17268dce4b653f65d84ef1a9d23eeeba7ef526a3 100644 --- a/code/machine_learning_models/utilities.py +++ b/code/machine_learning_models/utilities.py @@ -11,6 +11,7 @@ from sklearn.preprocessing import OrdinalEncoder show_plots = False +y_data = ['normal', 'anomaly'] # Plots def heat_map(df, model_name=None): @@ -34,12 +35,6 @@ def heat_map(df, model_name=None): if show_plots: plt.show() -def plot_counts(model_name=None): - if model_name: - save_plot(model_name + " - Counts") - if show_plots: - plt.show() - def plot_xy(df, x, y, model_name=None): """ Creates a scatter plot for two numerical columns. @@ -153,6 +148,13 @@ def plot_learning_curve(train_sizes, train_scores, test_scores, model_name=None) if show_plots: plt.show() +def plot_counts(target, df): + plt.clf() + sns.countplot(x = target, data = df) + save_plot("Count") + if show_plots: + plt.show() + def plot_roc_curve(y_true, y_score, model_name=None): """ Plots the ROC curve for a binary classification model.