kemoverse/setup_db.py

99 lines
3.2 KiB
Python

import sqlite3
import os
import argparse
from configparser import ConfigParser
from typing import List, Tuple
class DBNotFoundError(Exception):
pass
class InvalidMigrationError(Exception):
pass
def get_migrations() -> List[Tuple[int, str]] | InvalidMigrationError:
'''Returns a list of migration files in numeric order.'''
# Store transaction id and filename separately
sql_files: List[Tuple[int, str]] = []
migrations_dir = 'migrations'
for filename in os.listdir(migrations_dir):
joined_path = os.path.join(migrations_dir, filename)
# Ignore anything that isn't a .sql file
if not (os.path.isfile(joined_path) and filename.endswith('.sql')):
print(f'{filename} is not a .sql file, ignoring...')
continue
parts = filename.split('_', 1)
# Invalid filename format
if len(parts) < 2 or not parts[0].isdigit():
raise InvalidMigrationError(f'Invalid migration file: {filename}')
sql_files.append((int(parts[0]), joined_path))
# Get sorted list of files by migration number
sql_files.sort(key=lambda x: x[0])
return sql_files
def perform_migration(cursor: sqlite3.Cursor, migration: tuple[int, str]) -> None:
'''Performs a migration on the DB'''
print(f'Performing migration {migration[1]}...')
# Open and execute the sql script
with open(migration[1], encoding='utf-8') as file:
script = file.read()
cursor.executescript(script)
# Update the schema version
cursor.execute('UPDATE config SET value = ? WHERE key = "schema_version"', (migration[0],))
def get_db_path() -> str | DBNotFoundError:
'''Gets the DB path from config.ini'''
config = ConfigParser()
config.read('config.ini')
db_path = config['application']['DatabaseLocation']
if not db_path:
raise DBNotFoundError
return db_path
def get_current_migration(cursor: sqlite3.Cursor) -> int:
'''Gets the current schema version of the database'''
try:
cursor.execute('SELECT value FROM config WHERE key = ?', ('schema_version',))
version = cursor.fetchone()
return -1 if not version else int(version[0])
except sqlite3.Error:
print('Error getting schema version')
# Database has not been initialized yet
return -1
def main():
'''Does the thing'''
# Connect to the DB
db_path = get_db_path()
conn = sqlite3.connect(db_path, autocommit=False)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Obtain list of migrations to run
migrations = get_migrations()
# Determine schema version
current_migration = get_current_migration(cursor)
print(f'Current schema version: {current_migration}')
# Run any migrations newer than current schema
for migration in migrations:
if migration[0] <= current_migration:
print(f'Migration already up: {migration[1]}')
continue
try:
perform_migration(cursor, migration)
conn.commit()
except Exception as ex:
print(f'An error occurred while applying migration: {ex}, aborting...')
conn.rollback()
break
conn.close()
if __name__ == '__main__':
main()