diff --git a/bot/db_utils.py b/bot/db_utils.py index 2d030cb..a236980 100644 --- a/bot/db_utils.py +++ b/bot/db_utils.py @@ -47,13 +47,27 @@ def get_or_create_user(username): user_id = CURSOR.lastrowid return user_id -def insert_character(name: str, rarity: int, weight: float, file_id: str) -> int: +def insert_character(name: str, rarity: int, weight: float, file_id: str, stats: dict) -> int: '''Inserts a character''' CURSOR.execute( 'INSERT INTO characters (name, rarity, weight, file_id) VALUES (?, ?, ?, ?)', (name, rarity, weight, file_id) ) character_id = CURSOR.lastrowid + + # Insert stats + columns = ', '.join(stats.keys()) + placeholders = ', '.join(['?'] * len(stats)) + updates = ', '.join([f"{col}=excluded.{col}" for col in stats.keys()]) + values = list(stats.values()) + + sql = f''' + INSERT INTO character_stats (character_id, {columns}) + VALUES (?, {placeholders}) + ON CONFLICT(character_id) DO UPDATE SET {updates} + ''' + CURSOR.execute(sql, [character_id] + values) + return character_id if character_id else 0 def insert_pull(user_id, character_id): @@ -85,69 +99,7 @@ def set_config(key, value): # Character stat functions -def add_character_stats(character_id, stats): - ''' - Adds or updates character stats in the character_stats table. - `stats` should be a dictionary like {'power': 5, 'charm': 3} - ''' - if not stats: - return - - conn = get_db_connection() - cur = conn.cursor() - - columns = ', '.join(stats.keys()) - placeholders = ', '.join(['?'] * len(stats)) - updates = ', '.join([f"{col}=excluded.{col}" for col in stats.keys()]) - - values = list(stats.values()) - - sql = f''' - INSERT INTO character_stats (character_id, {columns}) - VALUES (?, {placeholders}) - ON CONFLICT(character_id) DO UPDATE SET {updates} - ''' - cur.execute(sql, [character_id] + values) - conn.commit() - conn.close() - - -def update_character_stat(character_id, stat_name, value): - '''Updates a single stat field for a character''' - conn = get_db_connection() - cur = conn.cursor() - cur.execute(f''' - UPDATE character_stats SET {stat_name} = ? WHERE character_id = ? - ''', (value, character_id)) - conn.commit() - conn.close() - -def get_character_stats(character_id): - '''Retrieves all stats for a single character dynamically''' - conn = get_db_connection() - conn.row_factory = sqlite3.Row # Enables dict-style access to rows - cur = conn.cursor() - cur.execute('SELECT * FROM character_stats WHERE character_id = ?', (character_id,)) - row = cur.fetchone() - conn.close() - - if row: - return {key: row[key] for key in row.keys() if key != 'character_id'} - else: - return {} - -def get_character_stat(character_id, stat_name): - '''Retrieves a single stat value for a character''' - if stat_name not in ('power', 'charm'): - raise ValueError("Invalid stat name") - conn = get_db_connection() - cur = conn.cursor() - cur.execute(f'SELECT {stat_name} FROM character_stats WHERE character_id = ?', (character_id,)) - row = cur.fetchone() - conn.close() - return row[0] if row else 0 - -def get_stats_for_multiple_characters(character_ids): +def get_characters(character_ids): ''' Retrieves stats for a list of character IDs. Returns a dictionary of character_id -> {stat_name: value, ...} @@ -162,12 +114,10 @@ def get_stats_for_multiple_characters(character_ids): WHERE character_id IN ({placeholders}) ''' - conn = get_db_connection() - cur = conn.cursor() - cur.execute(query, character_ids) + + CURSOR.execute(query, character_ids) rows = cur.fetchall() col_names = [desc[0] for desc in cur.description] - conn.close() result = {} for row in rows: