#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import time
import argparse
import subprocess
import logging
import json

logger = logging.getLogger(__name__)

TITLE_TARGET = ''


def which(program):

    def is_exe(pth):
        return os.path.isfile(pth) and os.access(pth, os.X_OK)

    fpath, _ = os.path.split(program)
    if fpath:
        if is_exe(program):
            return program
    else:
        for path in os.environ['PATH'].split(os.pathsep):
            exe_file = os.path.join(path, program)
            if is_exe(exe_file):
                return exe_file

    return None


class AttrDict(dict):

    def __getattr__(self, name):
        return AttrDict(self[name]) if isinstance(self[name], dict) else self[name]


class FioJob():
    def __init__(self, title, name, rw, bsize, numjobs=1, iodepth=1,
                 reference={}, **kwargs):
        self.title = f'{title}, {TITLE_TARGET}' if TITLE_TARGET else title
        self.name = name
        self.iodepth = iodepth
        self.rw = rw
        self.numjobs = numjobs
        self.group_reporting = 1
        self.randrepeat = 0
        self.bsize = bsize
        self.raw = None
        self.reference = reference
        self.results = []
        self.metrics = []
        self.results_text = []
        self.kwargs = kwargs

    def args(self):
        args = ['--name=%s' % self.name, '--bs=%dk' % self.bsize,
                '--iodepth=%d' % self.iodepth, '--numjobs=%d' % self.numjobs,
                '--rw=%s' % self.rw, '--group_reporting=%d' % self.group_reporting,
                '--randrepeat=%s' % self.randrepeat]
        for k, v in self.kwargs.items():
            args.append('--{0}={1}'.format(k, v))
        return args

    def parse_result(self):
        fmt = u'%-50s | %s | %8d MiB/s | %8d | %6.1f/%6.1f/%6.1f usec'

        # Считаем пропускную способность
        bw_r = self.raw.read.get('bw_bytes', self.raw.read.bw * 1024) / 1024 / 1024
        bw_w = self.raw.write.get('bw_bytes', self.raw.write.bw * 1024) / 1024 / 1024

        # Latency, в разных версиях fio разные колонки
        if 'lat_ns' in self.raw.read:
            lr = self.raw.read.lat_ns
            lw = self.raw.write.lat_ns
            div = 1000  # наносекунды →в микросекунды
        else:
            lr = self.raw.read.lat
            lw = self.raw.write.lat
            div = 1     # уже в микросекундах

        # Конвертим среднюю/мин/макс
        lat_r_min = lr.min / div
        lat_w_min = lw.min / div
        lat_r_max = lr.max / div
        lat_w_max = lw.max / div

        if 'lat_ns' in self.raw.read:
            # Новый fio, в наносекундах
            lat_r_min = self.raw.read.lat_ns.min / 1000
            lat_r_avg = self.raw.read.lat_ns.mean / 1000
            lat_r_max = self.raw.read.lat_ns.max / 1000

            lat_w_min = self.raw.write.lat_ns.min / 1000
            lat_w_avg = self.raw.write.lat_ns.mean / 1000
            lat_w_max = self.raw.write.lat_ns.max / 1000
        else:
            # Старый fio, в микросекундах
            lat_r_min = self.raw.read.lat.min
            lat_r_avg = self.raw.read.lat.mean
            lat_r_max = self.raw.read.lat.max

            lat_w_min = self.raw.write.lat.min
            lat_w_avg = self.raw.write.lat.mean
            lat_w_max = self.raw.write.lat.max

        add_pref = False
        if bw_r and bw_w:
            add_pref = True

        if bw_r:
            res = {
                'title': self.title,
                'type': 'R',
                'bandwidth': bw_r,
                'iops': int(self.raw.read.iops),
                'lat_min': lat_r_min,
                'lat_avg': lat_r_avg,
                'lat_max': lat_r_max,
            }
            res['human'] = fmt % (
                self.title, 'R', bw_r, int(self.raw.read.iops),
                lat_r_min, lat_r_avg, lat_r_max
            )
            self.results.append(res)

            pref = ', чтение' if add_pref else ''
            ref = self.reference['r'][0]
            self.metrics.append({
                'title': f"{self.title}{pref}, скорость",
                'value': bw_r,
                'reference': ref,
                'format': '{x:.2f} MB/sec',
                'condition': 'greater',
                'colorize': False,
            })
            ref = self.reference['r'][1]
            self.metrics.append({
                'title': f"{self.title}{pref}, IOPS (не менее {ref:d})",
                'value': int(self.raw.read.iops),
                'reference': ref,
                'format': '{x:d}',
                'condition': 'greater',
            })
            # ref = self.reference['r'][2]
            # self.metrics.append({
            #     'title': f"{self.title}{pref}, latency (не более {ref:.1f} usec)",
            #     'value': lat_r_avg,
            #     'reference': ref,
            #     'format': '{x:.1f} usec',
            #     'condition': 'less',
            # })

        if bw_w:
            res = {
                'title': self.title,
                'type': 'W',
                'bandwidth': bw_w,
                'iops': int(self.raw.write.iops),
                'lat_min': lat_w_min,
                'lat_avg': lat_w_avg,
                'lat_max': lat_w_max,
            }
            res['human'] = fmt % (
                self.title, 'W', bw_w, int(self.raw.write.iops),
                lat_w_min, lat_w_avg, lat_w_max
            )
            self.results.append(res)

            pref = ', запись' if add_pref else ''
            ref = self.reference['w'][0]
            self.metrics.append({
                'title': f"{self.title}{pref}, скорость",
                'value': bw_w,
                'reference': ref,
                'format': '{x:.2f} MB/sec',
                'condition': 'greater',
                'colorize': False,
            })
            ref = self.reference['w'][1]
            self.metrics.append({
                'title': f"{self.title}{pref}, IOPS (не менее {ref:d})",
                'value': int(self.raw.write.iops),
                'reference': ref,
                'format': '{x:d}',
                'condition': 'greater',
            })
            # ref = self.reference['w'][2]
            # self.metrics.append({
            #     'title': f"{self.title}{pref}, latency (не более {ref:.1f} usec)",
            #     'value': lat_w_avg,
            #     'reference': ref,
            #     'format': '{x:.1f} usec',
            #     'condition': 'less',
            # })

