import sqlite3
import traceback
import os
import argparse
from configparser import ConfigParser
from typing import List, Tuple

class DBNotFoundError(Exception):
    '''Could not find the database location'''

class InvalidMigrationError(Exception):
    '''Migration file has an invalid name'''

class KemoverseEnvUnset(Exception):
    '''KEMOVERSE_ENV is not set or has an invalid value'''

class ConfigError(Exception):
    '''Could not find the config file for the current environment'''

def get_migrations() -> List[Tuple[int, str]] | InvalidMigrationError:
    '''Returns a list of migration files in numeric order.'''
    # Store transaction id and filename separately
    sql_files: List[Tuple[int, str]] = []
    migrations_dir = 'migrations'

    for filename in os.listdir(migrations_dir):
        joined_path = os.path.join(migrations_dir, filename)

        # Ignore anything that isn't a .sql file
        if not (os.path.isfile(joined_path) and filename.endswith('.sql')):
            print(f'{filename} is not a .sql file, ignoring...')
            continue

        parts = filename.split('_', 1)

        # Invalid filename format
        if len(parts) < 2 or not parts[0].isdigit():
            raise InvalidMigrationError(f'Invalid migration file: {filename}')

        sql_files.append((int(parts[0]), joined_path))

    # Get sorted list of files by migration number
    sql_files.sort(key=lambda x: x[0])
    return sql_files

def perform_migration(cursor: sqlite3.Cursor, migration: tuple[int, str]) -> None:
    '''Performs a migration on the DB'''
    print(f'Performing migration {migration[1]}...')

    # Open and execute the sql script
    with open(migration[1], encoding='utf-8') as file:
        script = file.read()
        cursor.executescript(script)
        # Update the schema version
        cursor.execute('UPDATE config SET value = ? WHERE key = "schema_version"', (migration[0],))

def get_db_path() -> str | DBNotFoundError:
    '''Gets the DB path from config.ini'''
    env = os.environ.get('KEMOVERSE_ENV')
    if not (env and env in ['prod', 'dev']):
        raise KemoverseEnvUnset

    print(f'Running in "{env}" mode')

    config_path = f'config_{env}.ini'

    if not os.path.isfile(config_path):
        raise ConfigError(f'Could not find {config_path}')

    config = ConfigParser()
    config.read(config_path)
    db_path = config['application']['DatabaseLocation']
    if not db_path:
        raise DBNotFoundError()
    return db_path

def get_current_migration(cursor: sqlite3.Cursor) -> int:
    '''Gets the current schema version of the database'''
    try:
        cursor.execute('SELECT value FROM config WHERE key = ?', ('schema_version',))
        version = cursor.fetchone()
        return -1 if not version else int(version[0])
    except sqlite3.Error:
        print('Error getting schema version')
        # Database has not been initialized yet
        return -1

def main():
    '''Does the thing'''
    # Connect to the DB
    db_path = ''
    try:
        db_path = get_db_path()
    except ConfigError as ex:
        print(ex)
        return
    except KemoverseEnvUnset:
        print('Error: KEMOVERSE_ENV is either not set or has an invalid value.')
        print('Please set KEMOVERSE_ENV to either "dev" or "prod" before running.')
        print(traceback.format_exc())
        return

    conn = sqlite3.connect(db_path, autocommit=False)
    conn.row_factory = sqlite3.Row
    cursor = conn.cursor()

    # Obtain list of migrations to run
    migrations = get_migrations()
    # Determine schema version
    current_migration = get_current_migration(cursor)
    print(f'Current schema version: {current_migration}')

    # Run any migrations newer than current schema
    for migration in migrations:
        if migration[0] <= current_migration:
            print(f'Migration already up: {migration[1]}')
            continue
        try:
            perform_migration(cursor, migration)
            conn.commit()
        except Exception as ex:
            print(f'An error occurred while applying {migration[1]}: {ex}, aborting...')
            print(traceback.format_exc())
            conn.rollback()
            break
    conn.close()

if __name__ == '__main__':
    main()