#!/usr/bin/python3
import os
import stat
import fcntl
import yaml
import logging
import mimetypes
import hashlib
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path

import boto3
from botocore.exceptions import ClientError, EndpointConnectionError


@dataclass(frozen=True)
class Config:
    source_path: Path

    s3_bucket: str = ""
    s3_prefix: str = ""

    multipart_threshold_mb: int = 100
    multipart_chunk_mb: int = 50
    max_workers: int = 4
    max_retries: int = 3
    retry_base_delay_s: float = 1.0

    aws_access_key_id: str | None = None
    aws_secret_access_key: str | None = None
    aws_region: str = "us-east-1"
    endpoint_url: str | None = None

    dry_run: bool = False
    recursive: bool = True
    run_async: bool = False
    follow_symlinks: bool = False
    skip_existing: bool = True
    log_level: str = "INFO"
    log_file: Path | None = None

    @classmethod
    def from_args(cls, args) -> "Config":
        return cls(**{k: v for k, v in vars(args).items() if v is not None})

    @classmethod
    def from_yaml(cls, file_path: str = "s3_config.yaml") -> "Config":
        with open(file_path) as f:
            data = yaml.safe_load(f)

        return cls(**data)


@dataclass(frozen=True)
class TransferResult:
    path: Path
    s3_key: str | None = None
    status: str | None = None
    checksum: str | None = None
    error: str | None = None


class TransferLogger:
    def __init__(self, config: Config):
        self.config = config
        self._logger = None

        self._init_logger()

    def _init_logger(self):
        self._logger = logging.getLogger("logger")
        self._logger.setLevel(self.config.log_level)

        handler = logging.FileHandler(self.config.log_file)
        formatter = logging.Formatter(
            "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
        )
        handler.setFormatter(formatter)

        self._logger.addHandler(handler)

    def __getattr__(self, name):
        if self._logger and hasattr(self._logger, name):
            return getattr(self._logger, name)

        raise AttributeError(name)


class FileScanner:
    def __init__(self, config: Config, logger: TransferLogger):
        self.config = config
        self.logger = logger

    def scan(self):
        for path in self._walk(self.config.source_path):
            if not all((
                self._is_regular_file(path),
                self._is_closed(path)
            )):
                continue

            yield path

    def _walk(self, path: Path):
        try:
            with os.scandir(path) as it:
                for entry in it:
                    try:
                        if entry.is_file(follow_symlinks=False):
                            yield Path(entry)
                        elif entry.is_dir(follow_symlinks=False) and self.config.recursive:
                            yield from self._walk(Path(entry))
                    except PermissionError as e:
                        self.logger.error(f"Scan failed {Path(entry)}: {e}")
        except PermissionError as e:
            self.logger.error(f"Scan failed {path}: {e}")

    def _is_regular_file(self, path: Path):
        try:
            if not stat.S_ISREG(path.stat().st_mode):
                self.logger.error(f"Scan failed {path}: not regular file")
                return False

            if path.is_symlink():
                self.logger.error(f"Scan failed {path}: file is symlink")
                return False

            return True
        except OSError as e:
            self.logger.error(f"Scan failed {path}: {e}")
            return False

    def _is_closed(self, path: Path):
        try:
            with open(path, "rb") as f:
                fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)

            return True
        except OSError:
            return False

    def hash(self, path: Path):
        SIZE = 8 * 1024 * 1024

        h = hashlib.md5()
        with open(path, "rb") as fh:
            while True:
                chunk = fh.read(SIZE)
                if not chunk:
                    break

                h.update(chunk)
        return h.hexdigest()


