# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import asyncio
import logging
import random
import string

from typing import Any, Dict, List, Tuple

import aiodns
import aiohttp

from kubernetes_asyncio.client import (
    CoreV1Api,
    AppsV1Api,
    CustomObjectsApi
)

from nv.svc.kubernetes.client import KubernetesClient


class HostnameNotFoundError(Exception):
    """Raised when hostname is not found."""


class _APIError(Exception):
    """Raised when an error is raised calling an API."""

    def __init__(self, status_code: int, details: str):
        self.status_code = status_code
        self.details = details
        super().__init__(f"Error {status_code}: {details}")


class _CSP(object):
    """Custom CSP behaviour base class."""

    def __init__(self, k8s_client: KubernetesClient) -> None:
        """Initialize."""
        self._k8s_client = k8s_client
        self._default_namespace = "omni-streaming"

    def _generate_random_dns_prefix(self, length=6):
        letters = string.ascii_lowercase
        return ''.join(random.choice(letters) for _ in range(length))

    async def on_create(self, profile_data: dict, settings: dict) -> Dict:
        """Process CSP customisations on stream creation."""
        return {}

    async def on_delete(self, data: dict) -> None:
        """Proccess CSP customisations on stream termination."""
        pass

    async def resolve_endpoints(self, session_id: str) -> Tuple[Dict, bool]:
        """Resolve the endpoint to connect to for a given session."""
        return {}, False

    async def _resolve_hostname(self, hostname):
        ips = []

        resolver = aiodns.DNSResolver()
        try:
            ipv4_records = await resolver.query(hostname, 'A')
        except aiodns.error.DNSError as exc:
            logging.warning(f"Failed to resolve DNS for {hostname}. The domain might not have propagated yet if this is a new stream. {exc}")
            return []

        ips.extend([record.host for record in ipv4_records])
        return ips

    async def _fetch_resources(self, resource_type, selectors=None, args: Dict = None):
        api_class = self._get_k8s_api_class(resource_type)

        args = args or {}
        async with self._k8s_client as api_client:
            api_instance = api_class(api_client.api_http)
            func_name = args.pop('func_name', f"list_namespaced_{resource_type}")
            func = getattr(api_instance, func_name)

            if selectors:
                args['label_selector'] = ",".join(f"{k}={v}" for k, v in selectors.items())

            res = await func(
                namespace=self._default_namespace,
                **args
            )
            return res['items'] if isinstance(res, dict) else res.items

    def _get_k8s_api_class(self, resource_type):
        resource_to_api_class = {
            'pod': CoreV1Api,
            'service': CoreV1Api,
            'deployment': AppsV1Api,
            'targetgroupbinding': CustomObjectsApi,

        }
        return resource_to_api_class.get(resource_type, CoreV1Api)

    async def _tcp_port_ready(self, host: str, port: int) -> bool:
        try:
            logging.debug(f"Testing stream readiness on {host}:{port}")
            reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=2)
            writer.close()
            await writer.wait_closed()
            logging.debug(f"Stream connection on {host}:{port} successful")
            return True
        except asyncio.TimeoutError:
            logging.info(f"Failed to connect to {host}:{port}. Connection not ready.")
        except Exception as exc:
            logging.error(f"Failed to connect to {host}:{port}: {exc}")

        return False