class TestRunner():
    def __init__(self, target, loops, size, runtime=300, numjobs=1,
                 cdm=False, keep_testfile=False, prepare='true', hdd=False):
        self.results = []
        self.metrics = []
        self.jobs = []
        self.prepare = prepare
        self.cdm = cdm
        self.keep_testfile = keep_testfile
        self.loops = loops
        self.size = size
        self.numjobs = numjobs
        self.runtime = runtime
        self.elapsed = 0
        self.start = None
        self.hdd = hdd
        self.end = None
        self.fio_path = which('fio') or '/usr/bin/fio'
        self.targets = [
            f"{os.path.join(target, '.fiomark1.tmp')}",
            f"{os.path.join(target, '.fiomark2.tmp')}",
            f"{os.path.join(target, '.fiomark3.tmp')}",
            f"{os.path.join(target, '.fiomark4.tmp')}",
            f"{os.path.join(target, '.fiomark5.tmp')}",
        ]
        self.target = ':'.join(self.targets)
        if not self.fio_path:
            raise RuntimeError('fio could not be found')
        self.base_args = [
            self.fio_path,
            '--output-format=json',
            '--stonewall',
            '--invalidate=1',
            '--loops=%d' % self.loops,
            '--size=%dM' % self.size,
            '--filename=%s' % self.target,
        ]
        if self.runtime:
            self.base_args.append('--time_based=1')
            # Рантайм влияет на время теста, тестов может быть несколько,
            # и у каждого свой рантайм. Задать общее время работы fio невозможно.
            # Т.е. общее время можно вычислить только когда будет известно
            # количество тестов, а для разных стореджей оно может быть разным.
            # # self.base_args.append('--runtime={0}'.format(self.runtime))

    def add_job(self, job):
        self.jobs.append(job)

    def prepare_targets(self):
        logger.info("Подготавливаем файлы")
        size = self.size / len(self.targets)
        for target in self.targets:
            cmd = ["dd", "bs=1M", f"count={int(size)}", "oflag=direct", "if=/dev/urandom", f"of={target}"]
            logger.debug(" ".join(cmd))
            subprocess.run(
                cmd, 
                check=True,
                text=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
        logger.info("Выполняем sync два раза")
        subprocess.run("sync", check=True)
        subprocess.run("sync", check=True)
        logger.info("Сбрасываем page cache (может и не сработать, зависит от среды)")
        try:
            with open("/proc/sys/vm/drop_caches", "w") as outfile:
                subprocess.run(["echo", "1"], stdout=outfile, check=False)
        except OSError:
            logger.debug('Не удалось сбросить кеш (echo 1 > drop_caches)')
            pass
        

    def run_test(self):
        # Подготавливаем файлы и синкаем на диск (если не указали пропустить)
        # becnhmark_runner будет запускать тест с таймаутом
        # важно уложиться, поэтому создать файлы и выполнить сам тест можно в два действия
        if self.prepare != 'skip':
            self.prepare_targets()
        
        if self.prepare == 'only':
            return

        self.add_jobs()
        if not self.jobs:
            raise RuntimeError('No jobs defined')
        args = self.base_args
        logger.debug("Main args: %s", " ".join(args))

        buf = {}
        for job in self.jobs:
            buf[job.name] = job
            logger.debug("Job %s args: %s", job.title, " ".join(job.args()))
            args.extend(job.args())

        logger.debug("fio cmd: %s", " ".join(args))

        self.start = time.time()
        raw = json.loads(subprocess.check_output(args))
        self.raw_json = raw
        self.end = time.time()
        self.elapsed = self.end - self.start

        if not self.keep_testfile:
            try:
                for target in self.targets:
                    os.unlink(target)
            except:
                pass

        for r in raw['jobs']:
            buf[r['jobname']].raw = AttrDict(r)

        for job in self.jobs:
            job.parse_result()
            self.results += job.results
            self.metrics += job.metrics

    def add_jobs(self):
        self.base_args.append('--ioengine=libaio')
        self.base_args.append('--direct=1')

        if self.cdm:
            # Набор джоб - типичный CrystalDiskMark
            self.add_job(FioJob(title='Буферизированное параллельное чтение/запись', name='Bufread',
                                rw='readwrite', bsize=int(self.size / 32), iodepth=1,
                                numjobs=self.numjobs,
                                reference={'r': (300, 1000, 1000), 'w': (300, 1000, 1000)}))
            self.add_job(FioJob(title='Последовательное чтение Q32T1 %d MiB' % int(self.size / 32),
                                name='SeqQ32T1read', rw='read', bsize=int(self.size / 32),
                                numjobs=self.numjobs, iodepth=32,
                                reference={'r': (300, 1000, 1000)}))
            self.add_job(FioJob(title='Последовательная запись Q32T1 %d MiB' % int(self.size / 32),
                                name='SeqQ32T1write', rw='write', bsize=int(self.size / 32),
                                numjobs=self.numjobs, iodepth=32,
                                reference={'w': (300, 1000, 1000)}))
            self.add_job(FioJob(title='Случайное чтение блоками по 4KiB', name='4kread',
                                numjobs=self.numjobs, rw='randread', bsize=4, iodepth=1,
                                reference={'r': (300, 1000, 1000)}))
            self.add_job(FioJob(title='Случайная запись блоками по 4KiB', name='4kwrite',
                                numjobs=self.numjobs, rw='randwrite', bsize=4, iodepth=1,
                                reference={'w': (300, 1000, 1000)}))
            self.add_job(FioJob(title='Случайное чтение блоками по 4KiB Q32T1', name='4kQ32T1read',
                                numjobs=1, rw='randread', bsize=4, iodepth=32,
                                reference={'r': (5, 1000, 1000)}))
            self.add_job(FioJob(title='Случайная запись блоками по 4KiB Q32T1',
                                numjobs=1, rw='randwrite', bsize=4, iodepth=32,
                                name='4kQ32T1write', reference={'w': (5, 1000, 1000)}))
            self.add_job(FioJob(title='Случайное чтение блоками по 4KiB Q8T8', name='4kQ8T8read',
                                numjobs=8, rw='randread', bsize=4, iodepth=8,
                                reference={'r': (300, 1000, 1000)}))
            self.add_job(FioJob(title='Случайная запись блоками по 4KiB Q8T8', name='4kQ8T8write',
                                numjobs=8, rw='randwrite', bsize=4, iodepth=8,
                                reference={'w': (300, 1000, 1000)}))
        else:
            # hdd нельзя
            # reference_r = {'r': (330, 86000, 86000)}
            # reference_w = {'w': (180, 46000, 46000)}
            # if self.hdd:
            #     reference_r = {'r': (2, 500, 500)}
            #     reference_w = {'w': (1, 240, 240)}
            reference_r = {'r': (330, 20000, 20000)}
            reference_w = {'w': (180, 20000, 20000)}
            self.add_job(FioJob(title='Случайное чтение блоками по 4KiB Q32T1',
                                numjobs=1, rw='randread', bsize=4, iodepth=32,
                                 name='4kQ32T1read', reference=reference_r))
            self.add_job(FioJob(title='Случайная запись блоками по 4KiB Q32T1',
                                numjobs=1, rw='randwrite', bsize=4, iodepth=32,
                                name='4kQ32T1write', reference=reference_w))

        # Дефолтный runtime по тестам, задать общий - нельзя.
        runtime = self.runtime / len(self.jobs)
        self.base_args.append('--runtime={0}'.format(round(runtime, 2)))


def format_duration(seconds):
    """Convert seconds to H:M:S format."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = seconds % 60  # keep fractional part

    if hours > 0:
        # HhMMmSS.sss format
        return f"{hours}h{minutes:02d}m{secs:06.3f}s"
    elif minutes > 0:
        # MmSS.sss format (GNU style)
        return f"{minutes}m{secs:06.3f}s"
    else:
        # 0mSS.sss format (GNU style)
        return f"0m{secs:.3f}s"


def configure_logging(debug=False):
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(
        level=level,
        format="%(asctime)s - [%(levelname)s] - %(module)s.%(funcName)s - %(message)s",
    )


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--target', help='Путь для проверки, default = /mnt/shared',
                   type=str, default='/mnt/shared')
    p.add_argument('-l', '--loops', default=1, type=int,
                   help='Количество повторений тестов (1 по-умолчанию).'
                   ' Большее количество даст более точную картину')
    # 4096 - чтобы точно выбить кеш nvme
    p.add_argument('-s', '--size', help='Размер файла для тестов (1024 по-умолчанию)',
                   default=10240, type=int)
    p.add_argument('-j', '--numjobs', help='Количество потоков (5 по-умолчанию)',
                   default=5, type=int)
    # Дефолт 300, для local storage 9 тестов, т.е. примерно по 33с на тест.
    # Этого должно хватить, чтобы пробить любой дисковый кеш любого стореджа.
    p.add_argument('-t', '--runtime', default=300, type=int,
                   help='Общее время одного прохода всех тестов. Время выполнение каждого теста'
                   'Будет зависеть от количество тестов (может отличаться оп стреджам).'
                   ' WARNING --loops могут увеличить это время кратно')
    p.add_argument("--cdm", action="store_true",
                   help="Выполнить большой набор тестов, подобный CrystalDiskMark")
    p.add_argument("--hdd", action="store_true",
                   help=argparse.SUPPRESS)
                   # help="Тест проводится на hdd - флаг влияет на референсы (по-умолчанию, считаем, что ssd)")
    p.add_argument("-k", "--keep", action="store_true",
                   help="Оставить тестовый файл, не удалять")
    p.add_argument("--prepare", type=str, default="true",
                   choices=['true', 'only', 'skip'],
                   help="Подготовить файлы перед тестами (по-умолчанию - true)")
    p.add_argument("--json", action="store_true",
                   help="Напечатать вывод в json")
    p.add_argument("--debug", action="store_true", default=False,
                   help='Вывод дебага (при --json всегда отключено)')

    return p.parse_args()


if __name__ == '__main__':
    args = parse_args()
    if not args.json:
        configure_logging(debug=args.debug)

    i, m = divmod(args.size, 32)
    if m != 0 or i == 0:
        logger.error('ERROR! Неподходящий размер файла. Размер в мегабайтах,'
                     ' должен быть не менее 32 и делиться на 32 без остатка')
        sys.exit(1)

    if args.json:
        TITLE_TARGET = args.target

    runner = TestRunner(target=args.target, loops=args.loops, size=args.size,
                        runtime=args.runtime, numjobs=args.numjobs, cdm=args.cdm,
                        keep_testfile=args.keep, prepare=args.prepare, hdd=args.hdd)
    runner.run_test()

    if args.prepare == 'only':
        sys.exit(0)

    if args.json:
        print(json.dumps(runner.metrics, ensure_ascii=False))
    else:
        FMT = u'%-50s | %s | %-14s | %-8s | %-8s'
        THEAD = FMT % (u'Test', u'T', u'Speed', u'IOPS', u'Latency (min/avg/max)')
        logger.info('Testing %s, loops: %d', args.target, args.loops)
        logger.info(u'-' * 115,)
        logger.info("%s", THEAD)
        logger.info(u'-' * 115)
        for result in runner.results:
            logger.info(result['human'])
        logger.info(u'-' * 115)
        logger.info("Total runtime: %s", format_duration(runner.elapsed))
