diff --git a/code/machine_learning_models/decision_tree.py b/code/machine_learning_models/decision_tree.py
index c46df4d15470c710ca209dad0a34555f8ca1006f..2438a152b01c6641228146c9c59d6e2082f6f1dd 100644
--- a/code/machine_learning_models/decision_tree.py
+++ b/code/machine_learning_models/decision_tree.py
@@ -9,8 +9,7 @@ from sklearn.metrics import classification_report, confusion_matrix
 from sklearn.preprocessing import StandardScaler, LabelEncoder
 from sklearn.tree import DecisionTreeClassifier
 
-from utilities import plot_counts
-from utilities import plot_features, ordinal_encode, normalize, plot_confusion_matrix, print_high_confidence_samples, import_data
+from utilities import plot_features, ordinal_encode, normalize, plot_confusion_matrix, print_high_confidence_samples, import_data, plot_counts
 
 warnings.filterwarnings("ignore")
 
@@ -77,14 +76,9 @@ def train():
     print("Training complete.")
 
 def graphs():
-
     y_prediction = dtc.predict(X_test)
     print("Classification report: \n", classification_report(y_test, y_prediction))
 
-    # Plot absolute quantities of class 0 and class 1
-    sns.countplot(x=y_data, data=df_train)
-    plot_counts(model_name=model_name)
-
     # Determine feature importance
     features = pd.DataFrame(dtc.feature_importances_,
                             index=X_train.columns,
@@ -96,5 +90,6 @@ def graphs():
                           model_name=model_name)
 
     print_high_confidence_samples(model=dtc, x=X_train)
+    plot_counts(y_data, df_train)
 
     print("Graphs complete.")
\ No newline at end of file
diff --git a/code/machine_learning_models/knn.py b/code/machine_learning_models/knn.py
index cd62cd80c7cee288a49e426c2a6fbd233127289b..37d97955b4b203333a108b54e9b3835e9cb297ea 100644
--- a/code/machine_learning_models/knn.py
+++ b/code/machine_learning_models/knn.py
@@ -8,7 +8,7 @@ from sklearn.metrics import classification_report, confusion_matrix
 from sklearn.neighbors import KNeighborsClassifier
 from sklearn.preprocessing import StandardScaler, LabelEncoder
 
-from utilities import ordinal_encode, normalize, plot_confusion_matrix, plot_counts, import_data, plot_roc_curve
+from utilities import ordinal_encode, normalize, plot_confusion_matrix, import_data, plot_roc_curve, plot_counts
 
 warnings.filterwarnings("ignore")
 
@@ -58,8 +58,6 @@ def graphs():
     # Calculate prediction probabilities for ROC curve
     y_score = model.predict_proba(X_test)[:, 1]
     plot_roc_curve(y_test, y_score, model_name=model_name)
+    plot_counts(y_data, df_train)
 
-    # Plot absolute quantities of class 0 and class 1
-    sns.countplot(x=y_data, data=df_train)
-    plot_counts(model_name=model_name)
     print("Graphs complete.")
\ No newline at end of file
diff --git a/code/machine_learning_models/logistic_regression.py b/code/machine_learning_models/logistic_regression.py
index a0aea01af336f17bac26bef974f398c409d3eebc..1b56321d0c5ae6662867fe781b96825ae2132900 100644
--- a/code/machine_learning_models/logistic_regression.py
+++ b/code/machine_learning_models/logistic_regression.py
@@ -1,6 +1,5 @@
 import numpy as np
 import pandas as pd
-import seaborn as sns
 import os
 
 import warnings
@@ -9,7 +8,7 @@ from sklearn.model_selection import learning_curve
 from sklearn.preprocessing import StandardScaler, LabelEncoder
 
 from utilities import ordinal_encode, heat_map, plot_features, plot_confusion_matrix, normalize, \
-    print_high_confidence_samples, plot_counts, import_data, plot_roc_curve, plot_learning_curve
+    print_high_confidence_samples, import_data, plot_roc_curve, plot_learning_curve, plot_counts
 
 warnings.filterwarnings("ignore")
 
@@ -117,4 +116,6 @@ def graphs():
     # Plot ROC curve using the function from utilities
     plot_roc_curve(y_test, y_score, model_name=model_name)
 
+    plot_counts(y_data, df_train)
+
     print("Graphs complete.")
diff --git a/code/machine_learning_models/random_forest.py b/code/machine_learning_models/random_forest.py
index 6e4168d4c3cc1503ae6dff50e625f80ed492b9c6..7fa6abbe3ffa5565edfea84f1d70bfbab7abad6b 100644
--- a/code/machine_learning_models/random_forest.py
+++ b/code/machine_learning_models/random_forest.py
@@ -10,7 +10,7 @@ from sklearn.model_selection import learning_curve
 from sklearn.ensemble import RandomForestClassifier
 from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve
 
-from utilities import plot_precision_recall_curve, plot_learning_curve, import_data, plot_roc_curve
+from utilities import plot_precision_recall_curve, plot_learning_curve, import_data, plot_roc_curve, plot_counts
 from utilities import ordinal_encode, heat_map, plot_features, plot_confusion_matrix, normalize, print_high_confidence_samples
 
 sys.stdout.reconfigure(line_buffering=True)
@@ -95,4 +95,10 @@ def graphs():
 	# Plot ROC curve using the function from utilities
 	plot_roc_curve(y_test, y_score, model_name=model_name)
 
+	plot_counts(y_data, df_train)
+
 	print("Graphs complete.")
+
+if __name__ == "__main__":
+	train()
+	graphs()
\ No newline at end of file
diff --git a/code/machine_learning_models/utilities.py b/code/machine_learning_models/utilities.py
index 54a77c3c61e968067027ec4794d53699aeaad806..17268dce4b653f65d84ef1a9d23eeeba7ef526a3 100644
--- a/code/machine_learning_models/utilities.py
+++ b/code/machine_learning_models/utilities.py
@@ -11,6 +11,7 @@ from sklearn.preprocessing import OrdinalEncoder
 
 show_plots = False
 
+y_data = ['normal', 'anomaly']
 # Plots
 
 def heat_map(df, model_name=None):
@@ -34,12 +35,6 @@ def heat_map(df, model_name=None):
     if show_plots:
         plt.show()
 
-def plot_counts(model_name=None):
-    if model_name:
-        save_plot(model_name + " - Counts")
-    if show_plots:
-        plt.show()
-
 def plot_xy(df, x, y, model_name=None):
     """
     Creates a scatter plot for two numerical columns.
@@ -153,6 +148,13 @@ def plot_learning_curve(train_sizes, train_scores, test_scores, model_name=None)
     if show_plots:
         plt.show()
 
+def plot_counts(target, df):
+    plt.clf()
+    sns.countplot(x = target, data = df)
+    save_plot("Count")
+    if show_plots:
+        plt.show()
+
 def plot_roc_curve(y_true, y_score, model_name=None):
     """
     Plots the ROC curve for a binary classification model.