class Generic(_CSP):
    """Generic and default CSP manager."""

    def __init__(
        self,
        k8s_client: KubernetesClient,
        enable_wss: bool = False,
        hostname_annotation_key: str = "external-dns.alpha.kubernetes.io/hostname",
        service_annotations_location: str = "streamingKit.service.annotations",
        base_domain: str = ""
    ) -> Any:
        """Initialize."""
        super().__init__(k8s_client)
        self._enable_wss = enable_wss
        self._service_annotations_location = service_annotations_location
        self._hostname_annotation_key = hostname_annotation_key
        self._base_domain = base_domain

    async def on_create(self, profile_data: Dict, settings: Dict) -> Dict:
        """Process CSP customisations on stream creation."""
        if not self._enable_wss:
            return {}

        values = profile_data.get("settings", {}).get("values", {})
        keys = self._service_annotations_location.split(".")
        data = values
        for key in keys:
            data = data.get(key, {})

        prefix = self._generate_random_dns_prefix()
        data[self._hostname_annotation_key] = f"{prefix}.{self._base_domain}"
        settings[self._service_annotations_location] = data

        logging.debug(f"Generated settings {settings}")

        return settings

    async def resolve_endpoints(self, session_id: str) -> Tuple[Dict, bool]:
        """Resolve the endpoint to connect to for a given session."""
        services = await self._fetch_resources("service", selectors={"sessionId": session_id})
        routes, status = await self._extract_routes(services)
        return routes, status

    def _extract_ports(self, service_spec: Dict) -> Tuple[Dict, bool]:
        """Extract the port information from the service specification.

        Args:
            service_spec (Dict): The service specification dictionary.

        Returns:
            Dict: A dictionary containing the route information.
        """
        ports = service_spec.spec.ports

        routes = []
        for port in ports:
            routes.append(
                {
                    "source_port": port.port,
                    "description": port.name,
                    "protocol": port.protocol,
                    "destination_port": port.node_port
                }
            )

        status = True if routes else False

        logging.debug(f"Extracted port mappings {routes}, readiness status {status}")
        return {"routes": routes}, status

    async def _extract_routes(self, services):
        routes = {}
        statuses = []

        for service in services:
            ports, port_ready = self._extract_ports(service)

            entries = []
            lb_ready = False
            if self._enable_wss:
                hostname, lb_ready = await self._extract_hostname(service)
                entries.append(hostname)
            else:
                ips, lb_ready = await self._extract_lb_ips(service)
                entries.extend(ips)

            statuses.append(all([port_ready, lb_ready]))
            if not lb_ready:
                continue

            for entry in entries:
                routes[entry] = ports

        status = all(statuses) if statuses else False
        logging.debug(f"Extracted routes {routes}, readiness status {status}")
        return routes, status

    async def _extract_lb_ips(self, service) -> Tuple[List, bool]:
        ips = []
        hostname = None

        ingress = service.status.load_balancer.ingress
        if not ingress:
            return ips, False

        for entry in ingress:
            if entry.ip:
                ips.append(entry.ip)
            hostname = entry.hostname

        if not ips and hostname:
            logging.debug("No IPs were found attached to the service, trying to resolve hostname")
            ips = await self._resolve_hostname(hostname)

        status = True if ips else False

        logging.debug(f"Extracted IPs {ips}, readiness status {status}")
        return ips, status

    async def _extract_hostname(self, service) -> Tuple[str, bool]:
        annotations = service.metadata.annotations
        hostname = None

        try:
            hostname = annotations[self._hostname_annotation_key]
        except KeyError:
            logging.error(f"Hostname field `{self._hostname_annotation_key}` not found in annotations")
            raise HostnameNotFoundError("Unable to find hostname")

        ready = False
        try:
            ips = await self._resolve_hostname(hostname)
            if ips:
                ready = True
        except aiodns.error.DNSError as exc:
            logging.warning(f"Unable to resolve {hostname}: {exc}")

        return hostname, ready


