Skip to content
Snippets Groups Projects
Commit 2bbed70a authored by Daniel Yang's avatar Daniel Yang
Browse files

added print_high_confidence_samples

parent b268e577
No related branches found
No related tags found
No related merge requests found
......@@ -129,6 +129,13 @@ def plot_confusion_matrix(confusion_matrix: List[List[int]], accuracy: float, mo
if show_plots:
plt.show()
def print_high_confidence_samples(model, x: pd.DataFrame):
# Get predicted probabilities
predicted_probabilities = pd.DataFrame(model.predict_proba(x)[:, 1],
columns=['confidence level']) # Probability of being class 1
# Filter samples where the model is at least 90% sure
high_confidence_samples = predicted_probabilities[predicted_probabilities['confidence level'] > 0.9]
print(high_confidence_samples.head())
def save_plot(name):
plt.savefig("resulting_figures/" + name, dpi=300, bbox_inches='tight')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment