Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
main.py 17.97 KiB
import tkinter as tk
from tkinter import scrolledtext, ttk, Menu
import subprocess
import threading
import os
import sys
import duckdb

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'machine_learning_models')))
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Path to the Python executable in the virtual environment, please replace if it divers with your
python_executable = os.path.join(base_dir, ".venv", "Scripts", "python.exe")

from machine_learning_models import utilities as util

train_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'machine_learning_models', 'nsl-kdd-dataset', 'KDDTrain+.arff'))
test_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'machine_learning_models', 'nsl-kdd-dataset', 'KDDTest+.arff'))

df_train, df_test, model_name = util.import_data(
    train_file_path=train_file_path,
    test_file_path=test_file_path,
    model_name=None
)

from matplotlib import pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

class PacketCaptureGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("Intrusion Detection Systems")

        # Create a menu bar
        self.menu_bar = Menu(root)
        root.config(menu=self.menu_bar)

        # Add a "File" menu
        self.file_menu = Menu(self.menu_bar, tearoff=0)
        self.menu_bar.add_cascade(label="File", menu=self.file_menu)
        self.file_menu.add_command(label="Exit", command=root.quit)

        # Add a "View" menu with a "Statistics" option
        self.view_menu = Menu(self.menu_bar, tearoff=0)
        self.menu_bar.add_cascade(label="View", menu=self.view_menu)
        self.view_menu.add_command(label="Statistics", command=self.show_statistics)
        self.view_menu.add_command(label="Graph Viewer", command=self.open_graph_viewer)
        self.view_menu.add_command(label="Packet Data", command=self.show_packet_data)  # New menu item

        # Add a "Train" menu with a "Train Model" option
        self.train_menu = Menu(self.menu_bar, tearoff=0)
        self.menu_bar.add_cascade(label="Train", menu=self.train_menu)
        self.train_menu.add_command(label="Train Model", command=self.train_model)

        # Initialize variables for packet statistics
        self.captured_packets = 0
        self.suspicious_packets = 0
        self.non_suspicious_packets = 0

        # Create a frame for the packet display and model selection
        self.main_frame = tk.Frame(root)
        self.main_frame.pack(padx=10, pady=10)

        # Create a scrolled text widget to display packets
        self.packet_display = scrolledtext.ScrolledText(self.main_frame, width=80, height=20)
        self.packet_display.pack(side=tk.LEFT, padx=5)

        # Create a frame for the model selection
        self.model_frame = tk.Frame(self.main_frame)
        self.model_frame.pack(side=tk.LEFT, padx=5)

        # Add radio buttons for model selection
        self.model_var = tk.StringVar(value="rule_based")
        self.rule_based_radio = tk.Radiobutton(self.model_frame, text="Rule-Based", variable=self.model_var, value="rule_based", command=self.toggle_model_selection)
        self.rule_based_radio.pack(anchor=tk.W)

        # Add a button to start packet capturing
        self.start_button = tk.Button(root, text="Start Capturing", command=self.start_capturing)
        self.start_button.pack(pady=10)

        # Add a button to stop packet capturing
        self.stop_button = tk.Button(root, text="Stop Capturing", command=self.stop_capturing, state=tk.DISABLED)
        self.stop_button.pack(pady=10)

        # Initially hide the model selection
        self.toggle_model_selection()

    def toggle_model_selection(self):
        pass  # No need to toggle anything since we removed the ML model selection

    def start_capturing(self):
        # Disable UI elements
        self.rule_based_radio.config(state=tk.DISABLED)
        self.start_button.config(state=tk.DISABLED)
        self.stop_button.config(state=tk.NORMAL)
        self.disable_menu_items()

        if self.model_var.get() == "rule_based":
            self.packet_display.insert(tk.END, "Starting packet capture with rule-based model...\n")
            self.packet_display.see(tk.END)

        script_path = os.path.join(base_dir, "code", "package_capture", "src", "packet_capturing.py")

        # Start packet_capturing.py and capture output
        self.process = subprocess.Popen(
            [python_executable, "-u", script_path],  # -u ensures unbuffered output
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1
        )

        # Use a thread to update the text field
        self.capture_thread = threading.Thread(target=self.update_output, daemon=True)
        self.capture_thread.start()

    def update_output(self):
        """Read output from the subprocess and update the text field."""
        for line in iter(self.process.stdout.readline, ''):  # Keeps reading until EOF
            if "Packet captured" in line:
                self.captured_packets += 1
            if "WARNING" in line:
                self.suspicious_packets += 1
            self.non_suspicious_packets = self.captured_packets - self.suspicious_packets
            self.packet_display.insert(tk.END, line)
            self.packet_display.see(tk.END)

    def stop_capturing(self):
        if self.process:
            self.process.terminate()  # Stop the subprocess
            self.process.wait()
            self.packet_display.insert(tk.END, "Stopping packet capture...\n")
            self.packet_display.see(tk.END)

        # Re-enable UI elements
        self.rule_based_radio.config(state=tk.NORMAL)
        self.start_button.config(state=tk.NORMAL)
        self.stop_button.config(state=tk.DISABLED)
        self.enable_menu_items()

    def show_statistics(self):
        self.disable_menu_items()
        self.start_button.config(state=tk.DISABLED)
        # Create a new window for statistics
        stats_window = tk.Toplevel(self.root)
        stats_window.title("Statistics")

        # Create a figure for the graph
        fig, ax = plt.subplots()
        ax.bar(["Captured", "Suspicious", "Non-Suspicious"], [self.captured_packets, self.suspicious_packets, self.non_suspicious_packets])
        ax.set_ylabel("Number of Packets")
        ax.set_title("Packet Capture Statistics")

        # Display the graph in the new window
        canvas = FigureCanvasTkAgg(fig, master=stats_window)
        canvas.draw()
        canvas.get_tk_widget().pack()

        stats_window.protocol("WM_DELETE_WINDOW", lambda: self.on_close_window(stats_window))

    def train_model(self):
        self.disable_menu_items()
        self.start_button.config(state=tk.DISABLED)
        # Create a new window for model training
        train_window = tk.Toplevel(self.root)
        train_window.title("Train Model")

        # Add a label and combobox for selecting a machine learning model
        model_label = tk.Label(train_window, text="Select Model:")
        model_label.pack(anchor=tk.W)
        model_combobox = ttk.Combobox(train_window)
        model_combobox['values'] = ["Decision Tree", "Random Forest", "KNN", "Logistic Regression"]
        model_combobox.pack(anchor=tk.W)

        # Add a button to start training
        start_training_button = tk.Button(train_window, text="Start Training", command=lambda: self.start_training(model_combobox.get()))
        start_training_button.pack(pady=10)

        train_window.protocol("WM_DELETE_WINDOW", lambda: self.on_close_window(train_window))

    def start_training(self, model_name):
        self.packet_display.insert(tk.END, f"Starting training with {model_name} model...\n")
        self.packet_display.see(tk.END)

        if model_name == "Random Forest":
            from machine_learning_models import random_forest as model
        elif model_name == "Decision Tree":
            from machine_learning_models import decision_tree as model
        elif model_name == "Logistic Regression":
            from machine_learning_models import logistic_regression as model
        elif model_name == "KNN":
            from machine_learning_models import knn as model

        # Call the train function from the selected model module
        model.train()
        self.packet_display.insert(tk.END, "Training complete.\n")
        # Call the graph function to display the graphs
        model.graphs()
        self.packet_display.see(tk.END)
        self.show_plots(model_name)
        self.packet_display.insert(tk.END, "Graphs displayed.\n")

    def show_plots(self, model_name):
        # Create a new window for plots
        plots_window = tk.Toplevel(self.root)
        plots_window.title("Model Training Plots")
        plots_window.geometry("1200x800")  # Set the window size

        self.graph_list = self.get_graph_list(model_name=model_name)
        self.current_graph = None

        self.listbox = tk.Listbox(plots_window, width=50)
        self.listbox.pack(side=tk.LEFT, fill=tk.Y)
        for graph in self.graph_list:
            self.listbox.insert(tk.END, graph)

        self.listbox.bind("<<ListboxSelect>>", self.on_graph_select)

        self.canvas_frame = tk.Frame(plots_window)
        self.canvas_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

        self.canvas = None

        plots_window.bind("<MouseWheel>", self.on_mouse_wheel)
        plots_window.bind("<ButtonPress-1>", self.on_button_press)
        plots_window.bind("<ButtonRelease-1>", self.on_button_release)
        plots_window.bind("<Motion>", self.on_mouse_move)

        self.dragging = False
        self.last_mouse_pos = None

    def get_graph_list(self, model_name):
        return [f for f in os.listdir('resulting_figures') if f.endswith('.png') and model_name in f]

    def on_graph_select(self, event):
        selected_index = self.listbox.curselection()
        if selected_index:
            graph_file = self.graph_list[selected_index[0]]
            self.display_graph(graph_file)

    def display_graph(self, graph_file):
        if self.canvas:
            self.canvas.get_tk_widget().destroy()

        img = plt.imread(os.path.join("resulting_figures", graph_file))
        fig, self.ax = plt.subplots()
        self.ax.imshow(img)
        self.ax.axis('off')

        self.canvas = FigureCanvasTkAgg(fig, master=self.canvas_frame)
        self.canvas.draw()
        self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

    def zoom(self, event, scale_factor):
        x, y = self.canvas.get_tk_widget().winfo_pointerx(), self.canvas.get_tk_widget().winfo_pointery()
        x, y = self.ax.transData.inverted().transform((x, y))
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - x) / (cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - y) / (cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim([x - new_width * (1 - relx), x + new_width * relx])
        self.ax.set_ylim([y - new_height * (1 - rely), y + new_height * rely])

        self.canvas.draw()

    def on_mouse_wheel(self, event):
        scale_factor = 0.8 if event.delta > 0 else 1.2
        self.zoom(event, scale_factor)

    def on_button_press(self, event):
        if event.num == 1:
            self.dragging = True
            self.last_mouse_pos = (event.x, event.y)

    def on_button_release(self, event):
        if event.num == 1:
            self.dragging = False
            self.last_mouse_pos = None

    def on_mouse_move(self, event):
        if self.dragging and self.last_mouse_pos:
            dx = self.last_mouse_pos[0] - event.x
            dy = self.last_mouse_pos[1] - event.y

            cur_xlim = self.ax.get_xlim()
            cur_ylim = self.ax.get_ylim()

            self.ax.set_xlim([cur_xlim[0] + dx, cur_xlim[1] + dx])
            self.ax.set_ylim([cur_ylim[0] + dy, cur_ylim[1] + dy])

            self.canvas.draw()
            self.last_mouse_pos = (event.x, event.y)

    def open_graph_viewer(self):
        self.disable_menu_items()
        self.start_button.config(state=tk.DISABLED)
        viewer_window = tk.Toplevel(self.root)
        GraphViewer(viewer_window, self)

    def show_packet_data(self):
        self.disable_menu_items()
        self.start_button.config(state=tk.DISABLED)
        # Create a new window for packet data
        data_window = tk.Toplevel(self.root)
        data_window.title("Packet Data")

        # Create a Treeview widget to display packet data
        tree = ttk.Treeview(data_window, columns=("source_ip", "destination_ip", "source_mac", "destination_mac", "source_port", "destination_port", "is_dangerous", "type_of_threat"), show="headings")
        tree.pack(fill=tk.BOTH, expand=True)

        # Define column headings
        tree.heading("source_ip", text="Source IP")
        tree.heading("destination_ip", text="Destination IP")
        tree.heading("source_mac", text="Source MAC")
        tree.heading("destination_mac", text="Destination MAC")
        tree.heading("source_port", text="Source Port")
        tree.heading("destination_port", text="Destination Port")
        tree.heading("is_dangerous", text="Is Dangerous")
        tree.heading("type_of_threat", text="Type of Threat")

        # Connect to the DuckDB database and fetch data
        conn = duckdb.connect('code/packets_data.duckdb')
        query = "SELECT * FROM packets"
        result = conn.execute(query).fetchall()

        # Insert data into the Treeview widget
        for row in result:
            tree.insert("", tk.END, values=row)

        conn.close()

        data_window.protocol("WM_DELETE_WINDOW", lambda: self.on_close_window(data_window))

    def disable_menu_items(self):
        self.view_menu.entryconfig("Statistics", state=tk.DISABLED)
        self.view_menu.entryconfig("Graph Viewer", state=tk.DISABLED)
        self.view_menu.entryconfig("Packet Data", state=tk.DISABLED)
        self.train_menu.entryconfig("Train Model", state=tk.DISABLED)

    def enable_menu_items(self):
        self.view_menu.entryconfig("Statistics", state=tk.NORMAL)
        self.view_menu.entryconfig("Graph Viewer", state=tk.NORMAL)
        self.view_menu.entryconfig("Packet Data", state=tk.NORMAL)
        self.train_menu.entryconfig("Train Model", state=tk.NORMAL)
        self.start_button.config(state=tk.NORMAL)

    def on_close_window(self, window):
        window.destroy()
        self.enable_menu_items()

class GraphViewer:
    def __init__(self, root, parent):
        self.root = root
        self.parent = parent
        self.root.title("Graph Viewer")
        self.root.geometry("1200x800")

        self.graph_list = self.get_graph_list()
        self.current_graph = None

        self.create_widgets()

        self.dragging = False
        self.last_mouse_pos = None

        self.root.protocol("WM_DELETE_WINDOW", self.on_close)

    def create_widgets(self):
        self.listbox = tk.Listbox(self.root, width=50)
        self.listbox.pack(side=tk.LEFT, fill=tk.Y)
        for graph in self.graph_list:
            self.listbox.insert(tk.END, graph)

        self.listbox.bind("<<ListboxSelect>>", self.on_graph_select)

        self.canvas_frame = tk.Frame(self.root)
        self.canvas_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

        self.canvas = None

    def get_graph_list(self):
        return [f for f in os.listdir("resulting_figures") if f.endswith(".png")]

    def on_graph_select(self, event):
        selected_index = self.listbox.curselection()
        if selected_index:
            graph_file = self.graph_list[selected_index[0]]
            self.display_graph(graph_file)

    def display_graph(self, graph_file):
        if self.canvas:
            self.canvas.get_tk_widget().destroy()

        img = plt.imread(os.path.join("resulting_figures", graph_file))
        self.fig, self.ax = plt.subplots()
        self.ax.imshow(img)
        self.ax.axis('off')

        self.canvas = FigureCanvasTkAgg(self.fig, master=self.canvas_frame)
        self.canvas.draw()
        self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        self.canvas.mpl_connect("scroll_event", self.on_mouse_wheel)
        self.canvas.mpl_connect("button_press_event", self.on_button_press)
        self.canvas.mpl_connect("button_release_event", self.on_button_release)
        self.canvas.mpl_connect("motion_notify_event", self.on_mouse_move)

    def zoom(self, event, scale_factor):
        x, y = event.xdata, event.ydata
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - x) / (cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - y) / (cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim([x - new_width * (1 - relx), x + new_width * relx])
        self.ax.set_ylim([y - new_height * (1 - rely), y + new_height * rely])

        self.canvas.draw()

    def on_mouse_wheel(self, event):
        scale_factor = 0.8 if event.button == 'up' else 1.2
        self.zoom(event, scale_factor)

    def on_button_press(self, event):
        if event.button == 1:
            self.dragging = True
            self.last_mouse_pos = (event.x, event.y)

    def on_button_release(self, event):
        if event.button == 1:
            self.dragging = False
            self.last_mouse_pos = None

    def on_mouse_move(self, event):
        if self.dragging and self.last_mouse_pos:
            dx = self.last_mouse_pos[0] - event.x
            dy = event.y - self.last_mouse_pos[1]

            cur_xlim = self.ax.get_xlim()
            cur_ylim = self.ax.get_ylim()

            self.ax.set_xlim([cur_xlim[0] + dx, cur_xlim[1] + dx])
            self.ax.set_ylim([cur_ylim[0] + dy, cur_ylim[1] + dy])

            self.canvas.draw()
            self.last_mouse_pos = (event.x, event.y)

    def on_close(self):
        self.root.destroy()
        self.parent.enable_menu_items()

if __name__ == "__main__":
    root = tk.Tk()
    app = PacketCaptureGUI(root)
    root.mainloop()