From 649975f0c878e6e08b9223858ef5af1d61904e4c Mon Sep 17 00:00:00 2001 From: Daniel Yang <t5wol3yv@duck.com> Date: Thu, 20 Mar 2025 09:10:03 +0100 Subject: [PATCH] refactored code --- code/machine_learning_models/utilities.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/code/machine_learning_models/utilities.py b/code/machine_learning_models/utilities.py index cfbca48..5c409f0 100644 --- a/code/machine_learning_models/utilities.py +++ b/code/machine_learning_models/utilities.py @@ -12,7 +12,7 @@ show_plots = False # Plots -def heat_map(df, model_name=None): +def heat_map(df: pd.DataFrame, model_name=None): """ Generates a heatmap of the correlation matrix for numerical features in the DataFrame. @@ -39,7 +39,7 @@ def plot_counts(model_name=None): if show_plots: plt.show() -def plot_xy(df, x, y, model_name=None): +def plot_xy(df: pd.DataFrame, x, y, model_name=None): """ Creates a scatter plot for two numerical columns. @@ -113,8 +113,9 @@ def plot_precision_recall_curve(precision, recall, model_name=None): A good curve is mostly at the top and right. """ + prc_auc = auc(recall, precision) plt.figure(figsize=(8, 6)) - plt.plot(recall, precision, marker='.', label="Precision-Recall Curve") + plt.plot(recall, precision, color = 'darkorange', marker = '.', lw = 2, label = f"Precision Recall Curve (area = {prc_auc:.2f})") plt.xlabel("Recall") plt.ylabel("Precision") plt.title("Precision-Recall Curve") -- GitLab