from hashlib import md5
import logging
from pathlib import Path


from wsgidav import util
from wsgidav.dc.base_dc import BaseDomainController
import jwt


__docformat__ = "reStructuredText"

_logger = util.get_module_logger(__name__)


class ACRMDomainController(BaseDomainController):
    def __init__(self, wsgidav_app, config):
        super().__init__(wsgidav_app, config)
        self.dc_conf = config.get("acrm_dc", {})
        self.endpoint = self.dc_conf["endpoint"]
        self.pub_key_path = Path(self.dc_conf["pub_key_path"])
        self.pub_key = None
        if self.pub_key_path.exists():
            self.pub_key = self.pub_key_path.read_bytes()
        self.org_name = self.dc_conf.get('org_name')
        self._session = None

    @property
    def session(self):
        if not self._session:
            from requests import session
            self._session = session()
        return self._session

    def get_domain_realm(self, path_info, environ):
        """Resolve a relative url to the appropriate realm name."""
        realm = self._calc_realm_from_path_provider(path_info, environ)
        return realm

    def supports_http_digest_auth(self):
        return False

    def require_authentication(self, realm, environ):
        return True

    def basic_auth_user(self, realm, user_name, password, environ):
        tokens_dir = Path("/opt/rdisk/tokens")
        tokens_dir.mkdir(exist_ok=True, parents=True)

        filename = md5((user_name + ":" + password).encode()).hexdigest()

        file_path = tokens_dir / filename

        if file_path.exists():
            jwt_content = file_path.read_bytes()
        else:
            data = {"login": user_name, "password": password}
            params = {"action": "signin_with_login_password"}

            resp = self.session.post(
                f"{self.endpoint}/auth/signin", params=params, data=data
            )

            if not resp.ok:
                _logger.info(f"bad auth for {user_name}")
                return False

            jwt_content = resp.content
            file_path.write_bytes(jwt_content)

        try:
            jwt_decoded = jwt.decode(jwt_content, key=self.pub_key, algorithms=['RS256', 'HS256'])
        except jwt.ExpiredSignatureError:
            _logger.info(f'{user_name} used expired jwt')
            file_path.unlink()
            return False
        except jwt.PyJWTError:
            _logger.exception(f"failed jwt decode for user {user_name}")
            return False

        scope = jwt_decoded.get("scope") or ""
        roles = []
        for r in scope.split():
            try:
                org_name, role = r.split(':', 1)
            except ValueError:
                _logger.warning("bad scope %s", r)
                continue
            if org_name == self.org_name:
                roles.append(role)

        if not roles or 'Users' not in roles:
            _logger.info(f'Нет доступа к wsgidav:{user_name}')
            return False

        _logger.info(f'Успешный логин {jwt_content.decode()}')
        environ["wsgidav.auth.jwt"] = jwt_content.decode()
        environ["wsgidav.auth.roles"] = roles

        return True
