# Copyright (C) 2016 JWCrypto Project Contributors - see LICENSE file

import os
import struct
from abc import ABCMeta, abstractmethod
from binascii import hexlify, unhexlify

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import constant_time, hashes, hmac
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import utils as ec_utils
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives.keywrap import aes_key_unwrap, aes_key_wrap
from cryptography.hazmat.primitives.padding import PKCS7

from jwcrypto.common import InvalidCEKeyLength
from jwcrypto.common import InvalidJWAAlgorithm
from jwcrypto.common import InvalidJWEKeyLength
from jwcrypto.common import InvalidJWEKeyType
from jwcrypto.common import InvalidJWEOperation
from jwcrypto.common import base64url_decode, base64url_encode
from jwcrypto.common import json_decode
from jwcrypto.jwk import JWK

# Implements:
# - RFC 7518: JSON Web Algorithms (JWA)
# - RFC 8037: CFRG Elliptic Curve Diffie-Hellman (ECDH) and Signatures
#             in JSON Object Signing and Encryption (JOSE)
# - RFC 9864: Fully-Specified Algorithms for JSON Object Signing and
#             Encryption (JOSE) and CBOR Object Signing and Encryption
#             (COSE)

default_max_pbkdf2_iterations = 16384
"""The maximum number of iterations allowed for PBKDF2 key derivation.

This is a security measure to prevent denial-of-service attacks by malicious
actors providing a very high iteration count.
"""

default_enforce_hmac_key_length = True
"""Enforces that the HMAC key length is at least the size of the hash
function's output, as recommended by RFC 7518.

This can be disabled for compatibility with legacy or non-compliant systems
that use shorter keys.
"""


class JWAAlgorithm(metaclass=ABCMeta):

    @property
    @abstractmethod
    def name(self):
        """The algorithm Name"""

    @property
    @abstractmethod
    def description(self):
        """A short description"""

    @property
    @abstractmethod
    def keysize(self):
        """The algorithm key size"""

    @property
    @abstractmethod
    def algorithm_usage_location(self):
        """One of 'alg', 'enc' or 'JWK'"""

    @property
    @abstractmethod
    def algorithm_use(self):
        """One of 'sig', 'kex', 'enc'"""

    @property
    def input_keysize(self):
        """The input key size"""
        try:
            return self.wrap_key_size
        except AttributeError:
            return self.keysize


def _bitsize(x):
    return len(x) * 8


def _inbytes(x):
    return x // 8


def _randombits(x):
    if x % 8 != 0:
        raise ValueError("length must be a multiple of 8")
    return os.urandom(_inbytes(x))


# Note: the number of bits should be a multiple of 16
def _encode_int(n, bits):
    e = '{:x}'.format(n)
    ilen = ((bits + 7) // 8) * 2  # number of bytes rounded up times 2 bytes
    return unhexlify(e.rjust(ilen, '0')[:ilen])


def _decode_int(n):
    return int(hexlify(n), 16)


class _RawJWS:

    def sign(self, key, payload):
        raise NotImplementedError

    def verify(self, key, payload, signature):
        raise NotImplementedError


class _RawHMAC(_RawJWS):

    keysize = None

    def __init__(self, hashfn):
        self.backend = default_backend()
        self.hashfn = hashfn

    def _hmac_setup(self, key, payload):
        # Ensure key size matches RFC 7518 requirements
        if default_enforce_hmac_key_length and _bitsize(key) < self.keysize:
            raise InvalidJWEKeyLength(self.keysize, _bitsize(key))
        h = hmac.HMAC(key, self.hashfn, backend=self.backend)
        h.update(payload)
        return h

    def sign(self, key, payload):
        skey = base64url_decode(key.get_op_key('sign'))
        h = self._hmac_setup(skey, payload)
        return h.finalize()

    def verify(self, key, payload, signature):
        vkey = base64url_decode(key.get_op_key('verify'))
        h = self._hmac_setup(vkey, payload)
        h.verify(signature)


class _RawRSA(_RawJWS):
    def __init__(self, padfn, hashfn):
        self.padfn = padfn
        self.hashfn = hashfn

    def sign(self, key, payload):
        skey = key.get_op_key('sign')
        return skey.sign(payload, self.padfn, self.hashfn)

    def verify(self, key, payload, signature):
        pkey = key.get_op_key('verify')
        pkey.verify(signature, payload, self.padfn, self.hashfn)


class _RawEC(_RawJWS):
    def __init__(self, curve, hashfn):
        self._curve = curve
        self.hashfn = hashfn

    @property
    def curve(self):
        return self._curve

    def sign(self, key, payload):
        skey = key.get_op_key('sign', self._curve)
        size = skey.key_size
        signature = skey.sign(payload, ec.ECDSA(self.hashfn))
        r, s = ec_utils.decode_dss_signature(signature)
        return _encode_int(r, size) + _encode_int(s, size)

    def verify(self, key, payload, signature):
        pkey = key.get_op_key('verify', self._curve)
        r = signature[:len(signature) // 2]
        s = signature[len(signature) // 2:]
        enc_signature = ec_utils.encode_dss_signature(
            int(hexlify(r), 16), int(hexlify(s), 16))
        pkey.verify(enc_signature, payload, ec.ECDSA(self.hashfn))


class _RawNone(_RawJWS):

    def sign(self, key, payload):
        return ''

    def verify(self, key, payload, signature):
        if key['kty'] != 'oct' or key.get_op_key() != '':
            raise InvalidSignature('The "none" signature cannot be verified')


class _HS256(_RawHMAC, JWAAlgorithm):

    name = "HS256"
    description = "HMAC using SHA-256"
    keysize = 256
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_HS256, self).__init__(hashes.SHA256())


class _HS384(_RawHMAC, JWAAlgorithm):

    name = "HS384"
    description = "HMAC using SHA-384"
    keysize = 384
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_HS384, self).__init__(hashes.SHA384())


class _HS512(_RawHMAC, JWAAlgorithm):

    name = "HS512"
    description = "HMAC using SHA-512"
    keysize = 512
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_HS512, self).__init__(hashes.SHA512())


