From 0f515314b9a3aef991daff97b318d07587ebe07d Mon Sep 17 00:00:00 2001 From: ugmom <ugmom@student.kit.edu> Date: Sun, 23 Mar 2025 20:28:02 +0100 Subject: [PATCH] fixed count plot --- code/machine_learning_models/decision_tree.py | 9 ++------- code/machine_learning_models/knn.py | 6 ++---- .../machine_learning_models/logistic_regression.py | 5 +++-- code/machine_learning_models/random_forest.py | 8 +++++++- code/machine_learning_models/utilities.py | 14 ++++++++------ 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/code/machine_learning_models/decision_tree.py b/code/machine_learning_models/decision_tree.py index c46df4d..2438a15 100644 --- a/code/machine_learning_models/decision_tree.py +++ b/code/machine_learning_models/decision_tree.py @@ -9,8 +9,7 @@ from sklearn.metrics import classification_report, confusion_matrix from sklearn.preprocessing import StandardScaler, LabelEncoder from sklearn.tree import DecisionTreeClassifier -from utilities import plot_counts -from utilities import plot_features, ordinal_encode, normalize, plot_confusion_matrix, print_high_confidence_samples, import_data +from utilities import plot_features, ordinal_encode, normalize, plot_confusion_matrix, print_high_confidence_samples, import_data, plot_counts warnings.filterwarnings("ignore") @@ -77,14 +76,9 @@ def train(): print("Training complete.") def graphs(): - y_prediction = dtc.predict(X_test) print("Classification report: \n", classification_report(y_test, y_prediction)) - # Plot absolute quantities of class 0 and class 1 - sns.countplot(x=y_data, data=df_train) - plot_counts(model_name=model_name) - # Determine feature importance features = pd.DataFrame(dtc.feature_importances_, index=X_train.columns, @@ -96,5 +90,6 @@ def graphs(): model_name=model_name) print_high_confidence_samples(model=dtc, x=X_train) + plot_counts(y_data, df_train) print("Graphs complete.") \ No newline at end of file diff --git a/code/machine_learning_models/knn.py b/code/machine_learning_models/knn.py index cd62cd8..37d9795 100644 --- a/code/machine_learning_models/knn.py +++ b/code/machine_learning_models/knn.py @@ -8,7 +8,7 @@ from sklearn.metrics import classification_report, confusion_matrix from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler, LabelEncoder -from utilities import ordinal_encode, normalize, plot_confusion_matrix, plot_counts, import_data, plot_roc_curve +from utilities import ordinal_encode, normalize, plot_confusion_matrix, import_data, plot_roc_curve, plot_counts warnings.filterwarnings("ignore") @@ -58,8 +58,6 @@ def graphs(): # Calculate prediction probabilities for ROC curve y_score = model.predict_proba(X_test)[:, 1] plot_roc_curve(y_test, y_score, model_name=model_name) + plot_counts(y_data, df_train) - # Plot absolute quantities of class 0 and class 1 - sns.countplot(x=y_data, data=df_train) - plot_counts(model_name=model_name) print("Graphs complete.") \ No newline at end of file diff --git a/code/machine_learning_models/logistic_regression.py b/code/machine_learning_models/logistic_regression.py index a0aea01..1b56321 100644 --- a/code/machine_learning_models/logistic_regression.py +++ b/code/machine_learning_models/logistic_regression.py @@ -1,6 +1,5 @@ import numpy as np import pandas as pd -import seaborn as sns import os import warnings @@ -9,7 +8,7 @@ from sklearn.model_selection import learning_curve from sklearn.preprocessing import StandardScaler, LabelEncoder from utilities import ordinal_encode, heat_map, plot_features, plot_confusion_matrix, normalize, \ - print_high_confidence_samples, plot_counts, import_data, plot_roc_curve, plot_learning_curve + print_high_confidence_samples, import_data, plot_roc_curve, plot_learning_curve, plot_counts warnings.filterwarnings("ignore") @@ -117,4 +116,6 @@ def graphs(): # Plot ROC curve using the function from utilities plot_roc_curve(y_test, y_score, model_name=model_name) + plot_counts(y_data, df_train) + print("Graphs complete.") diff --git a/code/machine_learning_models/random_forest.py b/code/machine_learning_models/random_forest.py index 6e4168d..7fa6abb 100644 --- a/code/machine_learning_models/random_forest.py +++ b/code/machine_learning_models/random_forest.py @@ -10,7 +10,7 @@ from sklearn.model_selection import learning_curve from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve -from utilities import plot_precision_recall_curve, plot_learning_curve, import_data, plot_roc_curve +from utilities import plot_precision_recall_curve, plot_learning_curve, import_data, plot_roc_curve, plot_counts from utilities import ordinal_encode, heat_map, plot_features, plot_confusion_matrix, normalize, print_high_confidence_samples sys.stdout.reconfigure(line_buffering=True) @@ -95,4 +95,10 @@ def graphs(): # Plot ROC curve using the function from utilities plot_roc_curve(y_test, y_score, model_name=model_name) + plot_counts(y_data, df_train) + print("Graphs complete.") + +if __name__ == "__main__": + train() + graphs() \ No newline at end of file diff --git a/code/machine_learning_models/utilities.py b/code/machine_learning_models/utilities.py index 54a77c3..17268dc 100644 --- a/code/machine_learning_models/utilities.py +++ b/code/machine_learning_models/utilities.py @@ -11,6 +11,7 @@ from sklearn.preprocessing import OrdinalEncoder show_plots = False +y_data = ['normal', 'anomaly'] # Plots def heat_map(df, model_name=None): @@ -34,12 +35,6 @@ def heat_map(df, model_name=None): if show_plots: plt.show() -def plot_counts(model_name=None): - if model_name: - save_plot(model_name + " - Counts") - if show_plots: - plt.show() - def plot_xy(df, x, y, model_name=None): """ Creates a scatter plot for two numerical columns. @@ -153,6 +148,13 @@ def plot_learning_curve(train_sizes, train_scores, test_scores, model_name=None) if show_plots: plt.show() +def plot_counts(target, df): + plt.clf() + sns.countplot(x = target, data = df) + save_plot("Count") + if show_plots: + plt.show() + def plot_roc_curve(y_true, y_score, model_name=None): """ Plots the ROC curve for a binary classification model. -- GitLab