From 649975f0c878e6e08b9223858ef5af1d61904e4c Mon Sep 17 00:00:00 2001
From: Daniel Yang <t5wol3yv@duck.com>
Date: Thu, 20 Mar 2025 09:10:03 +0100
Subject: [PATCH] refactored code

---
 code/machine_learning_models/utilities.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/code/machine_learning_models/utilities.py b/code/machine_learning_models/utilities.py
index cfbca48..5c409f0 100644
--- a/code/machine_learning_models/utilities.py
+++ b/code/machine_learning_models/utilities.py
@@ -12,7 +12,7 @@ show_plots = False
 
 # Plots
 
-def heat_map(df, model_name=None):
+def heat_map(df: pd.DataFrame, model_name=None):
     """
     Generates a heatmap of the correlation matrix for numerical features in the DataFrame.
 
@@ -39,7 +39,7 @@ def plot_counts(model_name=None):
     if show_plots:
         plt.show()
 
-def plot_xy(df, x, y, model_name=None):
+def plot_xy(df: pd.DataFrame, x, y, model_name=None):
     """
     Creates a scatter plot for two numerical columns.
 
@@ -113,8 +113,9 @@ def plot_precision_recall_curve(precision, recall, model_name=None):
 
     A good curve is mostly at the top and right.
     """
+    prc_auc = auc(recall, precision)
     plt.figure(figsize=(8, 6))
-    plt.plot(recall, precision, marker='.', label="Precision-Recall Curve")
+    plt.plot(recall, precision, color = 'darkorange', marker = '.', lw = 2, label = f"Precision Recall Curve (area = {prc_auc:.2f})")
     plt.xlabel("Recall")
     plt.ylabel("Precision")
     plt.title("Precision-Recall Curve")
-- 
GitLab