#Kemoverse - a gacha-style bot for the Fediverse. #Copyright © 2025 Waifu # #This program is free software: you can redistribute it and/or modify #it under the terms of the GNU Affero General Public License as #published by the Free Software Foundation, either version 3 of the #License, or (at your option) any later version. # #This program is distributed in the hope that it will be useful, #but WITHOUT ANY WARRANTY; without even the implied warranty of #MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #GNU Affero General Public License for more details. # #You should have received a copy of the GNU Affero General Public License #along with this program. If not, see https://www.gnu.org/licenses/. import sqlite3 import traceback import os import argparse from configparser import ConfigParser from typing import List, Tuple class DBNotFoundError(Exception): '''Could not find the database location''' class InvalidMigrationError(Exception): '''Migration file has an invalid name''' class KemoverseEnvUnset(Exception): '''KEMOVERSE_ENV is not set or has an invalid value''' class ConfigError(Exception): '''Could not find the config file for the current environment''' def get_migrations() -> List[Tuple[int, str]] | InvalidMigrationError: '''Returns a list of migration files in numeric order.''' # Store transaction id and filename separately sql_files: List[Tuple[int, str]] = [] migrations_dir = 'migrations' for filename in os.listdir(migrations_dir): joined_path = os.path.join(migrations_dir, filename) # Ignore anything that isn't a .sql file if not (os.path.isfile(joined_path) and filename.endswith('.sql')): print(f'{filename} is not a .sql file, ignoring...') continue parts = filename.split('_', 1) # Invalid filename format if len(parts) < 2 or not parts[0].isdigit(): raise InvalidMigrationError(f'Invalid migration file: {filename}') sql_files.append((int(parts[0]), joined_path)) # Get sorted list of files by migration number sql_files.sort(key=lambda x: x[0]) return sql_files def perform_migration(cursor: sqlite3.Cursor, migration: tuple[int, str]) -> None: '''Performs a migration on the DB''' print(f'Performing migration {migration[1]}...') # Open and execute the sql script with open(migration[1], encoding='utf-8') as file: script = file.read() cursor.executescript(script) # Update the schema version cursor.execute('UPDATE config SET value = ? WHERE key = "schema_version"', (migration[0],)) def get_db_path() -> str | DBNotFoundError: '''Gets the DB path from config.ini''' env = os.environ.get('KEMOVERSE_ENV') if not (env and env in ['prod', 'dev']): raise KemoverseEnvUnset print(f'Running in "{env}" mode') config_path = f'config_{env}.ini' if not os.path.isfile(config_path): raise ConfigError(f'Could not find {config_path}') config = ConfigParser() config.read(config_path) db_path = config['application']['DatabaseLocation'] if not db_path: raise DBNotFoundError() return db_path def get_current_migration(cursor: sqlite3.Cursor) -> int: '''Gets the current schema version of the database''' try: cursor.execute('SELECT value FROM config WHERE key = ?', ('schema_version',)) version = cursor.fetchone() return -1 if not version else int(version[0]) except sqlite3.Error: print('Error getting schema version') # Database has not been initialized yet return -1 def main(): '''Does the thing''' # Connect to the DB db_path = '' try: db_path = get_db_path() except ConfigError as ex: print(ex) return except KemoverseEnvUnset: print('Error: KEMOVERSE_ENV is either not set or has an invalid value.') print('Please set KEMOVERSE_ENV to either "dev" or "prod" before running.') print(traceback.format_exc()) return conn = sqlite3.connect(db_path, autocommit=False) conn.row_factory = sqlite3.Row cursor = conn.cursor() # Obtain list of migrations to run migrations = get_migrations() # Determine schema version current_migration = get_current_migration(cursor) print(f'Current schema version: {current_migration}') # Run any migrations newer than current schema for migration in migrations: if migration[0] <= current_migration: print(f'Migration already up: {migration[1]}') continue try: perform_migration(cursor, migration) conn.commit() except Exception as ex: print(f'An error occurred while applying {migration[1]}: {ex}, aborting...') print(traceback.format_exc()) conn.rollback() break conn.close() if __name__ == '__main__': main()