#!/usr/bin/env python3
import sys
import json
import argparse
import time
from urllib.parse import urlparse
import os

class Colors:
	GREEN = '\033[92m'
	YELLOW = '\033[93m'
	RED = '\033[91m'
	BOLD = '\033[1m'
	END = '\033[0m'

try:
	import psycopg2
except ImportError:
	print(f"{Colors.RED}Ошибка: Установите psycopg2:")
	print(f"{Colors.YELLOW}  - apt install python3-psycopg2")
	print(f"  - python3 -m pip install psycopg2{Colors.END}")
	sys.exit(1)

class PGBench:
	def __init__(self, args):
		# Проверяем указана ли переменная с URL подключения к postgresql
		if os.getenv('EVA_CONFIG__DB_URL'):
			parsed = urlparse(os.getenv('EVA_CONFIG__DB_URL'))
			username = parsed.username or "postgres"
			password = parsed.password or None
			host = parsed.hostname
			port = parsed.port or 5432  # стандартный порт PostgreSQL
			database = parsed.path.lstrip('/').split('/', 1)[0] or "evadb"
		else:
			username = None
			password = None
			host = ""
			port = None
			database = None
		
		# Даже если переменная с URL подключения указана, приоритет аргументам
		if args.username:
			username = args.username
		elif not username and os.getenv('PGUSER'):
			username = os.getenv('PGUSER')
		elif not username:
			username = "postgres"

		if args.password:
			password = args.password
		elif not password and os.getenv('PGPASSWORD'):
			password = os.getenv('PGPASSWORD')
		elif not password:
			password = ""

		if args.host:
			host = args.host
		elif not host and os.getenv('PGHOST'):
			host = os.getenv('PGHOST')
		elif not host and os.getenv('HOSTADDR'):
			host = os.getenv('HOSTADDR')
		elif not host:
			# Если host не указывался, делаем его пустой строкой, чтобы не было ошибок
			host = ""

		if args.port:
			port = args.port
		elif not port and os.getenv('PGPORT'):
			port = os.getenv('PGPORT')
		elif not port:
			port = 5432

		if args.dbname:
			database = args.dbname
		elif not database and os.getenv('PGPASSWORD'):
			database = os.getenv('PGDATABASE')
		elif not database:
			database = "evadb"

		if not args.json:
			if password is None or password.strip() == "":
				print(f'{Colors.YELLOW}Warning: Переменная PGPASSWORD не задана, ключ -p/--password не был указан{Colors.END}', file=sys.stderr)
		
		self.table = args.table

		try:
			# Если host пустой, значит подключаемся к socket без host и port
			if not host:
				self.conn = psycopg2.connect(
					dbname=database,
					user=username,
					password=password
				)
			else:
				self.conn = psycopg2.connect(
					dbname=database,
					user=username,
					password=password,
					host=host,
					port=str(port)
				)
		except Exception:
			print(f"Не удалось подключиться к БД:")
			print(f"postgres://{username}:{password}@{host}:{port}/{database}")
			sys.exit(1)

		self.conn.autocommit = False
		self.large_data = b'X' * 10_485_760  # 10 МБ данных
		self.large_data_ids = []

	def bench_small_queries(self, duration=1):
		"""Замер скорости выполнения мелких запросов к PostgreSQL"""
		cursor = self.conn.cursor()
		
		# Подготовим простой запрос
		query = "SELECT 1"
		
		operations = 0
		start_time = time.time()
		end_time = start_time + duration
		
		# Выполняем запросы в течение 1 секунды
		while time.time() < end_time:
			cursor.execute(query)
			cursor.fetchone()  # Вычитываем результат
			operations += 1
		
		# Коммитим транзакцию
		self.conn.commit()
		
		elapsed = time.time() - start_time
		ops_per_sec = operations / elapsed / duration
		
		reference = 1000 # 100
		format_str = "{x:.2f} ops/sec"
		cursor.close()

		return {
			"title": f"PostgreSQL время выполнения мелких запросов - latency (не менее {format_str.format(x=reference)})",
			"value": ops_per_sec,
			"reference": reference,
			"format": format_str,
			"condition": "greater"
		}

	def large_write_benchmark(self, duration=5):
		"""Бенчмарк записи больших данных"""
		cursor = self.conn.cursor()
		
		# Подготовка таблицы
		cursor.execute(f"""
			DROP TABLE IF EXISTS {self.table};
			CREATE TABLE {self.table} (id SERIAL PRIMARY KEY, data BYTEA);
		""")
		
		bytes_written = 0
		start_time = time.time()
		end = start_time + duration
		
		while time.time() < end:
			cursor.execute(f"INSERT INTO {self.table} (data) VALUES (%s) RETURNING octet_length(data), id", 
					   (self.large_data,))
			result, value_id = cursor.fetchone()
			bytes_written += result
			self.large_data_ids.append(value_id)
		
		self.conn.commit()

		self.large_data_written = True
		
		write_time = time.time() - start_time
		mb_per_sec = bytes_written / 1024 / 1024 / write_time

		reference = 10
		format_str = "{x:.2f} MB/sec"
		cursor.close()
		
		return {
			"title": f"PostgreSQL скорость записи больших данных (не менее {format_str.format(x=reference)})",
			"value": mb_per_sec,
			"reference": reference,
			"format": format_str,
			"condition": "greater"
		}

	def large_read_benchmark(self, duration=2):
		"""
		Замер скорости чтения больших данных из БД
		Возвращает скорость в МБ/сек
		"""
		if not hasattr(self, 'large_data_written') or not self.large_data_written:
			# Если нет записанных данных, создаем тестовые
			self.large_write_benchmark()
		
		cursor = self.conn.cursor()
		bytes_read = 1
		start_time = time.time()
		end_time = start_time + duration
		
		# Читаем записи по кругу
		record_count = len(self.large_data_ids)
		idx = 1
		
		while time.time() < end_time:
			# Читаем следующую запись
			cursor.execute(
				f"SELECT data FROM {self.table} WHERE id = %s",
				(self.large_data_ids[idx % record_count],)
			)
			data = cursor.fetchone()[0]
			bytes_read += len(data)
			idx += 1
		
		read_time = time.time() - start_time
		mb_per_sec = (bytes_read / 1024 / 1024) / read_time
		
		reference = 20
		format_str = "{x:.2f} MB/sec"
		cursor.close()
		
		return {
			"title": f"PostgreSQL скорость чтения больших данных (не менее {format_str.format(x=reference)})",
			"value": mb_per_sec,
			"reference": reference,
			"format": format_str,
			"condition": "greater"
		}

	def cleanup(self):
		"""Очистка"""
		try:
			cursor = self.conn.cursor()
			cursor.execute(f"DROP TABLE IF EXISTS {self.table};")
			self.conn.commit()
		finally:
			cursor.close()
			self.conn.close()

