import tkinter as tk
from tkinter import scrolledtext, ttk, Menu
import subprocess
import threading
import os
import sys
import duckdb

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'machine_learning_models')))
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Path to the Python executable in the virtual environment, please replace if it divers with your
python_executable = os.path.join(base_dir, ".venv", "Scripts", "python.exe")

from machine_learning_models import utilities as util

train_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'machine_learning_models', 'nsl-kdd-dataset', 'KDDTrain+.arff'))
test_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'machine_learning_models', 'nsl-kdd-dataset', 'KDDTest+.arff'))

df_train, df_test, model_name = util.import_data(
    train_file_path=train_file_path,
    test_file_path=test_file_path,
    model_name=None
)

from matplotlib import pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

class PacketCaptureGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("Intrusion Detection Systems")

        # Create a menu bar
        self.menu_bar = Menu(root)
        root.config(menu=self.menu_bar)

        # Add a "File" menu
        self.file_menu = Menu(self.menu_bar, tearoff=0)
        self.menu_bar.add_cascade(label="File", menu=self.file_menu)
        self.file_menu.add_command(label="Exit", command=root.quit)

        # Add a "View" menu with a "Statistics" option
        self.view_menu = Menu(self.menu_bar, tearoff=0)
        self.menu_bar.add_cascade(label="View", menu=self.view_menu)
        self.view_menu.add_command(label="Statistics", command=self.show_statistics)
        self.view_menu.add_command(label="Graph Viewer", command=self.open_graph_viewer)
        self.view_menu.add_command(label="Packet Data", command=self.show_packet_data)  # New menu item

        # Add a "Train" menu with a "Train Model" option
        self.train_menu = Menu(self.menu_bar, tearoff=0)
        self.menu_bar.add_cascade(label="Train", menu=self.train_menu)
        self.train_menu.add_command(label="Train Model", command=self.train_model)

        # Initialize variables for packet statistics
        self.captured_packets = 0
        self.suspicious_packets = 0
        self.non_suspicious_packets = 0

        # Create a frame for the packet display and model selection
        self.main_frame = tk.Frame(root)
        self.main_frame.pack(padx=10, pady=10)

        # Create a scrolled text widget to display packets
        self.packet_display = scrolledtext.ScrolledText(self.main_frame, width=80, height=20)
        self.packet_display.pack(side=tk.LEFT, padx=5)

        # Create a frame for the model selection
        self.model_frame = tk.Frame(self.main_frame)
        self.model_frame.pack(side=tk.LEFT, padx=5)

        # Add radio buttons for model selection
        self.model_var = tk.StringVar(value="rule_based")
        self.rule_based_radio = tk.Radiobutton(self.model_frame, text="Rule-Based", variable=self.model_var, value="rule_based", command=self.toggle_model_selection)
        self.rule_based_radio.pack(anchor=tk.W)

        # Add a button to start packet capturing
        self.start_button = tk.Button(root, text="Start Capturing", command=self.start_capturing)
        self.start_button.pack(pady=10)

        # Add a button to stop packet capturing
        self.stop_button = tk.Button(root, text="Stop Capturing", command=self.stop_capturing, state=tk.DISABLED)
        self.stop_button.pack(pady=10)

        # Initially hide the model selection
        self.toggle_model_selection()

    def toggle_model_selection(self):
        pass  # No need to toggle anything since we removed the ML model selection

    def start_capturing(self):
        # Disable UI elements
        self.rule_based_radio.config(state=tk.DISABLED)
        self.start_button.config(state=tk.DISABLED)
        self.stop_button.config(state=tk.NORMAL)
        self.disable_menu_items()

        if self.model_var.get() == "rule_based":
            self.packet_display.insert(tk.END, "Starting packet capture with rule-based model...\n")
            self.packet_display.see(tk.END)

        script_path = os.path.join(base_dir, "code", "package_capture", "src", "packet_capturing.py")

        # Start packet_capturing.py and capture output
        self.process = subprocess.Popen(
            [python_executable, "-u", script_path],  # -u ensures unbuffered output
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1
        )

        # Use a thread to update the text field
        self.capture_thread = threading.Thread(target=self.update_output, daemon=True)
        self.capture_thread.start()

    def update_output(self):
        """Read output from the subprocess and update the text field."""
        for line in iter(self.process.stdout.readline, ''):  # Keeps reading until EOF
            if "Packet captured" in line:
                self.captured_packets += 1
            if "WARNING" in line:
                self.suspicious_packets += 1
            self.non_suspicious_packets = self.captured_packets - self.suspicious_packets
            self.packet_display.insert(tk.END, line)
            self.packet_display.see(tk.END)

    def stop_capturing(self):
        if self.process:
            self.process.terminate()  # Stop the subprocess
            self.process.wait()
            self.packet_display.insert(tk.END, "Stopping packet capture...\n")
            self.packet_display.see(tk.END)

        # Re-enable UI elements
        self.rule_based_radio.config(state=tk.NORMAL)
        self.start_button.config(state=tk.NORMAL)
        self.stop_button.config(state=tk.DISABLED)
        self.enable_menu_items()

    def show_statistics(self):
        self.disable_menu_items()
        self.start_button.config(state=tk.DISABLED)
        # Create a new window for statistics
        stats_window = tk.Toplevel(self.root)
        stats_window.title("Statistics")

        # Create a figure for the graph
        fig, ax = plt.subplots()
        ax.bar(["Captured", "Suspicious", "Non-Suspicious"], [self.captured_packets, self.suspicious_packets, self.non_suspicious_packets])
        ax.set_ylabel("Number of Packets")
        ax.set_title("Packet Capture Statistics")

        # Display the graph in the new window
        canvas = FigureCanvasTkAgg(fig, master=stats_window)
        canvas.draw()
        canvas.get_tk_widget().pack()

        stats_window.protocol("WM_DELETE_WINDOW", lambda: self.on_close_window(stats_window))

    def train_model(self):
        self.disable_menu_items()
        self.start_button.config(state=tk.DISABLED)
        # Create a new window for model training
        train_window = tk.Toplevel(self.root)
        train_window.title("Train Model")

        # Add a label and combobox for selecting a machine learning model
        model_label = tk.Label(train_window, text="Select Model:")
        model_label.pack(anchor=tk.W)
        model_combobox = ttk.Combobox(train_window)
        model_combobox['values'] = ["Decision Tree", "Random Forest", "KNN", "Logistic Regression"]
        model_combobox.pack(anchor=tk.W)

        # Add a button to start training
        start_training_button = tk.Button(train_window, text="Start Training", command=lambda: self.start_training(model_combobox.get()))
        start_training_button.pack(pady=10)

        train_window.protocol("WM_DELETE_WINDOW", lambda: self.on_close_window(train_window))

    def start_training(self, model_name):
        self.packet_display.insert(tk.END, f"Starting training with {model_name} model...\n")
        self.packet_display.see(tk.END)

        if model_name == "Random Forest":
            from machine_learning_models import random_forest as model
        elif model_name == "Decision Tree":
            from machine_learning_models import decision_tree as model
        elif model_name == "Logistic Regression":
            from machine_learning_models import logistic_regression as model
        elif model_name == "KNN":
            from machine_learning_models import knn as model

        # Call the train function from the selected model module
        model.train()
        self.packet_display.insert(tk.END, "Training complete.\n")
        # Call the graph function to display the graphs
        model.graphs()
        self.packet_display.see(tk.END)
        self.show_plots(model_name)
        self.packet_display.insert(tk.END, "Graphs displayed.\n")

    def show_plots(self, model_name):
        # Create a new window for plots
        plots_window = tk.Toplevel(self.root)
        plots_window.title("Model Training Plots")
        plots_window.geometry("1200x800")  # Set the window size

        self.graph_list = self.get_graph_list(model_name=model_name)
        self.current_graph = None

        self.listbox = tk.Listbox(plots_window, width=50)
        self.listbox.pack(side=tk.LEFT, fill=tk.Y)
        for graph in self.graph_list:
            self.listbox.insert(tk.END, graph)

        self.listbox.bind("<<ListboxSelect>>", self.on_graph_select)

        self.canvas_frame = tk.Frame(plots_window)
        self.canvas_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

        self.canvas = None

        plots_window.bind("<MouseWheel>", self.on_mouse_wheel)
        plots_window.bind("<ButtonPress-1>", self.on_button_press)
        plots_window.bind("<ButtonRelease-1>", self.on_button_release)
        plots_window.bind("<Motion>", self.on_mouse_move)

        self.dragging = False
        self.last_mouse_pos = None

    def get_graph_list(self, model_name):
        return [f for f in os.listdir('resulting_figures') if f.endswith('.png') and model_name in f]

    def on_graph_select(self, event):
        selected_index = self.listbox.curselection()
        if selected_index:
            graph_file = self.graph_list[selected_index[0]]
            self.display_graph(graph_file)

    def display_graph(self, graph_file):
        if self.canvas:
            self.canvas.get_tk_widget().destroy()

        img = plt.imread(os.path.join("resulting_figures", graph_file))
        fig, self.ax = plt.subplots()
        self.ax.imshow(img)
        self.ax.axis('off')

        self.canvas = FigureCanvasTkAgg(fig, master=self.canvas_frame)
        self.canvas.draw()
        self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

    def zoom(self, event, scale_factor):
        x, y = self.canvas.get_tk_widget().winfo_pointerx(), self.canvas.get_tk_widget().winfo_pointery()
        x, y = self.ax.transData.inverted().transform((x, y))
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - x) / (cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - y) / (cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim([x - new_width * (1 - relx), x + new_width * relx])
        self.ax.set_ylim([y - new_height * (1 - rely), y + new_height * rely])

        self.canvas.draw()

    def on_mouse_wheel(self, event):
        scale_factor = 0.8 if event.delta > 0 else 1.2
        self.zoom(event, scale_factor)

    def on_button_press(self, event):
        if event.num == 1:
            self.dragging = True
            self.last_mouse_pos = (event.x, event.y)

    def on_button_release(self, event):
        if event.num == 1:
            self.dragging = False
            self.last_mouse_pos = None

    def on_mouse_move(self, event):
        if self.dragging and self.last_mouse_pos:
            dx = self.last_mouse_pos[0] - event.x
            dy = self.last_mouse_pos[1] - event.y

            cur_xlim = self.ax.get_xlim()
            cur_ylim = self.ax.get_ylim()

            self.ax.set_xlim([cur_xlim[0] + dx, cur_xlim[1] + dx])
            self.ax.set_ylim([cur_ylim[0] + dy, cur_ylim[1] + dy])

            self.canvas.draw()
            self.last_mouse_pos = (event.x, event.y)

    def open_graph_viewer(self):
        self.disable_menu_items()
        self.start_button.config(state=tk.DISABLED)
        viewer_window = tk.Toplevel(self.root)
        GraphViewer(viewer_window, self)

    def show_packet_data(self):
        self.disable_menu_items()
        self.start_button.config(state=tk.DISABLED)
        # Create a new window for packet data
        data_window = tk.Toplevel(self.root)
        data_window.title("Packet Data")

        # Create a Treeview widget to display packet data
        tree = ttk.Treeview(data_window, columns=("source_ip", "destination_ip", "source_mac", "destination_mac", "source_port", "destination_port", "is_dangerous", "type_of_threat"), show="headings")
        tree.pack(fill=tk.BOTH, expand=True)

        # Define column headings
        tree.heading("source_ip", text="Source IP")
        tree.heading("destination_ip", text="Destination IP")
        tree.heading("source_mac", text="Source MAC")
        tree.heading("destination_mac", text="Destination MAC")
        tree.heading("source_port", text="Source Port")
        tree.heading("destination_port", text="Destination Port")
        tree.heading("is_dangerous", text="Is Dangerous")
        tree.heading("type_of_threat", text="Type of Threat")

        # Connect to the DuckDB database and fetch data
        conn = duckdb.connect('code/packets_data.duckdb')
        query = "SELECT * FROM packets"
        result = conn.execute(query).fetchall()

        # Insert data into the Treeview widget
        for row in result:
            tree.insert("", tk.END, values=row)

        conn.close()

        data_window.protocol("WM_DELETE_WINDOW", lambda: self.on_close_window(data_window))

    def disable_menu_items(self):
        self.view_menu.entryconfig("Statistics", state=tk.DISABLED)
        self.view_menu.entryconfig("Graph Viewer", state=tk.DISABLED)
        self.view_menu.entryconfig("Packet Data", state=tk.DISABLED)
        self.train_menu.entryconfig("Train Model", state=tk.DISABLED)

    def enable_menu_items(self):
        self.view_menu.entryconfig("Statistics", state=tk.NORMAL)
        self.view_menu.entryconfig("Graph Viewer", state=tk.NORMAL)
        self.view_menu.entryconfig("Packet Data", state=tk.NORMAL)
        self.train_menu.entryconfig("Train Model", state=tk.NORMAL)
        self.start_button.config(state=tk.NORMAL)

    def on_close_window(self, window):
        window.destroy()
        self.enable_menu_items()

class GraphViewer:
    def __init__(self, root, parent):
        self.root = root
        self.parent = parent
        self.root.title("Graph Viewer")
        self.root.geometry("1200x800")

        self.graph_list = self.get_graph_list()
        self.current_graph = None

        self.create_widgets()

        self.dragging = False
        self.last_mouse_pos = None

        self.root.protocol("WM_DELETE_WINDOW", self.on_close)

    def create_widgets(self):
        self.listbox = tk.Listbox(self.root, width=50)
        self.listbox.pack(side=tk.LEFT, fill=tk.Y)
        for graph in self.graph_list:
            self.listbox.insert(tk.END, graph)

        self.listbox.bind("<<ListboxSelect>>", self.on_graph_select)

        self.canvas_frame = tk.Frame(self.root)
        self.canvas_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

        self.canvas = None

    def get_graph_list(self):
        return [f for f in os.listdir("resulting_figures") if f.endswith(".png")]

    def on_graph_select(self, event):
        selected_index = self.listbox.curselection()
        if selected_index:
            graph_file = self.graph_list[selected_index[0]]
            self.display_graph(graph_file)

    def display_graph(self, graph_file):
        if self.canvas:
            self.canvas.get_tk_widget().destroy()

        img = plt.imread(os.path.join("resulting_figures", graph_file))
        self.fig, self.ax = plt.subplots()
        self.ax.imshow(img)
        self.ax.axis('off')

        self.canvas = FigureCanvasTkAgg(self.fig, master=self.canvas_frame)
        self.canvas.draw()
        self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        self.canvas.mpl_connect("scroll_event", self.on_mouse_wheel)
        self.canvas.mpl_connect("button_press_event", self.on_button_press)
        self.canvas.mpl_connect("button_release_event", self.on_button_release)
        self.canvas.mpl_connect("motion_notify_event", self.on_mouse_move)

    def zoom(self, event, scale_factor):
        x, y = event.xdata, event.ydata
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - x) / (cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - y) / (cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim([x - new_width * (1 - relx), x + new_width * relx])
        self.ax.set_ylim([y - new_height * (1 - rely), y + new_height * rely])

        self.canvas.draw()

    def on_mouse_wheel(self, event):
        scale_factor = 0.8 if event.button == 'up' else 1.2
        self.zoom(event, scale_factor)

    def on_button_press(self, event):
        if event.button == 1:
            self.dragging = True
            self.last_mouse_pos = (event.x, event.y)

    def on_button_release(self, event):
        if event.button == 1:
            self.dragging = False
            self.last_mouse_pos = None

    def on_mouse_move(self, event):
        if self.dragging and self.last_mouse_pos:
            dx = self.last_mouse_pos[0] - event.x
            dy = event.y - self.last_mouse_pos[1]

            cur_xlim = self.ax.get_xlim()
            cur_ylim = self.ax.get_ylim()

            self.ax.set_xlim([cur_xlim[0] + dx, cur_xlim[1] + dx])
            self.ax.set_ylim([cur_ylim[0] + dy, cur_ylim[1] + dy])

            self.canvas.draw()
            self.last_mouse_pos = (event.x, event.y)

    def on_close(self):
        self.root.destroy()
        self.parent.enable_menu_items()

if __name__ == "__main__":
    root = tk.Tk()
    app = PacketCaptureGUI(root)
    root.mainloop()