class _RS256(_RawRSA, JWAAlgorithm):

    name = "RS256"
    description = "RSASSA-PKCS1-v1_5 using SHA-256"
    keysize = 2048
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_RS256, self).__init__(padding.PKCS1v15(), hashes.SHA256())


class _RS384(_RawRSA, JWAAlgorithm):

    name = "RS384"
    description = "RSASSA-PKCS1-v1_5 using SHA-384"
    keysize = 2048
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_RS384, self).__init__(padding.PKCS1v15(), hashes.SHA384())


class _RS512(_RawRSA, JWAAlgorithm):

    name = "RS512"
    description = "RSASSA-PKCS1-v1_5 using SHA-512"
    keysize = 2048
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_RS512, self).__init__(padding.PKCS1v15(), hashes.SHA512())


class _ES256(_RawEC, JWAAlgorithm):

    name = "ES256"
    description = "ECDSA using P-256 and SHA-256"
    keysize = 256
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_ES256, self).__init__('P-256', hashes.SHA256())


class _ES256K(_RawEC, JWAAlgorithm):

    name = "ES256K"
    description = "ECDSA using secp256k1 curve and SHA-256"
    keysize = 256
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_ES256K, self).__init__('secp256k1', hashes.SHA256())


class _ES384(_RawEC, JWAAlgorithm):

    name = "ES384"
    description = "ECDSA using P-384 and SHA-384"
    keysize = 384
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_ES384, self).__init__('P-384', hashes.SHA384())


class _ES512(_RawEC, JWAAlgorithm):

    name = "ES512"
    description = "ECDSA using P-521 and SHA-512"
    keysize = 512
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_ES512, self).__init__('P-521', hashes.SHA512())


class _PS256(_RawRSA, JWAAlgorithm):

    name = "PS256"
    description = "RSASSA-PSS using SHA-256 and MGF1 with SHA-256"
    keysize = 2048
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        padfn = padding.PSS(padding.MGF1(hashes.SHA256()),
                            hashes.SHA256.digest_size)
        super(_PS256, self).__init__(padfn, hashes.SHA256())


class _PS384(_RawRSA, JWAAlgorithm):

    name = "PS384"
    description = "RSASSA-PSS using SHA-384 and MGF1 with SHA-384"
    keysize = 2048
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        padfn = padding.PSS(padding.MGF1(hashes.SHA384()),
                            hashes.SHA384.digest_size)
        super(_PS384, self).__init__(padfn, hashes.SHA384())


class _PS512(_RawRSA, JWAAlgorithm):

    name = "PS512"
    description = "RSASSA-PSS using SHA-512 and MGF1 with SHA-512"
    keysize = 2048
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        padfn = padding.PSS(padding.MGF1(hashes.SHA512()),
                            hashes.SHA512.digest_size)
        super(_PS512, self).__init__(padfn, hashes.SHA512())


class _None(_RawNone, JWAAlgorithm):

    name = "none"
    description = "No digital signature or MAC performed"
    keysize = 0
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'


class _RawKeyMgmt:

    def wrap(self, key, bitsize, cek, headers):
        raise NotImplementedError

    def unwrap(self, key, bitsize, ek, headers):
        raise NotImplementedError