class S3Connector:
    def __init__(self, config: Config, logger: TransferLogger):
        self.config = config
        self.logger = logger

        self.conn = boto3.client(
            's3',
            endpoint_url=config.endpoint_url,
            aws_access_key_id=config.aws_access_key_id,
            aws_secret_access_key=config.aws_secret_access_key,
            aws_session_token=None,
            region_name=config.aws_region,
            config=boto3.session.Config(
                signature_version="s3v4",
                max_pool_connections=50,
                tcp_keepalive=True,
                s3={"addressing_style": "path"},
                retries={
                    "mode": "adaptive",
                    'max_attempts': 1
                    }
                ),
            verify=False
        )

        self.check_connection()

    def _get_metadata(self, s3_key):
        try:
            return self.conn.head_object(Bucket=self.config.s3_bucket, Key=s3_key)
        except ClientError as e:
            code = e.response["ResponseMetadata"]["HTTPStatusCode"]
            if code == 404:
                if s3_key == "/":
                    return {
                        "ContentType": "application/x-directory",
                        "ContentLength": 0,
                        "Metadata": {}
                        }
                return None
            elif code == 400:
                if s3_key == "/":
                    return {
                        "ContentType": "application/x-directory",
                        "ContentLength": 0,
                        "Metadata": {}
                        }
            raise e

    def check_connection(self):
        try:
            self.conn.head_bucket(Bucket=self.config.s3_bucket)
        except ClientError as e:
            code = e.response["ResponseMetadata"]["HTTPStatusCode"]
            if code == 404:
                self.logger.error(f"Bucket {self.config.s3_bucket} does not exist")
            elif code == 403:
                self.logger.error(f"No access to bucket {self.config.s3_bucket}")
            else:
                self.logger.error(f"Connection to S3 failed: {e}")
            raise e
        except EndpointConnectionError as e:
            self.logger.error(f"Connection to S3 failed: {e}")
            raise e

    def exists(self, s3_key: str):
        return self._get_metadata(s3_key) is not None

    def hash(self, s3_key):
        return self._get_metadata(s3_key)["ETag"].strip('"')

    def upload(self, path: Path, s3_key: str, *kwargs):
        metadata = {
            "ctime": str(path.stat().st_ctime),
            "mtime": str(path.stat().st_mtime),
            "mode": str(0o777)
            }

        with open(path, "rb") as f:
            return self.conn.upload_fileobj(
                f,
                self.config.s3_bucket,
                s3_key,
                ExtraArgs={"Metadata": metadata}
                )

    def _guess_mimetype(self, path):
        mimetypes.add_type('application/vnd.ms-outlook', '.msg')
        mimetypes.add_type('text/plain', '.log')

        mimetype, _ = mimetypes.guess_type(path)
        return mimetype

    def as_s3_key(self, path: Path) -> str:
        relative = path.relative_to(self.config.source_path)
        return f"{self.config.s3_prefix.rstrip('/')}/{relative}".lstrip("/")


class TransferManager:
    def __init__(self, config: Config):
        self.config = config
        self.logger = TransferLogger(config)
        self.file_scanner = FileScanner(config, self.logger)
        self.s3_connector = S3Connector(config, self.logger)

        self.results = []

    def _iter_paths(self):
        for path in self.file_scanner.scan():
            s3_key = self.s3_connector.as_s3_key(path)

            if self.s3_connector.exists(s3_key):
                if self.config.skip_existing or (self.file_scanner.hash(path) == self.s3_connector.hash(s3_key)):
                    res = TransferResult(path=path, status="skipped")
                    self.results.append(res)
                    self.logger.info(f"Transfer result: {res}")
                    continue
                yield path
            else:
                yield path

    def upload_file(self, path: Path):
        s3_key = self.s3_connector.as_s3_key(path)
        local_checksum = self.file_scanner.hash(path)

        if self.config.dry_run:
            return TransferResult(path=path, s3_key=s3_key, checksum=local_checksum, status="dry_run")

        try:
            self.s3_connector.upload(path, s3_key)
            remote_checksum = self.s3_connector.hash(s3_key)

            if local_checksum != remote_checksum:
                raise ValueError(f"Checksum mismatch after upload: {path}")

            return TransferResult(path=path, s3_key=s3_key, checksum=local_checksum, status="ok")
        except Exception as e:
            self.logger.error(f"Failed: {path} — {e}")
            return TransferResult(path=path, status="failed", error=str(e))

    def run(self):
        self.logger.info("Transfer start")

        if self.config.run_async:
            self.run_async()
        else:
            self.run_sync()

        self._finalize()

    def run_async(self):
        paths = list(self._iter_paths())

        with ThreadPoolExecutor(max_workers=self.config.max_workers) as pool:
            futures = {pool.submit(self.upload_file, path): path for path in paths}

            for future in as_completed(futures):
                result = future.result()
                self.results.append(result)
                self.logger.info(f"Transfer result: {result}")

    def run_sync(self):
        for path in self._iter_paths():
            result = self.upload_file(path)
            self.results.append(result)
            self.logger.info(f"Transfer result: {result}")

    def _finalize(self):
        ok = sum(1 for r in self.results if r.status == "ok")
        failed = sum(1 for r in self.results if r.status == "failed")
        skipped = sum(1 for r in self.results if r.status == "skipped")
        dry_run = sum(1 for r in self.results if r.status == "dry_run")
        self.logger.info(f"Done: {ok} ok, {failed} failed, {skipped} skipped, {dry_run} dry_run")


if __name__ == "__main__":
    config = Config.from_yaml()

    tm = TransferManager(config)
    tm.run()
