""" Python implementation of the Happy Eyeballs Algorithm described in RFC 6555. """

# Copyright 2017 Seth Michael Larson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import errno
import socket
try:
    from selectors  import DefaultSelector, EVENT_WRITE
except (ImportError, AttributeError):
    from selectors2 import DefaultSelector, EVENT_WRITE

# time.perf_counter() is defined in Python 3.3
try:
    from time import perf_counter
except (ImportError, AttributeError):
    from time import time as perf_counter


# This list is due to socket.error and IOError not being a
# subclass of OSError until later versions of Python.
_SOCKET_ERRORS = (socket.error, OSError, IOError)


# Detects whether an IPv6 socket can be allocated.
def _detect_ipv6():
    if getattr(socket, 'has_ipv6', False) and hasattr(socket, 'AF_INET6'):
        _sock = None
        try:
            _sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
            _sock.bind(('::1', 0))
            return True
        except _SOCKET_ERRORS:
            if _sock:
                _sock.close()
    return False


_HAS_IPV6 = _detect_ipv6()

# These are error numbers for asynchronous operations which can
# be safely ignored by RFC 6555 as being non-errors.
_ASYNC_ERRNOS = set([errno.EINPROGRESS,
                     errno.EAGAIN,
                     errno.EWOULDBLOCK])
if hasattr(errno, 'WSAWOULDBLOCK'):
    _ASYNC_ERRNOS.add(errno.WSAWOULDBLOCK)

_DEFAULT_CACHE_DURATION = 60 * 10  # 10 minutes according to the RFC.

# This value that can be used to disable RFC 6555 globally.
RFC6555_ENABLED = _HAS_IPV6

__all__ = ['RFC6555_ENABLED',
           'create_connection',
           'cache']

__version__ = '1.0.0'
__author__ = 'Seth Michael Larson'
__email__ = 'sethmichaellarson@protonmail.com'
__license__ = 'Apache-2.0'


class _RFC6555CacheManager(object):
    def __init__(self):
        self.validity_duration = _DEFAULT_CACHE_DURATION
        self.enabled = True
        self.entries = {}

    def add_entry(self, address, family):
        if self.enabled:
            current_time = perf_counter()

            # Don't over-write old entries to reset their expiry.
            if address not in self.entries or self.entries[address][1] > current_time:
                self.entries[address] = (family, current_time + self.validity_duration)

    def get_entry(self, address):
        if not self.enabled or address not in self.entries:
            return None

        family, expiry = self.entries[address]
        if perf_counter() > expiry:
            del self.entries[address]
            return None

        return family


cache = _RFC6555CacheManager()