def get_color(speed, limit=160):
	"""Получение цвета для скорости"""
	half_limit = limit / 1.5
	if speed > limit:
		return Colors.GREEN
	elif speed >= half_limit:
		return Colors.YELLOW
	else:
		return Colors.RED + Colors.BOLD

def main():
	# Обработка агрументов скрипта
	parser = argparse.ArgumentParser(description='Скрипт для тестирования производительности PostgreSQL')
	parser.add_argument('-d', '--dbname', type=str,
		help='Имя БД для подключения (default: evadb)')
	parser.add_argument('-u', '--username', type=str,
		help='Имя пользователя БД для подключения (default: postgres)')
	parser.add_argument('-p', '--password', type=str,
		help='Пароль пользователя БД для подключения (можно также установить переменной окружения PGPASSWORD)')
	parser.add_argument('-H', '--host', type=str,
		help='Адрес сервера БД (default: socket)')
	parser.add_argument('-P', '--port', type=int,
		help='Порт сервера БД для подключения (default: 5432)')
	parser.add_argument('-t', '--table', type=str, default='eva_performance_test',
		help='Имя таблицы, которая будет создана в рамках тестирования (default: eva_performance_test)')
	parser.add_argument('--json', action='store_true', help='Вывод в формате JSON')
	parser.add_argument('--nocolor', action='store_true', help='Вывод без цветовых акцентов')
	args = parser.parse_args()

	# Инициализация объекта класса
	test = PGBench(args)
	
	# В try, чтобы в случае неудачи закрылось соединение
	result = []
	try:
		result.append(test.bench_small_queries())
		result.append(test.large_write_benchmark())
		result.append(test.large_read_benchmark())
	finally:
		test.cleanup()

	for value in result:
		if value['value'] is None:
			print(f"Ошибка: Не удалось получить {value['title']}")
			sys.exit(1)
		if not args.json:
			if args.nocolor:
				print(f"{value['title']}: {value['format'].format(x=value['value'])}")
			else:
				print(f"{value['title']}: {get_color(value['value'], value['reference'])}{value['format'].format(x=value['value'])}{Colors.END}")

	if args.json:
		print(json.dumps(result, ensure_ascii=False))
	
	return 0

if __name__ == "__main__":
	main()