diff --git a/code/machine_learning_models/random_forest.py b/code/machine_learning_models/random_forest.py
index 5bdbbf01c42c68807f8ba6280a0939abea539225..902b968bb9f461157e2d0855f76aea2d54a4e2b0 100644
--- a/code/machine_learning_models/random_forest.py
+++ b/code/machine_learning_models/random_forest.py
@@ -31,9 +31,6 @@ ordinal_encode(df = df_test, categories = y_values, target = y_data)
 
 normalize(df_train, df_test, y_data, sc, enc)
 
-# Correlation
-heat_map(df_train, model_name = model_name)
-
 # Separate X and y
 X_train = df_train.select_dtypes(include=[np.number]).drop(columns = [y_data])
 X_test = df_test.select_dtypes(include=[np.number]).drop(columns = [y_data])
@@ -42,48 +39,61 @@ y_test = df_test[[y_data]]
 
 # Train Random Forest Model
 model = RandomForestClassifier(n_estimators=100, max_depth=None, random_state=42)
-model.fit(X_train, y_train.values.ravel())
 
-# Predictions
-y_prediction = model.predict(X_test)
+# Prediction function
+def predict(prediction_input):
+	if len(prediction_input) == 0:
+		return
+	input_df = pd.DataFrame(prediction_input, columns=X_train.columns)
+	input_df[numerical_columns] = sc.transform(input_df[numerical_columns])
+	return ["anomaly" if x == 1 else "normal" for x in model.predict(input_df)]
 
-# Plot Confusion Matrix
-plot_confusion_matrix(confusion_matrix=confusion_matrix(y_test, y_prediction),
-					  accuracy=model.score(X_test, y_test),
-					  model_name=model_name)
+def train():
+	model.fit(X_train, y_train.values.ravel())
+	graphs()
+	print("Training complete.")
 
-print("Classification Report: \n", classification_report(y_test, y_prediction))
+def graphs():
+	# Correlation
+	heat_map(df_train, model_name=model_name)
 
+	# Predictions
+	y_prediction = model.predict(X_test)
 
-# Get high confidence samples for which the model is 90% confident
-print_high_confidence_samples(model, X_train)
+	# Plot Confusion Matrix
+	plot_confusion_matrix(confusion_matrix=confusion_matrix(y_test, y_prediction),
+	                      accuracy=model.score(X_test, y_test),
+	                      model_name=model_name)
 
-# Feature Importance Plot
-features = pd.DataFrame(model.feature_importances_, index=X_train.columns, columns=['Importance']).sort_values(by='Importance', ascending=False)
-plot_features(features, "Higher importance = More impact on classification", model_name=model_name)
+	print("Classification Report: \n", classification_report(y_test, y_prediction))
 
-# Precision-Recall Curve
-print("Calculating Precision Recall Curve")
-y_scores = model.predict_proba(X_test)[:, 1]
-precision, recall, _ = precision_recall_curve(y_test, y_scores)
-plot_precision_recall_curve(precision, recall, model_name)
+	# Get high confidence samples for which the model is 90% confident
+	print_high_confidence_samples(model, X_train)
 
-# Learning Curve
-print("Calculating Learning Curve")
-train_sizes, train_scores, test_scores = learning_curve(model, X_train, y_train.values.ravel(), cv=5, scoring="accuracy")
-plot_learning_curve(train_sizes, train_scores, test_scores, model_name)
+	# Feature Importance Plot
+	features = pd.DataFrame(model.feature_importances_, index=X_train.columns, columns=['Importance']).sort_values(
+		by='Importance', ascending=False)
+	plot_features(features, "Higher importance = More impact on classification", model_name=model_name)
 
-# Calculate prediction probabilities for ROC curve
-y_score = model.predict_proba(X_test)[:, 1]
+	# Precision-Recall Curve
+	print("Calculating Precision Recall Curve")
+	y_scores = model.predict_proba(X_test)[:, 1]
+	precision, recall, _ = precision_recall_curve(y_test, y_scores)
+	plot_precision_recall_curve(precision, recall, model_name)
 
-# Plot ROC curve using the function from utilities
-plot_roc_curve(y_test, y_score, model_name=model_name)
+	# Learning Curve
+	print("Calculating Learning Curve")
+	train_sizes, train_scores, test_scores = learning_curve(model, X_train, y_train.values.ravel(), cv=5,
+	                                                        scoring="accuracy")
+	plot_learning_curve(train_sizes, train_scores, test_scores, model_name)
+
+	# Calculate prediction probabilities for ROC curve
+	y_score = model.predict_proba(X_test)[:, 1]
+
+	# Plot ROC curve using the function from utilities
+	plot_roc_curve(y_test, y_score, model_name=model_name)
+
+if __name__ == "__main__":
+	train()
 
-# Prediction function
-def predict(prediction_input):
-	if len(prediction_input) == 0:
-		return
-	input_df = pd.DataFrame(prediction_input, columns=X_train.columns)
-	input_df[numerical_columns] = sc.transform(input_df[numerical_columns])
-	return ["anomaly" if x == 1 else "normal" for x in model.predict(input_df)]