class _RSA(_RawKeyMgmt):

    def __init__(self, padfn):
        self.padfn = padfn

    def _check_key(self, key):
        if not isinstance(key, JWK):
            raise ValueError('key is not a JWK object')
        if key['kty'] != 'RSA':
            raise InvalidJWEKeyType('RSA', key['kty'])

    # FIXME: get key size and ensure > 2048 bits
    def wrap(self, key, bitsize, cek, headers):
        self._check_key(key)
        if not cek:
            cek = _randombits(bitsize)
        rk = key.get_op_key('wrapKey')
        ek = rk.encrypt(cek, self.padfn)
        return {'cek': cek, 'ek': ek}

    def unwrap(self, key, bitsize, ek, headers):
        self._check_key(key)
        rk = key.get_op_key('unwrapKey')
        cek = rk.decrypt(ek, self.padfn)
        if _bitsize(cek) != bitsize:
            raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
        return cek


class _Rsa15(_RSA, JWAAlgorithm):

    name = 'RSA1_5'
    description = "RSAES-PKCS1-v1_5"
    keysize = 2048
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'

    def __init__(self):
        super(_Rsa15, self).__init__(padding.PKCS1v15())

    def unwrap(self, key, bitsize, ek, headers):
        self._check_key(key)
        # Address MMA attack by implementing RFC 3218 - 2.3.2. Random Filling
        # provides a random cek that will cause the decryption engine to
        # run to the end, but will fail decryption later.

        # always generate a random cek so we spend roughly the
        # same time as in the exception side of the branch
        cek = _randombits(bitsize)
        try:
            cek = super(_Rsa15, self).unwrap(key, bitsize, ek, headers)
            # always raise so we always run through the exception handling
            # code in all cases
            raise ValueError('Dummy')
        except Exception:  # pylint: disable=broad-except
            return cek


class _RsaOaep(_RSA, JWAAlgorithm):

    name = 'RSA-OAEP'
    description = "RSAES OAEP using default parameters"
    keysize = 2048
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'

    def __init__(self):
        super(_RsaOaep, self).__init__(
            padding.OAEP(padding.MGF1(hashes.SHA1()),
                         hashes.SHA1(), None))


class _RsaOaep256(_RSA, JWAAlgorithm):  # noqa: ignore=N801

    name = 'RSA-OAEP-256'
    description = "RSAES OAEP using SHA-256 and MGF1 with SHA-256"
    keysize = 2048
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'

    def __init__(self):
        super(_RsaOaep256, self).__init__(
            padding.OAEP(padding.MGF1(hashes.SHA256()),
                         hashes.SHA256(), None))


class _AesKw(_RawKeyMgmt):

    keysize = None

    def __init__(self):
        self.backend = default_backend()

    def _get_key(self, key, op):
        if not isinstance(key, JWK):
            raise ValueError('key is not a JWK object')
        if key['kty'] != 'oct':
            raise InvalidJWEKeyType('oct', key['kty'])
        rk = base64url_decode(key.get_op_key(op))
        if _bitsize(rk) != self.keysize:
            raise InvalidJWEKeyLength(self.keysize, _bitsize(rk))
        return rk

    def wrap(self, key, bitsize, cek, headers):
        rk = self._get_key(key, 'encrypt')
        if not cek:
            cek = _randombits(bitsize)

        ek = aes_key_wrap(rk, cek, default_backend())

        return {'cek': cek, 'ek': ek}

    def unwrap(self, key, bitsize, ek, headers):
        rk = self._get_key(key, 'decrypt')

        cek = aes_key_unwrap(rk, ek, default_backend())
        if _bitsize(cek) != bitsize:
            raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
        return cek


class _A128KW(_AesKw, JWAAlgorithm):

    name = 'A128KW'
    description = "AES Key Wrap using 128-bit key"
    keysize = 128
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'


class _A192KW(_AesKw, JWAAlgorithm):

    name = 'A192KW'
    description = "AES Key Wrap using 192-bit key"
    keysize = 192
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'


class _A256KW(_AesKw, JWAAlgorithm):

    name = 'A256KW'
    description = "AES Key Wrap using 256-bit key"
    keysize = 256
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'


