Skip to content
Snippets Groups Projects
Commit 0f515314 authored by Daniel Yang's avatar Daniel Yang
Browse files

fixed count plot

parent 6bb34210
No related branches found
No related tags found
No related merge requests found
......@@ -9,8 +9,7 @@ from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from utilities import plot_counts
from utilities import plot_features, ordinal_encode, normalize, plot_confusion_matrix, print_high_confidence_samples, import_data
from utilities import plot_features, ordinal_encode, normalize, plot_confusion_matrix, print_high_confidence_samples, import_data, plot_counts
warnings.filterwarnings("ignore")
......@@ -77,14 +76,9 @@ def train():
print("Training complete.")
def graphs():
y_prediction = dtc.predict(X_test)
print("Classification report: \n", classification_report(y_test, y_prediction))
# Plot absolute quantities of class 0 and class 1
sns.countplot(x=y_data, data=df_train)
plot_counts(model_name=model_name)
# Determine feature importance
features = pd.DataFrame(dtc.feature_importances_,
index=X_train.columns,
......@@ -96,5 +90,6 @@ def graphs():
model_name=model_name)
print_high_confidence_samples(model=dtc, x=X_train)
plot_counts(y_data, df_train)
print("Graphs complete.")
\ No newline at end of file
......@@ -8,7 +8,7 @@ from sklearn.metrics import classification_report, confusion_matrix
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler, LabelEncoder
from utilities import ordinal_encode, normalize, plot_confusion_matrix, plot_counts, import_data, plot_roc_curve
from utilities import ordinal_encode, normalize, plot_confusion_matrix, import_data, plot_roc_curve, plot_counts
warnings.filterwarnings("ignore")
......@@ -58,8 +58,6 @@ def graphs():
# Calculate prediction probabilities for ROC curve
y_score = model.predict_proba(X_test)[:, 1]
plot_roc_curve(y_test, y_score, model_name=model_name)
plot_counts(y_data, df_train)
# Plot absolute quantities of class 0 and class 1
sns.countplot(x=y_data, data=df_train)
plot_counts(model_name=model_name)
print("Graphs complete.")
\ No newline at end of file
import numpy as np
import pandas as pd
import seaborn as sns
import os
import warnings
......@@ -9,7 +8,7 @@ from sklearn.model_selection import learning_curve
from sklearn.preprocessing import StandardScaler, LabelEncoder
from utilities import ordinal_encode, heat_map, plot_features, plot_confusion_matrix, normalize, \
print_high_confidence_samples, plot_counts, import_data, plot_roc_curve, plot_learning_curve
print_high_confidence_samples, import_data, plot_roc_curve, plot_learning_curve, plot_counts
warnings.filterwarnings("ignore")
......@@ -117,4 +116,6 @@ def graphs():
# Plot ROC curve using the function from utilities
plot_roc_curve(y_test, y_score, model_name=model_name)
plot_counts(y_data, df_train)
print("Graphs complete.")
......@@ -10,7 +10,7 @@ from sklearn.model_selection import learning_curve
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve
from utilities import plot_precision_recall_curve, plot_learning_curve, import_data, plot_roc_curve
from utilities import plot_precision_recall_curve, plot_learning_curve, import_data, plot_roc_curve, plot_counts
from utilities import ordinal_encode, heat_map, plot_features, plot_confusion_matrix, normalize, print_high_confidence_samples
sys.stdout.reconfigure(line_buffering=True)
......@@ -95,4 +95,10 @@ def graphs():
# Plot ROC curve using the function from utilities
plot_roc_curve(y_test, y_score, model_name=model_name)
plot_counts(y_data, df_train)
print("Graphs complete.")
if __name__ == "__main__":
train()
graphs()
\ No newline at end of file
......@@ -11,6 +11,7 @@ from sklearn.preprocessing import OrdinalEncoder
show_plots = False
y_data = ['normal', 'anomaly']
# Plots
def heat_map(df, model_name=None):
......@@ -34,12 +35,6 @@ def heat_map(df, model_name=None):
if show_plots:
plt.show()
def plot_counts(model_name=None):
if model_name:
save_plot(model_name + " - Counts")
if show_plots:
plt.show()
def plot_xy(df, x, y, model_name=None):
"""
Creates a scatter plot for two numerical columns.
......@@ -153,6 +148,13 @@ def plot_learning_curve(train_sizes, train_scores, test_scores, model_name=None)
if show_plots:
plt.show()
def plot_counts(target, df):
plt.clf()
sns.countplot(x = target, data = df)
save_plot("Count")
if show_plots:
plt.show()
def plot_roc_curve(y_true, y_score, model_name=None):
"""
Plots the ROC curve for a binary classification model.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment