#
# Copyright (c) 2021 Contributors to the Eclipse Foundation
#
# This program and the accompanying materials are made
# available under the terms of the Eclipse Public License 2.0
# which is available at https://www.eclipse.org/legal/epl-2.0/
#
# SPDX-License-Identifier: EPL-2.0
#

import threading, random, asyncio, queue, ctypes, datetime, subprocess, os, io, json, traceback, sys
from typing import Dict, Callable, List, Optional, Union, Type
import tkinter as tk
from tkinter import ttk
from snakes.nets import PetriNet
if not 'SELF_CONTAINED' in globals():
    from model import Event, EventType, Parameters, Constraint
    from model import Event, Constraint
    from walker import Walker
    nets: Dict[str, Callable[[], PetriNet]]
    constraints: List[Type['Constraint']]


class TestClientWalker():
    def __init__(self, nets: Dict[str, Callable[[], PetriNet]], constraints: List[Type['Constraint']], send_event: Callable[['Event'], None], stopped: Callable[[Optional[str]], None], log: Callable[[str], None]) -> None:
        self.send_event = send_event
        self.stopped = stopped
        self.walker = Walker(nets, constraints, log)
        self.event_queue: queue.Queue[Union[Event, None]] = queue.Queue()
        self.thread: Optional[threading.Thread] = None
        self.stop_requested = False

    def start(self):
        assert self.thread == None, "Already running"
        self.thread = threading.Thread(target=self.__run_non_async)
        self.thread.start()
    
    def stop(self):
        if self.thread != None:
            self.stop_requested = True
            self.event_queue.put(None) # Force run to stop
            if threading.current_thread() != self.thread: self.thread.join()
            self.stop_requested = False

    def received_event(self, event: 'Event'):
        self.event_queue.put(event)

    def __run_non_async(self):
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        loop.run_until_complete(self.__run())
        loop.close()

    async def __run(self):
        port_notification_during_command_transition: Dict[str, List[Event]] = {}
        take_reply_to_cmd: Optional[Event] = None
        stop_on_no_events = False
        error: Optional[str] = None

        self.walker.log("Initial states:")
        for (port, states) in self.walker.states.items():
            for (machine, state) in states.items():
                self.walker.log(f"Port '{port}', machine '{machine}' is in '{state}'")

        try:
            while not self.stop_requested and error == None:
                event: Optional[Event] = None
                try:
                    timeout = 5 if stop_on_no_events else 0.1
                    if self.event_queue.qsize() == 0:
                        event = self.event_queue.get(True, timeout)
                    else:
                        event = self.event_queue.get()
                    if event == None: continue # None means we have to stop (added in stop())
                except queue.Empty:
                    pass

                if event != None:
                    stop_on_no_events = False
                    parameter_place_name = f"P_{event.method}{'_reply' if event.kind == EventType.Reply else ''}"
                    if not event.port in self.walker.nets:
                        error = f"Received event '{str(event)}' from unknown port '{event.port}'"
                        continue
                    if not parameter_place_name in self.walker.nets[event.port]._place:
                        error = f"Event '{event.method}' is unknown for port '{event.port}'"
                        continue
                    place = self.walker.nets[event.port]._place[parameter_place_name]
                    place.add([Parameters([p.value for p in event.parameters])])
                    steps = [e for e in self.walker.next_steps(event.port) if e.event == event]
                    if len(steps) == 0:
                        if event.kind == EventType.Notification and event.port in port_notification_during_command_transition:
                            port_notification_during_command_transition[event.port].append(event)
                        else:
                            error = f"Event '{str(event)}' is not possible"
                    else:
                        self.walker.take_step(random.choice(steps))
                        if event.kind == EventType.Reply:
                            for notification in port_notification_during_command_transition[event.port]:
                                steps = [e for e in self.walker.next_steps(event.port) if e.event == notification]
                                if len(steps) == 0:
                                    error = f"Event '{str(notification)}' is not possible"
                                    break
                                else:
                                    self.walker.take_step(random.choice(steps))
                            del port_notification_during_command_transition[event.port]  
                        elif event.kind == EventType.Command:
                            take_reply_to_cmd = event
                else:
                    steps = []
                    if take_reply_to_cmd != None:
                        steps = self.walker.next_steps(take_reply_to_cmd.port)
                        steps = [c for c in steps if c.event.port == take_reply_to_cmd.port
                            and c.event.kind == EventType.Reply and c.event.method == take_reply_to_cmd.method]
                    else:
                        ports = list(self.walker.nets.keys())
                        random.shuffle(ports)
                        for port in ports:
                            steps = self.walker.next_steps(port)
                            if len(steps) > 0: break

                    if len(steps) == 0:
                        if stop_on_no_events:
                            error = "No next steps possible from test client"
                        else:
                            stop_on_no_events = True
                    else:
                        step = random.choice(steps)
                        assert step.event != None
                        if step.event.kind == EventType.Command:
                            port_notification_during_command_transition[step.event.port] = []
                        self.send_event(step.event)
                        take_reply_to_cmd = None
                        self.walker.take_step(step)
        except Exception as e:
            error = f"Error while running: {repr(e)}"
            traceback.print_exc()

        if not self.stop_requested:
            self.stopped(error)