class _AesGcmKw(_RawKeyMgmt):

    keysize = None

    def __init__(self):
        self.backend = default_backend()

    def _get_key(self, key, op):
        if not isinstance(key, JWK):
            raise ValueError('key is not a JWK object')
        if key['kty'] != 'oct':
            raise InvalidJWEKeyType('oct', key['kty'])
        rk = base64url_decode(key.get_op_key(op))
        if _bitsize(rk) != self.keysize:
            raise InvalidJWEKeyLength(self.keysize, _bitsize(rk))
        return rk

    def wrap(self, key, bitsize, cek, headers):
        rk = self._get_key(key, 'encrypt')
        if not cek:
            cek = _randombits(bitsize)

        iv = _randombits(96)
        cipher = Cipher(algorithms.AES(rk), modes.GCM(iv),
                        backend=self.backend)
        encryptor = cipher.encryptor()
        ek = encryptor.update(cek) + encryptor.finalize()

        tag = encryptor.tag
        return {'cek': cek, 'ek': ek,
                'header': {'iv': base64url_encode(iv),
                           'tag': base64url_encode(tag)}}

    def unwrap(self, key, bitsize, ek, headers):
        rk = self._get_key(key, 'decrypt')

        if 'iv' not in headers:
            raise ValueError('Invalid Header, missing "iv" parameter')
        iv = base64url_decode(headers['iv'])
        if 'tag' not in headers:
            raise ValueError('Invalid Header, missing "tag" parameter')
        tag = base64url_decode(headers['tag'])

        cipher = Cipher(algorithms.AES(rk), modes.GCM(iv, tag),
                        backend=self.backend)
        decryptor = cipher.decryptor()
        cek = decryptor.update(ek) + decryptor.finalize()
        if _bitsize(cek) != bitsize:
            raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
        return cek


class _A128GcmKw(_AesGcmKw, JWAAlgorithm):

    name = 'A128GCMKW'
    description = "Key wrapping with AES GCM using 128-bit key"
    keysize = 128
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'


class _A192GcmKw(_AesGcmKw, JWAAlgorithm):

    name = 'A192GCMKW'
    description = "Key wrapping with AES GCM using 192-bit key"
    keysize = 192
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'


class _A256GcmKw(_AesGcmKw, JWAAlgorithm):

    name = 'A256GCMKW'
    description = "Key wrapping with AES GCM using 256-bit key"
    keysize = 256
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'


class _Pbes2HsAesKw(_RawKeyMgmt):

    name = None
    keysize = None
    hashsize = None

    def __init__(self):
        self.backend = default_backend()
        self.aeskwmap = {128: _A128KW, 192: _A192KW, 256: _A256KW}

    def _get_key(self, alg, key, p2s, p2c):
        if p2c > default_max_pbkdf2_iterations:
            raise ValueError('Invalid p2c value, too large')

        if not isinstance(key, JWK):
            # backwards compatibility for old interface
            if isinstance(key, bytes):
                plain = key
            else:
                plain = key.encode('utf8')
        else:
            plain = base64url_decode(key.get_op_key())

        salt = bytes(self.name.encode('utf8')) + b'\x00' + p2s

        if self.hashsize == 256:
            hashalg = hashes.SHA256()
        elif self.hashsize == 384:
            hashalg = hashes.SHA384()
        elif self.hashsize == 512:
            hashalg = hashes.SHA512()
        else:
            raise ValueError('Unknown Hash Size')

        kdf = PBKDF2HMAC(algorithm=hashalg, length=_inbytes(self.keysize),
                         salt=salt, iterations=p2c, backend=self.backend)
        rk = kdf.derive(plain)
        if _bitsize(rk) != self.keysize:
            raise InvalidJWEKeyLength(self.keysize, len(rk))
        return JWK(kty="oct", use="enc", k=base64url_encode(rk))

    def wrap(self, key, bitsize, cek, headers):
        ret_header = {}
        if 'p2s' in headers:
            p2s = base64url_decode(headers['p2s'])
            if len(p2s) < 8:
                raise ValueError('Invalid Salt, must be 8 or more octects')
        else:
            p2s = _randombits(128)
            ret_header['p2s'] = base64url_encode(p2s)
        if 'p2c' in headers:
            p2c = headers['p2c']
        else:
            p2c = 8192
            ret_header['p2c'] = p2c
        kek = self._get_key(headers['alg'], key, p2s, p2c)

        aeskw = self.aeskwmap[self.keysize]()
        ret = aeskw.wrap(kek, bitsize, cek, headers)
        if len(ret_header) > 0:
            ret['header'] = ret_header
        return ret

    def unwrap(self, key, bitsize, ek, headers):
        if 'p2s' not in headers:
            raise ValueError('Invalid Header, missing "p2s" parameter')
        if 'p2c' not in headers:
            raise ValueError('Invalid Header, missing "p2c" parameter')
        p2s = base64url_decode(headers['p2s'])
        p2c = headers['p2c']
        kek = self._get_key(headers['alg'], key, p2s, p2c)

        aeskw = self.aeskwmap[self.keysize]()
        return aeskw.unwrap(kek, bitsize, ek, headers)


