forked from waifu/kemoverse
Add database migration system
This commit is contained in:
parent
d210f44efc
commit
25a72b3002
4 changed files with 128 additions and 64 deletions
64
db.py
64
db.py
|
@ -1,64 +0,0 @@
|
|||
import sqlite3
|
||||
|
||||
# Connect to SQLite database (or create it if it doesn't exist)
|
||||
conn = sqlite3.connect('gacha_game.db')
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create tables
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
has_rolled BOOLEAN NOT NULL DEFAULT 0
|
||||
)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS characters (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
rarity INTEGER NOT NULL,
|
||||
weight REAL NOT NULL,
|
||||
file_id TEXT NOT NULL
|
||||
)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS pulls (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER,
|
||||
character_id INTEGER,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id),
|
||||
FOREIGN KEY (character_id) REFERENCES characters(id)
|
||||
)
|
||||
''')
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Initialize essential config key
|
||||
cursor.execute('INSERT INTO config VALUES ("last_seen_notif_id", 0)')
|
||||
|
||||
""" # Insert example characters into the database if they don't already exist
|
||||
characters = [
|
||||
('Murakami-san', 1, 0.35),
|
||||
('Mastodon-kun', 2, 0.25),
|
||||
('Pleroma-tan', 3, 0.2),
|
||||
('Misskey-tan', 4, 0.15),
|
||||
('Syuilo-mama', 5, 0.05)
|
||||
]
|
||||
|
||||
|
||||
cursor.executemany('''
|
||||
INSERT OR IGNORE INTO characters (name, rarity, weight) VALUES (?, ?, ?)
|
||||
''', characters)
|
||||
"""
|
||||
|
||||
# Commit changes and close
|
||||
conn.commit()
|
||||
conn.close()
|
28
migrations/0000_setup.sql
Normal file
28
migrations/0000_setup.sql
Normal file
|
@ -0,0 +1,28 @@
|
|||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
has_rolled BOOLEAN NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS characters (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
rarity INTEGER NOT NULL,
|
||||
weight REAL NOT NULL,
|
||||
file_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pulls (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER,
|
||||
character_id INTEGER,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id),
|
||||
FOREIGN KEY (character_id) REFERENCES characters(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
);
|
||||
INSERT OR IGNORE INTO config VALUES ("schema_version", 0);
|
1
migrations/0001_fix_notif_id.sql
Normal file
1
migrations/0001_fix_notif_id.sql
Normal file
|
@ -0,0 +1 @@
|
|||
INSERT OR IGNORE INTO config VALUES ("last_seen_notif_id", 0);
|
99
setup_db.py
Normal file
99
setup_db.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
import sqlite3
|
||||
import os
|
||||
import argparse
|
||||
from configparser import ConfigParser
|
||||
from typing import List, Tuple
|
||||
|
||||
class DBNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
class InvalidMigrationError(Exception):
|
||||
pass
|
||||
|
||||
def get_migrations() -> List[Tuple[int, str]] | InvalidMigrationError:
|
||||
'''Returns a list of migration files in numeric order.'''
|
||||
# Store transaction id and filename separately
|
||||
sql_files: List[Tuple[int, str]] = []
|
||||
migrations_dir = 'migrations'
|
||||
|
||||
for filename in os.listdir(migrations_dir):
|
||||
joined_path = os.path.join(migrations_dir, filename)
|
||||
|
||||
# Ignore anything that isn't a .sql file
|
||||
if not (os.path.isfile(joined_path) and filename.endswith('.sql')):
|
||||
print(f'{filename} is not a .sql file, ignoring...')
|
||||
continue
|
||||
|
||||
parts = filename.split('_', 1)
|
||||
|
||||
# Invalid filename format
|
||||
if len(parts) < 2 or not parts[0].isdigit():
|
||||
raise InvalidMigrationError(f'Invalid migration file: {filename}')
|
||||
|
||||
sql_files.append((int(parts[0]), joined_path))
|
||||
|
||||
# Get sorted list of files by migration number
|
||||
sql_files.sort(key=lambda x: x[0])
|
||||
return sql_files
|
||||
|
||||
def perform_migration(cursor: sqlite3.Cursor, migration: tuple[int, str]) -> None:
|
||||
'''Performs a migration on the DB'''
|
||||
print(f'Performing migration {migration[1]}...')
|
||||
|
||||
# Open and execute the sql script
|
||||
with open(migration[1], encoding='utf-8') as file:
|
||||
script = file.read()
|
||||
cursor.executescript(script)
|
||||
# Update the schema version
|
||||
cursor.execute('UPDATE config SET value = ? WHERE key = "schema_version"', (migration[0],))
|
||||
|
||||
def get_db_path() -> str | DBNotFoundError:
|
||||
'''Gets the DB path from config.ini'''
|
||||
config = ConfigParser()
|
||||
config.read('config.ini')
|
||||
db_path = config['application']['DatabaseLocation']
|
||||
if not db_path:
|
||||
raise DBNotFoundError
|
||||
return db_path
|
||||
|
||||
def get_current_migration(cursor: sqlite3.Cursor) -> int:
|
||||
'''Gets the current schema version of the database'''
|
||||
try:
|
||||
cursor.execute('SELECT value FROM config WHERE key = ?', ('schema_version',))
|
||||
version = cursor.fetchone()
|
||||
return -1 if not version else int(version[0])
|
||||
except sqlite3.Error:
|
||||
print('Error getting schema version')
|
||||
# Database has not been initialized yet
|
||||
return -1
|
||||
|
||||
def main():
|
||||
'''Does the thing'''
|
||||
# Connect to the DB
|
||||
db_path = get_db_path()
|
||||
conn = sqlite3.connect(db_path, autocommit=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Obtain list of migrations to run
|
||||
migrations = get_migrations()
|
||||
# Determine schema version
|
||||
current_migration = get_current_migration(cursor)
|
||||
print(f'Current schema version: {current_migration}')
|
||||
|
||||
# Run any migrations newer than current schema
|
||||
for migration in migrations:
|
||||
if migration[0] <= current_migration:
|
||||
print(f'Migration already up: {migration[1]}')
|
||||
continue
|
||||
try:
|
||||
perform_migration(cursor, migration)
|
||||
conn.commit()
|
||||
except Exception as ex:
|
||||
print(f'An error occurred while applying migration: {ex}, aborting...')
|
||||
conn.rollback()
|
||||
break
|
||||
conn.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Add table
Reference in a new issue