-
Vladyslav Lubkovskyi authoredVladyslav Lubkovskyi authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
main.py 17.97 KiB
import tkinter as tk
from tkinter import scrolledtext, ttk, Menu
import subprocess
import threading
import os
import sys
import duckdb
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'machine_learning_models')))
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Path to the Python executable in the virtual environment, please replace if it divers with your
python_executable = os.path.join(base_dir, ".venv", "Scripts", "python.exe")
from machine_learning_models import utilities as util
train_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'machine_learning_models', 'nsl-kdd-dataset', 'KDDTrain+.arff'))
test_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'machine_learning_models', 'nsl-kdd-dataset', 'KDDTest+.arff'))
df_train, df_test, model_name = util.import_data(
train_file_path=train_file_path,
test_file_path=test_file_path,
model_name=None
)
from matplotlib import pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
class PacketCaptureGUI:
def __init__(self, root):
self.root = root
self.root.title("Intrusion Detection Systems")
# Create a menu bar
self.menu_bar = Menu(root)
root.config(menu=self.menu_bar)
# Add a "File" menu
self.file_menu = Menu(self.menu_bar, tearoff=0)
self.menu_bar.add_cascade(label="File", menu=self.file_menu)
self.file_menu.add_command(label="Exit", command=root.quit)
# Add a "View" menu with a "Statistics" option
self.view_menu = Menu(self.menu_bar, tearoff=0)
self.menu_bar.add_cascade(label="View", menu=self.view_menu)
self.view_menu.add_command(label="Statistics", command=self.show_statistics)
self.view_menu.add_command(label="Graph Viewer", command=self.open_graph_viewer)
self.view_menu.add_command(label="Packet Data", command=self.show_packet_data) # New menu item
# Add a "Train" menu with a "Train Model" option
self.train_menu = Menu(self.menu_bar, tearoff=0)
self.menu_bar.add_cascade(label="Train", menu=self.train_menu)
self.train_menu.add_command(label="Train Model", command=self.train_model)
# Initialize variables for packet statistics
self.captured_packets = 0
self.suspicious_packets = 0
self.non_suspicious_packets = 0
# Create a frame for the packet display and model selection
self.main_frame = tk.Frame(root)
self.main_frame.pack(padx=10, pady=10)
# Create a scrolled text widget to display packets
self.packet_display = scrolledtext.ScrolledText(self.main_frame, width=80, height=20)
self.packet_display.pack(side=tk.LEFT, padx=5)
# Create a frame for the model selection
self.model_frame = tk.Frame(self.main_frame)
self.model_frame.pack(side=tk.LEFT, padx=5)
# Add radio buttons for model selection
self.model_var = tk.StringVar(value="rule_based")
self.rule_based_radio = tk.Radiobutton(self.model_frame, text="Rule-Based", variable=self.model_var, value="rule_based", command=self.toggle_model_selection)
self.rule_based_radio.pack(anchor=tk.W)
# Add a button to start packet capturing
self.start_button = tk.Button(root, text="Start Capturing", command=self.start_capturing)
self.start_button.pack(pady=10)
# Add a button to stop packet capturing
self.stop_button = tk.Button(root, text="Stop Capturing", command=self.stop_capturing, state=tk.DISABLED)
self.stop_button.pack(pady=10)
# Initially hide the model selection
self.toggle_model_selection()
def toggle_model_selection(self):
pass # No need to toggle anything since we removed the ML model selection
def start_capturing(self):
# Disable UI elements
self.rule_based_radio.config(state=tk.DISABLED)
self.start_button.config(state=tk.DISABLED)
self.stop_button.config(state=tk.NORMAL)
self.disable_menu_items()
if self.model_var.get() == "rule_based":
self.packet_display.insert(tk.END, "Starting packet capture with rule-based model...\n")
self.packet_display.see(tk.END)
script_path = os.path.join(base_dir, "code", "package_capture", "src", "packet_capturing.py")
# Start packet_capturing.py and capture output
self.process = subprocess.Popen(
[python_executable, "-u", script_path], # -u ensures unbuffered output
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1
)
# Use a thread to update the text field
self.capture_thread = threading.Thread(target=self.update_output, daemon=True)
self.capture_thread.start()
def update_output(self):
"""Read output from the subprocess and update the text field."""
for line in iter(self.process.stdout.readline, ''): # Keeps reading until EOF
if "Packet captured" in line:
self.captured_packets += 1
if "WARNING" in line:
self.suspicious_packets += 1
self.non_suspicious_packets = self.captured_packets - self.suspicious_packets
self.packet_display.insert(tk.END, line)
self.packet_display.see(tk.END)
def stop_capturing(self):
if self.process:
self.process.terminate() # Stop the subprocess
self.process.wait()
self.packet_display.insert(tk.END, "Stopping packet capture...\n")
self.packet_display.see(tk.END)
# Re-enable UI elements
self.rule_based_radio.config(state=tk.NORMAL)
self.start_button.config(state=tk.NORMAL)
self.stop_button.config(state=tk.DISABLED)
self.enable_menu_items()
def show_statistics(self):
self.disable_menu_items()
self.start_button.config(state=tk.DISABLED)
# Create a new window for statistics
stats_window = tk.Toplevel(self.root)
stats_window.title("Statistics")
# Create a figure for the graph
fig, ax = plt.subplots()
ax.bar(["Captured", "Suspicious", "Non-Suspicious"], [self.captured_packets, self.suspicious_packets, self.non_suspicious_packets])
ax.set_ylabel("Number of Packets")
ax.set_title("Packet Capture Statistics")
# Display the graph in the new window
canvas = FigureCanvasTkAgg(fig, master=stats_window)
canvas.draw()
canvas.get_tk_widget().pack()
stats_window.protocol("WM_DELETE_WINDOW", lambda: self.on_close_window(stats_window))
def train_model(self):
self.disable_menu_items()
self.start_button.config(state=tk.DISABLED)
# Create a new window for model training
train_window = tk.Toplevel(self.root)
train_window.title("Train Model")
# Add a label and combobox for selecting a machine learning model
model_label = tk.Label(train_window, text="Select Model:")
model_label.pack(anchor=tk.W)
model_combobox = ttk.Combobox(train_window)
model_combobox['values'] = ["Decision Tree", "Random Forest", "KNN", "Logistic Regression"]
model_combobox.pack(anchor=tk.W)
# Add a button to start training
start_training_button = tk.Button(train_window, text="Start Training", command=lambda: self.start_training(model_combobox.get()))
start_training_button.pack(pady=10)
train_window.protocol("WM_DELETE_WINDOW", lambda: self.on_close_window(train_window))
def start_training(self, model_name):
self.packet_display.insert(tk.END, f"Starting training with {model_name} model...\n")
self.packet_display.see(tk.END)
if model_name == "Random Forest":
from machine_learning_models import random_forest as model
elif model_name == "Decision Tree":
from machine_learning_models import decision_tree as model
elif model_name == "Logistic Regression":
from machine_learning_models import logistic_regression as model
elif model_name == "KNN":
from machine_learning_models import knn as model
# Call the train function from the selected model module
model.train()
self.packet_display.insert(tk.END, "Training complete.\n")
# Call the graph function to display the graphs
model.graphs()
self.packet_display.see(tk.END)
self.show_plots(model_name)
self.packet_display.insert(tk.END, "Graphs displayed.\n")
def show_plots(self, model_name):
# Create a new window for plots
plots_window = tk.Toplevel(self.root)
plots_window.title("Model Training Plots")
plots_window.geometry("1200x800") # Set the window size
self.graph_list = self.get_graph_list(model_name=model_name)
self.current_graph = None
self.listbox = tk.Listbox(plots_window, width=50)
self.listbox.pack(side=tk.LEFT, fill=tk.Y)
for graph in self.graph_list:
self.listbox.insert(tk.END, graph)
self.listbox.bind("<<ListboxSelect>>", self.on_graph_select)
self.canvas_frame = tk.Frame(plots_window)
self.canvas_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
self.canvas = None
plots_window.bind("<MouseWheel>", self.on_mouse_wheel)
plots_window.bind("<ButtonPress-1>", self.on_button_press)
plots_window.bind("<ButtonRelease-1>", self.on_button_release)
plots_window.bind("<Motion>", self.on_mouse_move)
self.dragging = False
self.last_mouse_pos = None
def get_graph_list(self, model_name):
return [f for f in os.listdir('resulting_figures') if f.endswith('.png') and model_name in f]
def on_graph_select(self, event):
selected_index = self.listbox.curselection()
if selected_index:
graph_file = self.graph_list[selected_index[0]]
self.display_graph(graph_file)
def display_graph(self, graph_file):
if self.canvas:
self.canvas.get_tk_widget().destroy()
img = plt.imread(os.path.join("resulting_figures", graph_file))
fig, self.ax = plt.subplots()
self.ax.imshow(img)
self.ax.axis('off')
self.canvas = FigureCanvasTkAgg(fig, master=self.canvas_frame)
self.canvas.draw()
self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
def zoom(self, event, scale_factor):
x, y = self.canvas.get_tk_widget().winfo_pointerx(), self.canvas.get_tk_widget().winfo_pointery()
x, y = self.ax.transData.inverted().transform((x, y))
cur_xlim = self.ax.get_xlim()
cur_ylim = self.ax.get_ylim()
new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor
relx = (cur_xlim[1] - x) / (cur_xlim[1] - cur_xlim[0])
rely = (cur_ylim[1] - y) / (cur_ylim[1] - cur_ylim[0])
self.ax.set_xlim([x - new_width * (1 - relx), x + new_width * relx])
self.ax.set_ylim([y - new_height * (1 - rely), y + new_height * rely])
self.canvas.draw()
def on_mouse_wheel(self, event):
scale_factor = 0.8 if event.delta > 0 else 1.2
self.zoom(event, scale_factor)
def on_button_press(self, event):
if event.num == 1:
self.dragging = True
self.last_mouse_pos = (event.x, event.y)
def on_button_release(self, event):
if event.num == 1:
self.dragging = False
self.last_mouse_pos = None
def on_mouse_move(self, event):
if self.dragging and self.last_mouse_pos:
dx = self.last_mouse_pos[0] - event.x
dy = self.last_mouse_pos[1] - event.y
cur_xlim = self.ax.get_xlim()
cur_ylim = self.ax.get_ylim()
self.ax.set_xlim([cur_xlim[0] + dx, cur_xlim[1] + dx])
self.ax.set_ylim([cur_ylim[0] + dy, cur_ylim[1] + dy])
self.canvas.draw()
self.last_mouse_pos = (event.x, event.y)
def open_graph_viewer(self):
self.disable_menu_items()
self.start_button.config(state=tk.DISABLED)
viewer_window = tk.Toplevel(self.root)
GraphViewer(viewer_window, self)
def show_packet_data(self):
self.disable_menu_items()
self.start_button.config(state=tk.DISABLED)
# Create a new window for packet data
data_window = tk.Toplevel(self.root)
data_window.title("Packet Data")
# Create a Treeview widget to display packet data
tree = ttk.Treeview(data_window, columns=("source_ip", "destination_ip", "source_mac", "destination_mac", "source_port", "destination_port", "is_dangerous", "type_of_threat"), show="headings")
tree.pack(fill=tk.BOTH, expand=True)
# Define column headings
tree.heading("source_ip", text="Source IP")
tree.heading("destination_ip", text="Destination IP")
tree.heading("source_mac", text="Source MAC")
tree.heading("destination_mac", text="Destination MAC")
tree.heading("source_port", text="Source Port")
tree.heading("destination_port", text="Destination Port")
tree.heading("is_dangerous", text="Is Dangerous")
tree.heading("type_of_threat", text="Type of Threat")
# Connect to the DuckDB database and fetch data
conn = duckdb.connect('code/packets_data.duckdb')
query = "SELECT * FROM packets"
result = conn.execute(query).fetchall()
# Insert data into the Treeview widget
for row in result:
tree.insert("", tk.END, values=row)
conn.close()
data_window.protocol("WM_DELETE_WINDOW", lambda: self.on_close_window(data_window))
def disable_menu_items(self):
self.view_menu.entryconfig("Statistics", state=tk.DISABLED)
self.view_menu.entryconfig("Graph Viewer", state=tk.DISABLED)
self.view_menu.entryconfig("Packet Data", state=tk.DISABLED)
self.train_menu.entryconfig("Train Model", state=tk.DISABLED)
def enable_menu_items(self):
self.view_menu.entryconfig("Statistics", state=tk.NORMAL)
self.view_menu.entryconfig("Graph Viewer", state=tk.NORMAL)
self.view_menu.entryconfig("Packet Data", state=tk.NORMAL)
self.train_menu.entryconfig("Train Model", state=tk.NORMAL)
self.start_button.config(state=tk.NORMAL)
def on_close_window(self, window):
window.destroy()
self.enable_menu_items()
class GraphViewer:
def __init__(self, root, parent):
self.root = root
self.parent = parent
self.root.title("Graph Viewer")
self.root.geometry("1200x800")
self.graph_list = self.get_graph_list()
self.current_graph = None
self.create_widgets()
self.dragging = False
self.last_mouse_pos = None
self.root.protocol("WM_DELETE_WINDOW", self.on_close)
def create_widgets(self):
self.listbox = tk.Listbox(self.root, width=50)
self.listbox.pack(side=tk.LEFT, fill=tk.Y)
for graph in self.graph_list:
self.listbox.insert(tk.END, graph)
self.listbox.bind("<<ListboxSelect>>", self.on_graph_select)
self.canvas_frame = tk.Frame(self.root)
self.canvas_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
self.canvas = None
def get_graph_list(self):
return [f for f in os.listdir("resulting_figures") if f.endswith(".png")]
def on_graph_select(self, event):
selected_index = self.listbox.curselection()
if selected_index:
graph_file = self.graph_list[selected_index[0]]
self.display_graph(graph_file)
def display_graph(self, graph_file):
if self.canvas:
self.canvas.get_tk_widget().destroy()
img = plt.imread(os.path.join("resulting_figures", graph_file))
self.fig, self.ax = plt.subplots()
self.ax.imshow(img)
self.ax.axis('off')
self.canvas = FigureCanvasTkAgg(self.fig, master=self.canvas_frame)
self.canvas.draw()
self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
self.canvas.mpl_connect("scroll_event", self.on_mouse_wheel)
self.canvas.mpl_connect("button_press_event", self.on_button_press)
self.canvas.mpl_connect("button_release_event", self.on_button_release)
self.canvas.mpl_connect("motion_notify_event", self.on_mouse_move)
def zoom(self, event, scale_factor):
x, y = event.xdata, event.ydata
cur_xlim = self.ax.get_xlim()
cur_ylim = self.ax.get_ylim()
new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor
relx = (cur_xlim[1] - x) / (cur_xlim[1] - cur_xlim[0])
rely = (cur_ylim[1] - y) / (cur_ylim[1] - cur_ylim[0])
self.ax.set_xlim([x - new_width * (1 - relx), x + new_width * relx])
self.ax.set_ylim([y - new_height * (1 - rely), y + new_height * rely])
self.canvas.draw()
def on_mouse_wheel(self, event):
scale_factor = 0.8 if event.button == 'up' else 1.2
self.zoom(event, scale_factor)
def on_button_press(self, event):
if event.button == 1:
self.dragging = True
self.last_mouse_pos = (event.x, event.y)
def on_button_release(self, event):
if event.button == 1:
self.dragging = False
self.last_mouse_pos = None
def on_mouse_move(self, event):
if self.dragging and self.last_mouse_pos:
dx = self.last_mouse_pos[0] - event.x
dy = event.y - self.last_mouse_pos[1]
cur_xlim = self.ax.get_xlim()
cur_ylim = self.ax.get_ylim()
self.ax.set_xlim([cur_xlim[0] + dx, cur_xlim[1] + dx])
self.ax.set_ylim([cur_ylim[0] + dy, cur_ylim[1] + dy])
self.canvas.draw()
self.last_mouse_pos = (event.x, event.y)
def on_close(self):
self.root.destroy()
self.parent.enable_menu_items()
if __name__ == "__main__":
root = tk.Tk()
app = PacketCaptureGUI(root)
root.mainloop()