class _Pbes2Hs256A128Kw(_Pbes2HsAesKw, JWAAlgorithm):

    name = 'PBES2-HS256+A128KW'
    description = 'PBES2 with HMAC SHA-256 and "A128KW" wrapping'
    keysize = 128
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'
    hashsize = 256


class _Pbes2Hs384A192Kw(_Pbes2HsAesKw, JWAAlgorithm):

    name = 'PBES2-HS384+A192KW'
    description = 'PBES2 with HMAC SHA-384 and "A192KW" wrapping'
    keysize = 192
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'
    hashsize = 384


class _Pbes2Hs512A256Kw(_Pbes2HsAesKw, JWAAlgorithm):

    name = 'PBES2-HS512+A256KW'
    description = 'PBES2 with HMAC SHA-512 and "A256KW" wrapping'
    keysize = 256
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'
    hashsize = 512


class _Direct(_RawKeyMgmt, JWAAlgorithm):

    name = 'dir'
    description = "Direct use of a shared symmetric key"
    keysize = 128
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'

    def _check_key(self, key):
        if not isinstance(key, JWK):
            raise ValueError('key is not a JWK object')
        if key['kty'] != 'oct':
            raise InvalidJWEKeyType('oct', key['kty'])

    def wrap(self, key, bitsize, cek, headers):
        self._check_key(key)
        if cek:
            return (cek, None)
        k = base64url_decode(key.get_op_key('encrypt'))
        if _bitsize(k) != bitsize:
            raise InvalidCEKeyLength(bitsize, _bitsize(k))
        return {'cek': k}

    def unwrap(self, key, bitsize, ek, headers):
        self._check_key(key)
        if ek != b'':
            raise ValueError('Invalid Encryption Key.')
        cek = base64url_decode(key.get_op_key('decrypt'))
        if _bitsize(cek) != bitsize:
            raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
        return cek


class _EcdhEs(_RawKeyMgmt, JWAAlgorithm):

    name = 'ECDH-ES'
    description = "ECDH-ES using Concat KDF"
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'
    keysize = None

    def __init__(self):
        self.backend = default_backend()
        self.aeskwmap = {128: _A128KW, 192: _A192KW, 256: _A256KW}

    def _check_key(self, key):
        if not isinstance(key, JWK):
            raise ValueError('key is not a JWK object')
        if key['kty'] not in ['EC', 'OKP']:
            raise InvalidJWEKeyType('EC or OKP', key['kty'])
        if key['kty'] == 'OKP':
            if key['crv'] not in ['X25519', 'X448']:
                raise InvalidJWEKeyType('X25519 or X448',
                                        key['crv'])

    def _derive(self, privkey, pubkey, alg, bitsize, headers):
        # OtherInfo is defined in NIST SP 56A 5.8.1.2.1

        # AlgorithmID
        otherinfo = struct.pack('>I', len(alg))
        otherinfo += bytes(alg.encode('utf8'))

        # PartyUInfo
        apu = base64url_decode(headers['apu']) if 'apu' in headers else b''
        otherinfo += struct.pack('>I', len(apu))
        otherinfo += apu

        # PartyVInfo
        apv = base64url_decode(headers['apv']) if 'apv' in headers else b''
        otherinfo += struct.pack('>I', len(apv))
        otherinfo += apv

        # SuppPubInfo
        otherinfo += struct.pack('>I', bitsize)

        # no SuppPrivInfo

        # Shared Key generation
        if isinstance(privkey, ec.EllipticCurvePrivateKey):
            shared_key = privkey.exchange(ec.ECDH(), pubkey)
        else:
            # X25519/X448
            shared_key = privkey.exchange(pubkey)

        ckdf = ConcatKDFHash(algorithm=hashes.SHA256(),
                             length=_inbytes(bitsize),
                             otherinfo=otherinfo,
                             backend=self.backend)
        return ckdf.derive(shared_key)

    def wrap(self, key, bitsize, cek, headers):
        self._check_key(key)
        dk_size = self.keysize
        if self.keysize is None:
            if cek is not None:
                raise InvalidJWEOperation('ECDH-ES cannot use an existing CEK')
            alg = headers['enc']
            dk_size = bitsize
        else:
            alg = headers['alg']

        epk = JWK.generate(kty=key['kty'], crv=key['crv'])
        dk = self._derive(epk.get_op_key('unwrapKey'),
                          key.get_op_key('wrapKey'),
                          alg, dk_size, headers)

        if self.keysize is None:
            ret = {'cek': dk}
        else:
            aeskw = self.aeskwmap[self.keysize]()
            kek = JWK(kty="oct", use="enc", k=base64url_encode(dk))
            ret = aeskw.wrap(kek, bitsize, cek, headers)

        ret['header'] = {'epk': json_decode(epk.export_public())}
        return ret

    def unwrap(self, key, bitsize, ek, headers):
        if 'epk' not in headers:
            raise ValueError('Invalid Header, missing "epk" parameter')
        self._check_key(key)
        dk_size = self.keysize
        if self.keysize is None:
            alg = headers['enc']
            dk_size = bitsize
        else:
            alg = headers['alg']

        epk = JWK(**headers['epk'])
        dk = self._derive(key.get_op_key('unwrapKey'),
                          epk.get_op_key('wrapKey'),
                          alg, dk_size, headers)
        if self.keysize is None:
            return dk
        else:
            aeskw = self.aeskwmap[self.keysize]()
            kek = JWK(kty="oct", use="enc", k=base64url_encode(dk))
            cek = aeskw.unwrap(kek, bitsize, ek, headers)
            return cek


