diff --git a/code/machine_learning_models/utilities.py b/code/machine_learning_models/utilities.py index cfbca48dce5fb43d2475adbb797951e8279e2922..5c409f03828635a003d3fd5aaaf8a9899103e1f4 100644 --- a/code/machine_learning_models/utilities.py +++ b/code/machine_learning_models/utilities.py @@ -12,7 +12,7 @@ show_plots = False # Plots -def heat_map(df, model_name=None): +def heat_map(df: pd.DataFrame, model_name=None): """ Generates a heatmap of the correlation matrix for numerical features in the DataFrame. @@ -39,7 +39,7 @@ def plot_counts(model_name=None): if show_plots: plt.show() -def plot_xy(df, x, y, model_name=None): +def plot_xy(df: pd.DataFrame, x, y, model_name=None): """ Creates a scatter plot for two numerical columns. @@ -113,8 +113,9 @@ def plot_precision_recall_curve(precision, recall, model_name=None): A good curve is mostly at the top and right. """ + prc_auc = auc(recall, precision) plt.figure(figsize=(8, 6)) - plt.plot(recall, precision, marker='.', label="Precision-Recall Curve") + plt.plot(recall, precision, color = 'darkorange', marker = '.', lw = 2, label = f"Precision Recall Curve (area = {prc_auc:.2f})") plt.xlabel("Recall") plt.ylabel("Precision") plt.title("Precision-Recall Curve")