class AWS(_CSP):
    """AWS customisations."""

    def __init__(
        self,
        k8s_client: KubernetesClient,
        nlb_mgmt_svc: str = "",
        enable_wss: bool = False
    ) -> Any:
        """Initialize AWS class."""
        super().__init__(k8s_client=k8s_client)
        self._nlb_mgmt_svc_url = nlb_mgmt_svc

        self._port_locations = {
            "media": "streamingKit.service.mediaPort",
            "signaling": "streamingKit.service.signalingPort"
        }

        self._targetgroup_arn_locations = {
            "media": "streamingKit.aws.targetgroups.media",
            "signaling": "streamingKit.aws.targetgroups.signaling",
        }

        self._listeners_arn_locations = {
            "media": "streamingKit.aws.listeners.media",
            "signaling": "streamingKit.aws.listeners.signaling",
        }

        self._nlb_location = "streamingKit.aws.nlb"
        self._alias_location = "streamingKit.aws.alias"
        self._enable_wss = enable_wss

    def _lookup_nested_dict(self, nested_dict, key_string):
        keys = key_string.split('.')
        value = nested_dict
        for key in keys:
            value = value[key]
        return value

    async def on_create(self, profile_data: dict, settings: dict) -> Dict:
        """Process AWS customisations on stream creation."""
        values = profile_data["settings"]["values"]

        ports = {}
        for port_name, location in self._port_locations.items():
            ports[port_name] = self._lookup_nested_dict(values, location)

        allocations = []
        default_protocol = "TLS" if self._enable_wss else "TCP"

        for name in ports.keys():
            allocations.append(
                {
                    "name": name,
                    "protocol": "UDP" if name == "media" else default_protocol
                }
            )

        url = f"{self._nlb_mgmt_svc_url}/allocation"
        async with aiohttp.ClientSession() as session:
            async with session.post(url, json={"allocations": allocations}) as resp:
                if resp.status not in [200]:
                    detail = await resp.text()
                    error_msg = f"Failed request to {url}: {resp.status}, {detail}"
                    logging.error(error_msg)
                    raise _APIError(status_code=resp.status, details=error_msg)

                arns = await resp.json()

        settings[self._nlb_location] = arns["loadbalancer"]["dnsName"]
        settings[self._alias_location] = arns["loadbalancer"].get("alias", "")

        for key, location in self._targetgroup_arn_locations.items():
            listener_arn = arns['allocations'][key]['listenerArn']
            listener_port = arns['allocations'][key]['listenerPort']
            listener_protocol = arns['allocations'][key]['listenerProtocol']
            settings[location] = arns["allocations"][key]["targetGroupArn"]
            settings[self._listeners_arn_locations[key]] = f"{listener_arn}@{listener_port}@{listener_protocol}@{key}"

        logging.debug(f"Generated settings {settings}")
        return settings

    async def resolve_endpoints(self, session_id: str) -> Tuple[Dict, bool]:
        """Resolve the endpoint to connect to for a given session."""

        args = {
            'func_name': 'list_namespaced_custom_object',
            'group': 'elbv2.k8s.aws',
            'version': 'v1beta1',
            'plural': 'targetgroupbindings',
        }

        tgbs = await self._fetch_resources(
            "targetgroupbinding",
            selectors={"sessionId": session_id},
            args=args
        )

        routes, status = await self._extract_routes(tgbs)

        return routes, status

    async def _extract_routes(self, tgbs: List) -> Tuple[Dict, bool] :
        routes = []
        statuses = []

        hostnames = []

        for tgb in tgbs:
            annotations = tgb["metadata"]["annotations"]
            listener = annotations.get('nvidia.com/omniverse.listener', '')
            nlb_hostname = annotations.get('nvidia.com/omniverse.nlb', '')
            alias = annotations.get('nvidia.com/omniverse.alias', '')

            hostname = alias if self._enable_wss else nlb_hostname

            if not listener or not hostname:
                statuses.append(False)
                logging.error(f"Invalid targetgroupbinding was found. No listener annotation was found {tgb}")
                continue

            listener_arn, port, protocol, name = listener.split("@")
            hostnames.append(hostname)
            routes.append({
                "source_port": int(port),
                "description": name,
                "protocol": protocol,
                "destination_port": -1
            })
            status = await self._tcp_port_ready(hostname, port) if protocol.lower() == "tcp" else True
            statuses.append(status)

        status = all(statuses) if statuses else False

        assembled_routes = {}

        ips = []
        if not self._enable_wss:
            for hostname in hostnames:
                ips.extend(await self._resolve_hostname(hostname))
            hostnames = ips

        for hostname in hostnames:
            assembled_routes[hostname] = {"routes": routes}

        logging.debug(f"Extracted routes {assembled_routes}, readiness status {status}")
        return assembled_routes, status