class _EcdhEsAes128Kw(_EcdhEs):

    name = 'ECDH-ES+A128KW'
    description = 'ECDH-ES using Concat KDF and "A128KW" wrapping'
    keysize = 128
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'


class _EcdhEsAes192Kw(_EcdhEs):

    name = 'ECDH-ES+A192KW'
    description = 'ECDH-ES using Concat KDF and "A192KW" wrapping'
    keysize = 192
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'


class _EcdhEsAes256Kw(_EcdhEs):

    name = 'ECDH-ES+A256KW'
    description = 'ECDH-ES using Concat KDF and "A256KW" wrapping'
    keysize = 256
    algorithm_usage_location = 'alg'
    algorithm_use = 'kex'


class _EdDsa(_RawJWS, JWAAlgorithm):

    name = 'EdDSA'
    description = 'EdDSA using Ed25519 or Ed448 algorithms'
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'
    keysize = None

    def sign(self, key, payload):
        if key['crv'] in ['Ed25519', 'Ed448']:
            skey = key.get_op_key('sign')
            return skey.sign(payload)
        raise NotImplementedError

    def verify(self, key, payload, signature):
        if key['crv'] in ['Ed25519', 'Ed448']:
            pkey = key.get_op_key('verify')
            return pkey.verify(signature, payload)
        raise NotImplementedError


class _Ed25519(_RawJWS, JWAAlgorithm):

    name = 'Ed25519'
    description = 'EdDSA using Ed25519'
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'
    keysize = None

    def sign(self, key, payload):
        if key['crv'] != 'Ed25519':
            raise InvalidJWEKeyType('Ed25519', key['crv'])
        skey = key.get_op_key('sign')
        return skey.sign(payload)

    def verify(self, key, payload, signature):
        if key['crv'] != 'Ed25519':
            raise InvalidJWEKeyType('Ed25519', key['crv'])
        pkey = key.get_op_key('verify')
        return pkey.verify(signature, payload)


class _Ed448(_RawJWS, JWAAlgorithm):

    name = 'Ed448'
    description = 'EdDSA using Ed448'
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'
    keysize = None

    def sign(self, key, payload):
        if key['crv'] != 'Ed448':
            raise InvalidJWEKeyType('Ed448', key['crv'])
        skey = key.get_op_key('sign')
        return skey.sign(payload)

    def verify(self, key, payload, signature):
        if key['crv'] != 'Ed448':
            raise InvalidJWEKeyType('Ed448', key['crv'])
        pkey = key.get_op_key('verify')
        return pkey.verify(signature, payload)


class _RawJWE:

    def encrypt(self, k, aad, m):
        raise NotImplementedError

    def decrypt(self, k, aad, iv, e, t):
        raise NotImplementedError


