import argparse
import asyncio
import os
import shutil
import socket
import socketserver
import subprocess
import sys
import time
from argparse import Namespace
from typing import List, Tuple, Union


def tmux_cp(data: str) -> None:
  if os.environ.get("TMUX", None):
    subprocess.run(["tmux", "set-buffer", data])


def clippy(data: str) -> None:

  if shutil.which("pbcopy"):
    subprocess.run(["pbcopy"], input=data.encode())
  else:
    raise NotImplementedError("⚠️ No clipboard integration ⚠️")


def clean_socket(path: str) -> None:
  try:
    os.unlink(path)
  except IOError:
    pass


def locate_prog() -> str:
  home = os.environ["HOME"]
  canonical = os.path.realpath(__file__)

  if canonical.startswith(home):
    p = os.path.join(os.getcwd(), canonical)
    prog = os.path.relpath(p, home)
    return f"$HOME/{prog}"
  else:
    return canonical


def supervise(args: List[str]) -> None:
  prog = locate_prog()
  process = subprocess.Popen(
      ["ssh", *args, f"{prog} --daemon"],
      stdout=subprocess.PIPE,
      stderr=subprocess.PIPE)

  buf = bytearray()
  while True:
    code = process.poll()
    if code:
      print(f"ssh exited - {code}", file=sys.stderr)
      print(process.stderr.read().decode(), file=sys.stderr)
      print("\a")
      return

    line: bytes = process.stdout.readline()
    for b in line:
      if b == 0:
        data: str = buf.decode()
        buf = bytearray()
        tmux_cp(data)
        clippy(data)
        break
      else:
        buf.append(b)


def remote_daemon(path: str) -> None:
  class Handler(socketserver.BaseRequestHandler):
    def handle(self) -> None:
      with self.request.makefile() as fd:
        data: str = fd.read()
        sys.stdout.write(data)
        sys.stdout.flush()

  clean_socket(path)
  with socketserver.UnixStreamServer(path, Handler) as srv:
    srv.serve_forever()


def remote_copy(path: str, data: str) -> None:
  with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
    sock.connect(path)
    sock.sendall(data.encode())
    sock.sendall(b'\0\n')


def parse_args() -> Tuple[bool, Union[Namespace, List[str]]]:
  if len(sys.argv) > 1 and sys.argv[1] == "-d":
    return True, sys.argv[2:]
  else:
    runtime_dir = os.environ.get(
        "XDG_RUNTIME_DIR", os.path.join(os.environ["HOME"], ".ssh"))
    socket_path = os.path.join(runtime_dir, "copy_socket")
    parser = argparse.ArgumentParser()
    parser.add_argument("--tmux", action="store_true")
    parser.add_argument("--daemon", action="store_true")
    parser.add_argument("--socket", default=socket_path)
    return False, parser.parse_args()


def run() -> None:
  ssh, args = parse_args()
  if ssh:
    while True:
      supervise(args)
      time.sleep(1)
  elif args.daemon:
    remote_daemon(args.socket)
  elif os.environ.get("SSH_TTY") is not None:
    data: str = sys.stdin.read().strip("\n")
    if not args.tmux:
      tmux_cp(data)
    remote_copy(args.socket, data)
  else:
    data: str = sys.stdin.read().strip("\n")
    if not args.tmux:
      tmux_cp(data)
    clippy(data)


def main() -> None:
  try:
    run()
  except KeyboardInterrupt:
    pass
  except Exception as e:
    print(e, file=sys.stderr)
    print("\a")
    exit(1)


main()
