kemoverse/setup_db.py
2025-05-29 13:27:56 +01:00

129 lines
4.2 KiB
Python

import sqlite3
import traceback
import os
import argparse
from configparser import ConfigParser
from typing import List, Tuple
class DBNotFoundError(Exception):
'''Could not find the database location'''
class InvalidMigrationError(Exception):
'''Migration file has an invalid name'''
class KemoverseEnvUnset(Exception):
'''KEMOVERSE_ENV is not set or has an invalid value'''
class ConfigError(Exception):
'''Could not find the config file for the current environment'''
def get_migrations() -> List[Tuple[int, str]] | InvalidMigrationError:
'''Returns a list of migration files in numeric order.'''
# Store transaction id and filename separately
sql_files: List[Tuple[int, str]] = []
migrations_dir = 'migrations'
for filename in os.listdir(migrations_dir):
joined_path = os.path.join(migrations_dir, filename)
# Ignore anything that isn't a .sql file
if not (os.path.isfile(joined_path) and filename.endswith('.sql')):
print(f'{filename} is not a .sql file, ignoring...')
continue
parts = filename.split('_', 1)
# Invalid filename format
if len(parts) < 2 or not parts[0].isdigit():
raise InvalidMigrationError(f'Invalid migration file: {filename}')
sql_files.append((int(parts[0]), joined_path))
# Get sorted list of files by migration number
sql_files.sort(key=lambda x: x[0])
return sql_files
def perform_migration(cursor: sqlite3.Cursor, migration: tuple[int, str]) -> None:
'''Performs a migration on the DB'''
print(f'Performing migration {migration[1]}...')
# Open and execute the sql script
with open(migration[1], encoding='utf-8') as file:
script = file.read()
cursor.executescript(script)
# Update the schema version
cursor.execute('UPDATE config SET value = ? WHERE key = "schema_version"', (migration[0],))
def get_db_path() -> str | DBNotFoundError:
'''Gets the DB path from config.ini'''
env = os.environ.get('KEMOVERSE_ENV')
if not (env and env in ['prod', 'dev']):
raise KemoverseEnvUnset
print(f'Running in "{env}" mode')
config_path = f'config_{env}.ini'
if not os.path.isfile(config_path):
raise ConfigError(f'Could not find {config_path}')
config = ConfigParser()
config.read(config_path)
db_path = config['application']['DatabaseLocation']
if not db_path:
raise DBNotFoundError()
return db_path
def get_current_migration(cursor: sqlite3.Cursor) -> int:
'''Gets the current schema version of the database'''
try:
cursor.execute('SELECT value FROM config WHERE key = ?', ('schema_version',))
version = cursor.fetchone()
return -1 if not version else int(version[0])
except sqlite3.Error:
print('Error getting schema version')
# Database has not been initialized yet
return -1
def main():
'''Does the thing'''
# Connect to the DB
db_path = ''
try:
db_path = get_db_path()
except ConfigError as ex:
print(ex)
return
except KemoverseEnvUnset:
print('Error: KEMOVERSE_ENV is either not set or has an invalid value.')
print('Please set KEMOVERSE_ENV to either "dev" or "prod" before running.')
print(traceback.format_exc())
return
conn = sqlite3.connect(db_path, autocommit=False)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Obtain list of migrations to run
migrations = get_migrations()
# Determine schema version
current_migration = get_current_migration(cursor)
print(f'Current schema version: {current_migration}')
# Run any migrations newer than current schema
for migration in migrations:
if migration[0] <= current_migration:
print(f'Migration already up: {migration[1]}')
continue
try:
perform_migration(cursor, migration)
conn.commit()
except Exception as ex:
print(f'An error occurred while applying {migration[1]}: {ex}, aborting...')
print(traceback.format_exc())
conn.rollback()
break
conn.close()
if __name__ == '__main__':
main()