class _AesCbcHmacSha2(_RawJWE):

    keysize = None

    def __init__(self, hashfn):
        self.backend = default_backend()
        self.hashfn = hashfn
        self.blocksize = algorithms.AES.block_size
        self.wrap_key_size = self.keysize * 2

    def _mac(self, k, aad, iv, e):
        al = _encode_int(_bitsize(aad), 64)
        h = hmac.HMAC(k, self.hashfn, backend=self.backend)
        h.update(aad)
        h.update(iv)
        h.update(e)
        h.update(al)
        m = h.finalize()
        return m[:_inbytes(self.keysize)]

    # RFC 7518 - 5.2.2
    def encrypt(self, k, aad, m):
        """ Encrypt according to the selected encryption and hashing
        functions.

        :param k: Encryption key
        :param aad: Additional Authentication Data
        :param m: Plaintext

        Returns a dictionary with the computed data.
        """
        if len(k) != _inbytes(self.wrap_key_size):
            raise ValueError("Invalid input key size")

        hkey = k[:_inbytes(self.keysize)]
        ekey = k[_inbytes(self.keysize):]

        # encrypt
        iv = _randombits(self.blocksize)
        cipher = Cipher(algorithms.AES(ekey), modes.CBC(iv),
                        backend=self.backend)
        encryptor = cipher.encryptor()
        padder = PKCS7(self.blocksize).padder()
        padded_data = padder.update(m) + padder.finalize()
        e = encryptor.update(padded_data) + encryptor.finalize()

        # mac
        t = self._mac(hkey, aad, iv, e)

        return (iv, e, t)

    def decrypt(self, k, aad, iv, e, t):
        """ Decrypt according to the selected encryption and hashing
        functions.
        :param k: Encryption key
        :param aad: Additional Authenticated Data
        :param iv: Initialization Vector
        :param e: Ciphertext
        :param t: Authentication Tag

        Returns plaintext or raises an error
        """
        if len(k) != _inbytes(self.wrap_key_size):
            raise ValueError("Invalid input key size")

        hkey = k[:_inbytes(self.keysize)]
        dkey = k[_inbytes(self.keysize):]

        # verify mac
        if not constant_time.bytes_eq(t, self._mac(hkey, aad, iv, e)):
            raise InvalidSignature('Failed to verify MAC')

        # decrypt
        cipher = Cipher(algorithms.AES(dkey), modes.CBC(iv),
                        backend=self.backend)
        decryptor = cipher.decryptor()
        d = decryptor.update(e) + decryptor.finalize()
        unpadder = PKCS7(self.blocksize).unpadder()
        return unpadder.update(d) + unpadder.finalize()


class _A128CbcHs256(_AesCbcHmacSha2, JWAAlgorithm):

    name = 'A128CBC-HS256'
    description = "AES_128_CBC_HMAC_SHA_256 authenticated"
    keysize = 128
    algorithm_usage_location = 'enc'
    algorithm_use = 'enc'

    def __init__(self):
        super(_A128CbcHs256, self).__init__(hashes.SHA256())


class _A192CbcHs384(_AesCbcHmacSha2, JWAAlgorithm):

    name = 'A192CBC-HS384'
    description = "AES_192_CBC_HMAC_SHA_384 authenticated"
    keysize = 192
    algorithm_usage_location = 'enc'
    algorithm_use = 'enc'

    def __init__(self):
        super(_A192CbcHs384, self).__init__(hashes.SHA384())


class _A256CbcHs512(_AesCbcHmacSha2, JWAAlgorithm):

    name = 'A256CBC-HS512'
    description = "AES_256_CBC_HMAC_SHA_512 authenticated"
    keysize = 256
    algorithm_usage_location = 'enc'
    algorithm_use = 'enc'

    def __init__(self):
        super(_A256CbcHs512, self).__init__(hashes.SHA512())


class _AesGcm(_RawJWE):

    keysize = None

    def __init__(self):
        self.backend = default_backend()
        self.wrap_key_size = self.keysize

    # RFC 7518 - 5.3
    def encrypt(self, k, aad, m):
        """ Encrypt according to the selected encryption and hashing
        functions.

        :param k: Encryption key
        :param aad: Additional Authentication Data
        :param m: Plaintext

        Returns a dictionary with the computed data.
        """
        iv = _randombits(96)
        cipher = Cipher(algorithms.AES(k), modes.GCM(iv),
                        backend=self.backend)
        encryptor = cipher.encryptor()
        encryptor.authenticate_additional_data(aad)
        e = encryptor.update(m) + encryptor.finalize()

        return (iv, e, encryptor.tag)

    def decrypt(self, k, aad, iv, e, t):
        """ Decrypt according to the selected encryption and hashing
        functions.
        :param k: Encryption key
        :param aad: Additional Authenticated Data
        :param iv: Initialization Vector
        :param e: Ciphertext
        :param t: Authentication Tag

        Returns plaintext or raises an error
        """
        cipher = Cipher(algorithms.AES(k), modes.GCM(iv, t),
                        backend=self.backend)
        decryptor = cipher.decryptor()
        decryptor.authenticate_additional_data(aad)
        return decryptor.update(e) + decryptor.finalize()


class _A128Gcm(_AesGcm, JWAAlgorithm):

    name = 'A128GCM'
    description = "AES GCM using 128-bit key"
    keysize = 128
    algorithm_usage_location = 'enc'
    algorithm_use = 'enc'


