import functools
import math
import ssl
import typing
from types import TracebackType

import trio

from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
    BaseBackgroundManager,
    BaseEvent,
    BasePoolSemaphore,
    BaseQueue,
    BaseTCPStream,
    ConcurrencyBackend,
    TimeoutFlag,
)


def _or_inf(value: typing.Optional[float]) -> float:
    return value if value is not None else float("inf")


class TCPStream(BaseTCPStream):
    def __init__(
        self,
        stream: typing.Union[trio.SocketStream, trio.SSLStream],
        timeout: TimeoutConfig,
    ) -> None:
        self.stream = stream
        self.timeout = timeout
        self.write_buffer = b""
        self.write_lock = trio.Lock()

    def get_http_version(self) -> str:
        if not isinstance(self.stream, trio.SSLStream):
            return "HTTP/1.1"

        ident = self.stream.selected_alpn_protocol()
        if ident is None:
            return "HTTP/1.1"

        return "HTTP/2" if ident == "h2" else "HTTP/1.1"

    async def read(
        self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
    ) -> bytes:
        if timeout is None:
            timeout = self.timeout

        while True:
            # Check our flag at the first possible moment, and use a fine
            # grained retry loop if we're not yet in read-timeout mode.
            should_raise = flag is None or flag.raise_on_read_timeout
            read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01)

            with trio.move_on_after(read_timeout):
                return await self.stream.receive_some(max_bytes=n)

            if should_raise:
                raise ReadTimeout() from None

    def is_connection_dropped(self) -> bool:
        # Adapted from: https://github.com/encode/httpx/pull/143#issuecomment-515202982
        stream = self.stream

        # Peek through any SSLStream wrappers to get the underlying SocketStream.
        while hasattr(stream, "transport_stream"):
            stream = stream.transport_stream
        assert isinstance(stream, trio.SocketStream)

        # Counter-intuitively, what we really want to know here is whether the socket is
        # *readable*, i.e. whether it would return immediately with empty bytes if we
        # called `.recv()` on it, indicating that the other end has closed the socket.
        # See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
        return stream.socket.is_readable()

    def write_no_block(self, data: bytes) -> None:
        self.write_buffer += data

    async def write(
        self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
    ) -> None:
        if self.write_buffer:
            previous_data = self.write_buffer
            # Reset before recursive call, otherwise we'll go through
            # this branch indefinitely.
            self.write_buffer = b""
            try:
                await self.write(previous_data, timeout=timeout, flag=flag)
            except WriteTimeout:
                self.writer_buffer = previous_data
                raise

        if not data:
            return

        if timeout is None:
            timeout = self.timeout

        write_timeout = _or_inf(timeout.write_timeout)

        while True:
            with trio.move_on_after(write_timeout):
                async with self.write_lock:
                    await self.stream.send_all(data)
                break
            # We check our flag at the first possible moment, in order to
            # allow us to suppress write timeouts, if we've since
            # switched over to read-timeout mode.
            should_raise = flag is None or flag.raise_on_write_timeout
            if should_raise:
                raise WriteTimeout() from None

    async def close(self) -> None:
        await self.stream.aclose()


class PoolSemaphore(BasePoolSemaphore):
    def __init__(self, pool_limits: PoolLimits):
        self.pool_limits = pool_limits

    @property
    def semaphore(self) -> typing.Optional[trio.Semaphore]:
        if not hasattr(self, "_semaphore"):
            max_connections = self.pool_limits.hard_limit
            if max_connections is None:
                self._semaphore = None
            else:
                self._semaphore = trio.Semaphore(
                    max_connections, max_value=max_connections
                )
        return self._semaphore

    async def acquire(self) -> None:
        if self.semaphore is None:
            return

        timeout = _or_inf(self.pool_limits.pool_timeout)

        with trio.move_on_after(timeout):
            await self.semaphore.acquire()
            return

        raise PoolTimeout()

    def release(self) -> None:
        if self.semaphore is None:
            return

        self.semaphore.release()


class TrioBackend(ConcurrencyBackend):
    async def open_tcp_stream(
        self,
        hostname: str,
        port: int,
        ssl_context: typing.Optional[ssl.SSLContext],
        timeout: TimeoutConfig,
    ) -> TCPStream:
        connect_timeout = _or_inf(timeout.connect_timeout)

        with trio.move_on_after(connect_timeout) as cancel_scope:
            stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
            if ssl_context is not None:
                stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
                await stream.do_handshake()

        if cancel_scope.cancelled_caught:
            raise ConnectTimeout()

        return TCPStream(stream=stream, timeout=timeout)

    async def run_in_threadpool(
        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
    ) -> typing.Any:
        return await trio.to_thread.run_sync(
            functools.partial(func, **kwargs) if kwargs else func, *args
        )

    def run(
        self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
    ) -> typing.Any:
        return trio.run(
            functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
        )

    def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
        return PoolSemaphore(limits)

    def create_queue(self, max_size: int) -> BaseQueue:
        return Queue(max_size=max_size)

    def create_event(self) -> BaseEvent:
        return Event()

    def background_manager(
        self, coroutine: typing.Callable, *args: typing.Any
    ) -> "BackgroundManager":
        return BackgroundManager(coroutine, *args)


class Queue(BaseQueue):
    def __init__(self, max_size: int) -> None:
        self.send_channel, self.receive_channel = trio.open_memory_channel(math.inf)

    async def get(self) -> typing.Any:
        return await self.receive_channel.receive()

    async def put(self, value: typing.Any) -> None:
        await self.send_channel.send(value)


class Event(BaseEvent):
    def __init__(self) -> None:
        self._event = trio.Event()

    def set(self) -> None:
        self._event.set()

    def is_set(self) -> bool:
        return self._event.is_set()

    async def wait(self) -> None:
        await self._event.wait()

    def clear(self) -> None:
        # trio.Event.clear() was deprecated in Trio 0.12.
        # https://github.com/python-trio/trio/issues/637
        self._event = trio.Event()


class BackgroundManager(BaseBackgroundManager):
    def __init__(self, coroutine: typing.Callable, *args: typing.Any) -> None:
        self.coroutine = coroutine
        self.args = args
        self.nursery_manager = trio.open_nursery()
        self.nursery: typing.Optional[trio.Nursery] = None

    async def __aenter__(self) -> "BackgroundManager":
        self.nursery = await self.nursery_manager.__aenter__()
        self.nursery.start_soon(self.coroutine, *self.args)
        return self

    async def __aexit__(
        self,
        exc_type: typing.Type[BaseException] = None,
        exc_value: BaseException = None,
        traceback: TracebackType = None,
    ) -> None:
        assert self.nursery is not None
        await self.nursery_manager.__aexit__(exc_type, exc_value, traceback)