class TestClient:
    running = False
    stopping_or_starting = False
    start_time: datetime.datetime
    update_ui_timer: threading.Timer
    adapter: subprocess.Popen
    walker: 'TestClientWalker'

    window = tk.Tk()
    cmd_entry: tk.Entry
    log_text: tk.Text
    start_button: ttk.Button
    save_coverage_button: ttk.Button
    running_time_label: tk.Label
    state_coverage_label: tk.Label
    event_coverage_label: tk.Label

    def start_adapter(self):
        def reader(pipe: io.BytesIO, cb: Callable[[str], None]):
            with pipe:
                for line in iter(pipe.readline, b''):
                    cb(line.decode()[:-1])
            self.stop("Adapter stopped")

        self.adapter = subprocess.Popen(self.cmd_entry.get().strip(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
        threading.Thread(target=reader, args=[self.adapter.stdout, self.on_stdout]).start()
        threading.Thread(target=reader, args=[self.adapter.stderr, self.on_stderr]).start()

    def log(self, line: str):
        time = datetime.datetime.now().replace(microsecond=0).isoformat()
        self.log_text.insert(tk.END, f"{time}: {line}\n")
        self.log_text.see(tk.END)

    def send_to_adapter(self, event: 'Event'):
        self.log(f"-> {str(event)}")
        assert self.adapter.stdin != None
        self.adapter.stdin.write((json.dumps(event.to_json()) + "\n").encode())
        self.adapter.stdin.flush()

    def on_stdout(self, line: str):
        try:
            jsn = json.loads(line)
            if jsn['kind'] == 'Adapter': 
                if jsn['type'] == 'started':
                    self.log("Adapter started")
                    self.walker.start()
            elif self.running:
                event = Event.from_json(jsn)
                self.log(f"<- {str(event)}")
                self.walker.received_event(event)
        except Exception as e:
            self.stop(f"Error while processing adapter stdout: {repr(e)}")
            traceback.print_exc()

    def on_stderr(self, line: str):
        self.log(f"Adapter: {line}")

    def show_ui(self):
        if sys.platform == 'win32':
            ctypes.windll.shcore.SetProcessDpiAwareness(1)
        self.window.title("Eclipse CommaSuite Test Client")
        self.window.geometry('1200x600')
        self.window.grid_columnconfigure(0, weight=1)
        self.window.grid_rowconfigure(1, weight=1)

        # Command entry
        self.cmd_entry = tk.Entry(self.window)
        self.cmd_entry.insert(0, "")
        self.cmd_entry.grid(row=0, column=0, sticky=tk.EW, pady=2, padx=2)

        # Log
        log_frame = tk.Frame()
        log_frame.grid_rowconfigure(0, weight=1)
        log_frame.grid_columnconfigure(0, weight=1)
        log_frame.grid(row=1, column=0, padx=2, pady=2, sticky=tk.NSEW)
        log_xscroll = tk.Scrollbar(log_frame, orient=tk.HORIZONTAL)
        log_xscroll.grid(row=1, column=0, sticky=tk.NSEW)
        log_yscroll = tk.Scrollbar(log_frame, orient=tk.VERTICAL)
        log_yscroll.grid(row=0, column=1, sticky=tk.NSEW)
        self.log_text = tk.Text(log_frame, wrap=tk.NONE, yscrollcommand=log_yscroll.set, xscrollcommand=log_xscroll.set)
        log_yscroll['command'] = self.log_text.yview # type: ignore
        log_xscroll['command'] = self.log_text.xview # type: ignore
        self.log_text.grid(row=0, column=0, sticky=tk.NSEW) 

        # Right frame
        right_frame = tk.Frame()
        right_frame.grid(row=0, column=1, sticky=tk.N, pady=2, rowspan=2)

        # Start button
        self.start_button = ttk.Button(right_frame, text="Start", command=self.start_stop)
        self.start_button.grid(column=0, row=0, sticky=tk.EW, columnspan=2)
        
        #Save Coverage button
        self.save_coverage_button = ttk.Button(right_frame, text="Save Coverage Info", command=self.save_coverage)
        self.save_coverage_button.grid(column=0, row=4, sticky=tk.W)

        # Labels
        tk.Label(right_frame, text = "Running time:").grid(column=0, row=1, sticky=tk.W)
        tk.Label(right_frame, text = "State coverage:").grid(column=0, row=2, sticky=tk.W)
        tk.Label(right_frame, text = "Event coverage:").grid(column=0, row=3, sticky=tk.W)
        self.running_time_label = tk.Label(right_frame, text = "200s")
        self.running_time_label.grid(column=1, row=1, sticky=tk.W)
        self.state_coverage_label = tk.Label(right_frame, text = "0/10")
        self.state_coverage_label.grid(column=1, row=2, sticky=tk.W)
        self.event_coverage_label = tk.Label(right_frame, text = "0/20")
        self.event_coverage_label.grid(column=1, row=3, sticky=tk.W)

        # Init UI
        self.running_time_label['text'] = '-'
        self.state_coverage_label['text'] = '-'
        self.event_coverage_label['text'] = '-'
        self.update_ui()
        self.window.mainloop()

    def update_ui(self):
        self.start_button['text'] = 'Stop' if self.running else 'Start'
        self.start_button['state'] = 'disabled' if self.stopping_or_starting else 'enabled'
        self.cmd_entry['state'] = 'disabled' if self.running else 'normal'
        if self.running:
            self.running_time_label['text'] = f"{round((datetime.datetime.now() - self.start_time).total_seconds())}s"
            self.state_coverage_label['text'] = f"{len(self.walker.walker.seen_states)}/{len(self.walker.walker.all_states)}"
            self.event_coverage_label['text'] = f"{len(self.walker.walker.seen_events)}/{len(self.walker.walker.all_events)}"

    def update_ui_timer_tick(self):
        self.update_ui()
        self.update_ui_timer = threading.Timer(1, self.update_ui_timer_tick)
        self.update_ui_timer.start()

    def stop(self, reason: Optional[str], force: bool = False):
        if (self.stopping_or_starting and not force) or not self.running: return
        self.stopping_or_starting = True
        self.update_ui()
        self.log(f"Stopping{f': {reason}' if reason != None else ''}")
        self.walker.stop()
        self.adapter.kill()
        self.update_ui_timer.cancel()
        self.stopping_or_starting = False
        self.running = False
        self.log("Stopped")
        self.update_ui()

    def start_stop(self):
        if self.stopping_or_starting: return
        self.stopping_or_starting = True
        self.update_ui()
        
        def handle():
            if not self.running:
                self.log('Starting...')
                self.start_time = datetime.datetime.now()
                self.walker = TestClientWalker(nets, constraints, self.send_to_adapter, self.stop, self.log)
                try:
                    self.start_adapter()
                except Exception as e:
                    self.log(f"Failed to start adapter: '{str(e)}'")
                    self.stopping_or_starting = False
                    self.update_ui()
                    return
                self.update_ui_timer_tick()
                self.running = True
                self.log('Started')
            else:
                self.stop(None, True)
            self.stopping_or_starting = False
            self.update_ui()
            
        threading.Thread(target=handle).start()
        
    def save_coverage(self):
        if not hasattr(self, 'walker'): return
        f= open(os.path.dirname(__file__) + "/state_coverage.txt","w")
        unseen_states = self.walker.walker.all_states.difference(self.walker.walker.seen_states)
        if len(unseen_states) > 0:
            f.write("Uncovered states:\n\n")
            f.write("Port,Interface,State\n")
            self.print_coverage_info(f, unseen_states)
            f.write("\n")
        if len(self.walker.walker.seen_states) > 0:
            f.write("Covered states:\n\n")
            f.write("Port,Interface,State\n")
            self.print_coverage_info(f, self.walker.walker.seen_states)
        f= open(os.path.dirname(__file__) + "/event_coverage.txt","w")
        unseen_events = self.walker.walker.all_events.difference(self.walker.walker.seen_events)
        if len(unseen_events) > 0:
            f.write("Uncovered events:\n\n")
            f.write("Port,Interface,Event\n")
            self.print_coverage_info(f, unseen_events)
            f.write("\n")
        if len(self.walker.walker.seen_events) > 0:
            f.write("Covered events:\n\n")
            f.write("Port,Interface,Event\n")
            self.print_coverage_info(f, self.walker.walker.seen_events)
        f.close()
    
    def print_coverage_info(self, f, info):
        ports = dict()
        for s in info:
            fragments = s.split(".")
            port = fragments[1] + "," + fragments[0]
            content = ".".join(fragments[2:])
            if ports.get(port) == None:
                ports.update({port: {content}})
            else:
                ports.get(port).add(content)
        for k in ports.keys():
            for v in ports[k]:
                f.write(k + "," + v + "\n")

if __name__ == "__main__":
    client = TestClient()
    client.show_ui()