class _A192Gcm(_AesGcm, JWAAlgorithm):

    name = 'A192GCM'
    description = "AES GCM using 192-bit key"
    keysize = 192
    algorithm_usage_location = 'enc'
    algorithm_use = 'enc'


class _A256Gcm(_AesGcm, JWAAlgorithm):

    name = 'A256GCM'
    description = "AES GCM using 256-bit key"
    keysize = 256
    algorithm_usage_location = 'enc'
    algorithm_use = 'enc'


class _BP256R1(_RawEC, JWAAlgorithm):

    name = "BP256R1"
    description = (
        "ECDSA using Brainpool256R1 curve and SHA-256"
        " (unregistered, custom-defined in breach"
        " of IETF rules by gematik GmbH)"
    )
    keysize = 256
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_BP256R1, self).__init__('BP-256', hashes.SHA256())


class _BP384R1(_RawEC, JWAAlgorithm):

    name = "BP384R1"
    description = (
        "ECDSA using Brainpool384R1 curve and SHA-384"
        " (unregistered, custom-defined in breach"
        " of IETF rules by gematik GmbH)"
    )
    keysize = 384
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_BP384R1, self).__init__('BP-384', hashes.SHA384())


class _BP512R1(_RawEC, JWAAlgorithm):

    name = "BP512R1"
    description = (
        "ECDSA using Brainpool512R1 curve and SHA-512"
        " (unregistered, custom-defined in breach"
        " of IETF rules by gematik GmbH)"
    )
    keysize = 512
    algorithm_usage_location = 'alg'
    algorithm_use = 'sig'

    def __init__(self):
        super(_BP512R1, self).__init__('BP-512', hashes.SHA512())


class JWA:
    """JWA Signing Algorithms.

    This class provides access to all JWA algorithms.
    """

    algorithms_registry = {
        'HS256': _HS256,
        'HS384': _HS384,
        'HS512': _HS512,
        'RS256': _RS256,
        'RS384': _RS384,
        'RS512': _RS512,
        'ES256': _ES256,
        'ES256K': _ES256K,
        'ES384': _ES384,
        'ES512': _ES512,
        'PS256': _PS256,
        'PS384': _PS384,
        'PS512': _PS512,
        'none': _None,
        'RSA1_5': _Rsa15,
        'RSA-OAEP': _RsaOaep,
        'RSA-OAEP-256': _RsaOaep256,
        'A128KW': _A128KW,
        'A192KW': _A192KW,
        'A256KW': _A256KW,
        'dir': _Direct,
        'ECDH-ES': _EcdhEs,
        'ECDH-ES+A128KW': _EcdhEsAes128Kw,
        'ECDH-ES+A192KW': _EcdhEsAes192Kw,
        'ECDH-ES+A256KW': _EcdhEsAes256Kw,
        'EdDSA': _EdDsa,
        'A128GCMKW': _A128GcmKw,
        'A192GCMKW': _A192GcmKw,
        'A256GCMKW': _A256GcmKw,
        'PBES2-HS256+A128KW': _Pbes2Hs256A128Kw,
        'PBES2-HS384+A192KW': _Pbes2Hs384A192Kw,
        'PBES2-HS512+A256KW': _Pbes2Hs512A256Kw,
        'A128CBC-HS256': _A128CbcHs256,
        'A192CBC-HS384': _A192CbcHs384,
        'A256CBC-HS512': _A256CbcHs512,
        'A128GCM': _A128Gcm,
        'A192GCM': _A192Gcm,
        'A256GCM': _A256Gcm,
        'BP256R1': _BP256R1,
        'BP384R1': _BP384R1,
        'BP512R1': _BP512R1,
        'Ed25519': _Ed25519,
        'Ed448': _Ed448
    }

    @classmethod
    def instantiate_alg(cls, name, use=None):
        alg = cls.algorithms_registry[name]
        if use is not None and alg.algorithm_use != use:
            raise KeyError
        return alg()

    @classmethod
    def signing_alg(cls, name):
        try:
            return cls.instantiate_alg(name, use='sig')
        except KeyError:
            raise InvalidJWAAlgorithm('%s is not a valid Signing algorithm'
                                      ' name' % name) from None

    @classmethod
    def keymgmt_alg(cls, name):
        try:
            return cls.instantiate_alg(name, use='kex')
        except KeyError:
            raise InvalidJWAAlgorithm('%s is not a valid Key Management'
                                      ' algorithm name' % name) from None

    @classmethod
    def encryption_alg(cls, name):
        try:
            return cls.instantiate_alg(name, use='enc')
        except KeyError:
            raise InvalidJWAAlgorithm('%s is not a valid Encryption'
                                      ' algorithm name' % name) from None
