129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
import sqlite3
|
|
import traceback
|
|
import os
|
|
import argparse
|
|
from configparser import ConfigParser
|
|
from typing import List, Tuple
|
|
|
|
class DBNotFoundError(Exception):
|
|
'''Could not find the database location'''
|
|
|
|
class InvalidMigrationError(Exception):
|
|
'''Migration file has an invalid name'''
|
|
|
|
class KemoverseEnvUnset(Exception):
|
|
'''KEMOVERSE_ENV is not set or has an invalid value'''
|
|
|
|
class ConfigError(Exception):
|
|
'''Could not find the config file for the current environment'''
|
|
|
|
def get_migrations() -> List[Tuple[int, str]] | InvalidMigrationError:
|
|
'''Returns a list of migration files in numeric order.'''
|
|
# Store transaction id and filename separately
|
|
sql_files: List[Tuple[int, str]] = []
|
|
migrations_dir = 'migrations'
|
|
|
|
for filename in os.listdir(migrations_dir):
|
|
joined_path = os.path.join(migrations_dir, filename)
|
|
|
|
# Ignore anything that isn't a .sql file
|
|
if not (os.path.isfile(joined_path) and filename.endswith('.sql')):
|
|
print(f'{filename} is not a .sql file, ignoring...')
|
|
continue
|
|
|
|
parts = filename.split('_', 1)
|
|
|
|
# Invalid filename format
|
|
if len(parts) < 2 or not parts[0].isdigit():
|
|
raise InvalidMigrationError(f'Invalid migration file: {filename}')
|
|
|
|
sql_files.append((int(parts[0]), joined_path))
|
|
|
|
# Get sorted list of files by migration number
|
|
sql_files.sort(key=lambda x: x[0])
|
|
return sql_files
|
|
|
|
def perform_migration(cursor: sqlite3.Cursor, migration: tuple[int, str]) -> None:
|
|
'''Performs a migration on the DB'''
|
|
print(f'Performing migration {migration[1]}...')
|
|
|
|
# Open and execute the sql script
|
|
with open(migration[1], encoding='utf-8') as file:
|
|
script = file.read()
|
|
cursor.executescript(script)
|
|
# Update the schema version
|
|
cursor.execute('UPDATE config SET value = ? WHERE key = "schema_version"', (migration[0],))
|
|
|
|
def get_db_path() -> str | DBNotFoundError:
|
|
'''Gets the DB path from config.ini'''
|
|
env = os.environ.get('KEMOVERSE_ENV')
|
|
if not (env and env in ['prod', 'dev']):
|
|
raise KemoverseEnvUnset
|
|
|
|
print(f'Running in "{env}" mode')
|
|
|
|
config_path = f'config_{env}.ini'
|
|
|
|
if not os.path.isfile(config_path):
|
|
raise ConfigError(f'Could not find {config_path}')
|
|
|
|
config = ConfigParser()
|
|
config.read(config_path)
|
|
db_path = config['application']['DatabaseLocation']
|
|
if not db_path:
|
|
raise DBNotFoundError()
|
|
return db_path
|
|
|
|
def get_current_migration(cursor: sqlite3.Cursor) -> int:
|
|
'''Gets the current schema version of the database'''
|
|
try:
|
|
cursor.execute('SELECT value FROM config WHERE key = ?', ('schema_version',))
|
|
version = cursor.fetchone()
|
|
return -1 if not version else int(version[0])
|
|
except sqlite3.Error:
|
|
print('Error getting schema version')
|
|
# Database has not been initialized yet
|
|
return -1
|
|
|
|
def main():
|
|
'''Does the thing'''
|
|
# Connect to the DB
|
|
db_path = ''
|
|
try:
|
|
db_path = get_db_path()
|
|
except ConfigError as ex:
|
|
print(ex)
|
|
return
|
|
except KemoverseEnvUnset:
|
|
print('Error: KEMOVERSE_ENV is either not set or has an invalid value.')
|
|
print('Please set KEMOVERSE_ENV to either "dev" or "prod" before running.')
|
|
print(traceback.format_exc())
|
|
return
|
|
|
|
conn = sqlite3.connect(db_path, autocommit=False)
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
# Obtain list of migrations to run
|
|
migrations = get_migrations()
|
|
# Determine schema version
|
|
current_migration = get_current_migration(cursor)
|
|
print(f'Current schema version: {current_migration}')
|
|
|
|
# Run any migrations newer than current schema
|
|
for migration in migrations:
|
|
if migration[0] <= current_migration:
|
|
print(f'Migration already up: {migration[1]}')
|
|
continue
|
|
try:
|
|
perform_migration(cursor, migration)
|
|
conn.commit()
|
|
except Exception as ex:
|
|
print(f'An error occurred while applying {migration[1]}: {ex}, aborting...')
|
|
print(traceback.format_exc())
|
|
conn.rollback()
|
|
break
|
|
conn.close()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|