class _RFC6555ConnectionManager(object):
    def __init__(self, address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
        self.address = address
        self.timeout = timeout
        self.source_address = source_address

        self._error = None
        self._selector = DefaultSelector()
        self._sockets = []
        self._start_time = None

    def create_connection(self):
        self._start_time = perf_counter()

        host, port = self.address
        addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
        ret = self._connect_with_cached_family(addr_info)

        # If it's a list, then these are the remaining values to try.
        if isinstance(ret, list):
            addr_info = ret
        else:
            cache.add_entry(self.address, ret.family)
            return ret

        # If we don't get any results back then just skip to the end.
        if not addr_info:
            raise socket.error('getaddrinfo returns an empty list')

        sock = self._attempt_connect_with_addr_info(addr_info)

        if sock:
            cache.add_entry(self.address, sock.family)
            return sock
        elif self._error:
            raise self._error
        else:
            raise socket.timeout()

    def _attempt_connect_with_addr_info(self, addr_info):
        sock = None
        try:
            for family, socktype, proto, _, sockaddr in addr_info:
                self._create_socket(family, socktype, proto, sockaddr)
                sock = self._wait_for_connection(False)
                if sock:
                    break
            if sock is None:
                sock = self._wait_for_connection(True)
        finally:
            self._remove_all_sockets()
        return sock

    def _connect_with_cached_family(self, addr_info):
        family = cache.get_entry(self.address)
        if family is None:
            return addr_info

        is_family = []
        not_family = []

        for value in addr_info:
            if value[0] == family:
                is_family.append(value)
            else:
                not_family.append(value)

        sock = self._attempt_connect_with_addr_info(is_family)
        if sock is not None:
            return sock

        return not_family

    def _create_socket(self, family, socktype, proto, sockaddr):
        sock = None
        try:
            sock = socket.socket(family, socktype, proto)

            # If we're using the 'default' socket timeout we have
            # to set it to a real value here as this is the earliest
            # opportunity to without pre-allocating a socket just for
            # this purpose.
            if self.timeout is socket._GLOBAL_DEFAULT_TIMEOUT:
                self.timeout = sock.gettimeout()

            if self.source_address:
                sock.bind(self.source_address)

            # Make the socket non-blocking so we can use our selector.
            sock.settimeout(0.0)

            if self._is_acceptable_errno(sock.connect_ex(sockaddr)):
                self._selector.register(sock, EVENT_WRITE)
                self._sockets.append(sock)

        except _SOCKET_ERRORS as e:
            self._error = e
            if sock is not None:
                _RFC6555ConnectionManager._close_socket(sock)

    def _wait_for_connection(self, last_wait):
        self._remove_all_errored_sockets()

        # This is a safe-guard to make sure sock.gettimeout() is called in the
        # case that the default socket timeout is used. If there are no
        # sockets then we may not have called sock.gettimeout() yet.
        if not self._sockets:
            return None

        # If this is the last time we're waiting for connections
        # then we should wait until we should raise a timeout
        # error, otherwise we should only wait >0.2 seconds as
        # recommended by RFC 6555.
        if last_wait:
            if self.timeout is None:
                select_timeout = None
            else:
                select_timeout = self._get_remaining_time()
        else:
            select_timeout = self._get_select_time()

        # Wait for any socket to become writable as a sign of being connected.
        for key, _ in self._selector.select(select_timeout):
            sock = key.fileobj

            if not self._is_socket_errored(sock):

                # Restore the old proper timeout of the socket.
                sock.settimeout(self.timeout)

                # Remove it from this list to exempt the socket from cleanup.
                self._sockets.remove(sock)
                self._selector.unregister(sock)
                return sock

        return None

    def _get_remaining_time(self):
        if self.timeout is None:
            return None
        return max(self.timeout - (perf_counter() - self._start_time), 0.0)

    def _get_select_time(self):
        if self.timeout is None:
            return 0.2
        return min(0.2, self._get_remaining_time())

    def _remove_all_errored_sockets(self):
        socks = []
        for sock in self._sockets:
            if self._is_socket_errored(sock):
                socks.append(sock)
        for sock in socks:
            self._selector.unregister(sock)
            self._sockets.remove(sock)
            _RFC6555ConnectionManager._close_socket(sock)

    @staticmethod
    def _close_socket(sock):
        try:
            sock.close()
        except _SOCKET_ERRORS:
            pass

    def _is_acceptable_errno(self, errno):
        if errno == 0 or errno in _ASYNC_ERRNOS:
            return True
        self._error = socket.error()
        self._error.errno = errno
        return False

    def _is_socket_errored(self, sock):
        errno = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
        return not self._is_acceptable_errno(errno)

    def _remove_all_sockets(self):
        for sock in self._sockets:
            self._selector.unregister(sock)
            _RFC6555ConnectionManager._close_socket(sock)
        self._sockets = []


def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
    if RFC6555_ENABLED and _HAS_IPV6:
        manager = _RFC6555ConnectionManager(address, timeout, source_address)
        return manager.create_connection()
    else:
        # This code is the same as socket.create_connection() but is
        # here to make sure the same code is used across all Python versions as
        # the source_address parameter was added to socket.create_connection() in 3.2
        # This segment of code is licensed under the Python Software Foundation License
        # See LICENSE: https://github.com/python/cpython/blob/3.6/LICENSE
        host, port = address
        err = None
        for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
            af, socktype, proto, canonname, sa = res
            sock = None
            try:
                sock = socket.socket(af, socktype, proto)
                if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
                    sock.settimeout(timeout)
                if source_address:
                    sock.bind(source_address)
                sock.connect(sa)
                return sock

            except socket.error as _:
                err = _
                if sock is not None:
                    sock.close()

        if err is not None:
            raise err
        else:
            raise socket.error("getaddrinfo returns an empty list")
