From bdd2f20b842ccea9de368bddae849a5a8e28194b Mon Sep 17 00:00:00 2001 From: VD15 Date: Mon, 26 May 2025 13:11:48 +0100 Subject: [PATCH] Implement connection pooling --- bot/add_character.py | 67 +++++++++++++++--------------------- bot/bot_app.py | 4 +++ bot/db_utils.py | 81 +++++++++++++++++++++++++------------------- bot/response.py | 23 ++----------- 4 files changed, 81 insertions(+), 94 deletions(-) diff --git a/bot/add_character.py b/bot/add_character.py index aae3fb3..e17a4c4 100644 --- a/bot/add_character.py +++ b/bot/add_character.py @@ -1,11 +1,12 @@ import requests from misskey.exceptions import MisskeyAPIException from client import client_connection -from db_utils import get_db_connection +from db_utils import insert_character def add_character(name: str, rarity: int, weight: float, image_url: str) -> tuple[int, str]: """ - Adds a character to the database, uploading the image from a public URL to the bot's Misskey Drive. + Adds a character to the database, uploading the image from a public URL to + the bot's Misskey Drive. Args: name (str): Character name. @@ -20,41 +21,29 @@ def add_character(name: str, rarity: int, weight: float, image_url: str) -> tupl ValueError: If inputs are invalid. RuntimeError: If image download/upload or database operation fails. """ + # Validate inputs + if not name or not name.strip(): + raise ValueError("Character name cannot be empty.") + if not isinstance(rarity, int) or rarity < 1: + raise ValueError("Rarity must be a positive integer.") + if not isinstance(weight, (int, float)) or weight <= 0: + raise ValueError("Weight must be a positive number.") + if not image_url: + raise ValueError("Image URL must be provided.") + + # Download image + response = requests.get(image_url, stream=True, timeout=30) + if response.status_code != 200: + raise RuntimeError(f"Failed to download image from {image_url}") + + # Upload to bot's Drive + mk = client_connection() try: - # Validate inputs - if not name or not name.strip(): - raise ValueError("Character name cannot be empty.") - if not isinstance(rarity, int) or rarity < 1: - raise ValueError("Rarity must be a positive integer.") - if not isinstance(weight, (int, float)) or weight <= 0: - raise ValueError("Weight must be a positive number.") - if not image_url: - raise ValueError("Image URL must be provided.") - - # Download image - response = requests.get(image_url, stream=True, timeout=30) - if response.status_code != 200: - raise RuntimeError(f"Failed to download image from {image_url}") - - # Upload to bot's Drive - mk = client_connection() - try: - media = mk.drive_files_create(response.raw) - file_id = media["id"] - except MisskeyAPIException as e: - raise RuntimeError(f"Failed to upload image to bot's Drive: {e}") from e - - # Insert into database - conn = get_db_connection() - cur = conn.cursor() - cur.execute( - 'INSERT INTO characters (name, rarity, weight, file_id) VALUES (?, ?, ?, ?)', - (name.strip(), rarity, float(weight), file_id) - ) - conn.commit() - character_id = cur.lastrowid - - return character_id, file_id - finally: - if 'conn' in locals(): - conn.close() + media = mk.drive_files_create(response.raw) + file_id = media["id"] + except MisskeyAPIException as e: + raise RuntimeError(f"Failed to upload image to bot's Drive: {e}") from e + + # Insert into database + character_id = insert_character(name.strip(), rarity, float(weight), file_id) + return character_id, file_id diff --git a/bot/bot_app.py b/bot/bot_app.py index b65ef3a..825695e 100644 --- a/bot/bot_app.py +++ b/bot/bot_app.py @@ -1,6 +1,7 @@ import time import misskey as misskey from client import client_connection +import db_utils as db from config import NOTIFICATION_POLL_INTERVAL from notification import process_notifications @@ -8,6 +9,9 @@ from notification import process_notifications if __name__ == '__main__': # Initialize the Misskey client client = client_connection() + # Connect to DB + db.connect() + print('Listening for notifications...') while True: if not process_notifications(client): diff --git a/bot/db_utils.py b/bot/db_utils.py index a521ec5..daff8c5 100644 --- a/bot/db_utils.py +++ b/bot/db_utils.py @@ -1,68 +1,79 @@ +from random import choices import sqlite3 import config DB_PATH = config.DB_PATH +CONNECTION: sqlite3.Connection +CURSOR: sqlite3.Cursor -def get_db_connection(): +def connect() -> None: '''Creates a connection to the database''' - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - return conn + print('Connecting to the database...') + global CONNECTION + global CURSOR + CONNECTION = sqlite3.connect(DB_PATH, autocommit=True) + CONNECTION.row_factory = sqlite3.Row + CURSOR = CONNECTION.cursor() + +def get_random_character(): + ''' Gets a random character from the database''' + CURSOR.execute('SELECT * FROM characters') + characters = CURSOR.fetchall() + + if not characters: + return None, None, None, None + + weights = [c['weight'] for c in characters] + chosen = choices(characters, weights=weights, k=1)[0] + + return chosen['id'], chosen['name'], chosen['file_id'], chosen['rarity'] def get_or_create_user(username): '''Retrieves an ID for a given user, if the user does not exist, it will be created.''' - conn = get_db_connection() - conn.row_factory = sqlite3.Row - cur = conn.cursor() - cur.execute('SELECT id FROM users WHERE username = ?', (username,)) - user = cur.fetchone() + CURSOR.execute('SELECT id FROM users WHERE username = ?', (username,)) + user = CURSOR.fetchone() if user: - conn.close() return user[0] # New user starts with has_rolled = False - cur.execute( + CURSOR.execute( 'INSERT INTO users (username, has_rolled) VALUES (?, ?)', (username, False) ) - conn.commit() - user_id = cur.lastrowid - conn.close() + user_id = CURSOR.lastrowid return user_id -def add_pull(user_id, character_id): +def insert_character(name: str, rarity: int, weight: float, file_id: str) -> int: + '''Inserts a character''' + CURSOR.execute( + 'INSERT INTO characters (name, rarity, weight, file_id) VALUES (?, ?, ?, ?)', + (name, rarity, weight, file_id) + ) + character_id = CURSOR.lastrowid + return character_id if character_id else 0 + +def insert_pull(user_id, character_id): '''Creates a pull in the database''' - conn = get_db_connection() - cur = conn.cursor() - cur.execute('INSERT INTO pulls (user_id, character_id) VALUES (?, ?)', (user_id, character_id)) - conn.commit() - conn.close() + CURSOR.execute( + 'INSERT INTO pulls (user_id, character_id) VALUES (?, ?)', + (user_id, character_id) + ) def get_last_rolled_at(user_id): '''Gets the timestamp when the user last rolled''' - conn = get_db_connection() - cur = conn.cursor() - cur.execute("SELECT timestamp FROM pulls WHERE user_id = ? ORDER BY timestamp DESC", \ + CURSOR.execute("SELECT timestamp FROM pulls WHERE user_id = ? ORDER BY timestamp DESC", \ (user_id,)) - row = cur.fetchone() - conn.close() + row = CURSOR.fetchone() return row[0] if row else None def get_config(key): '''Reads the value for a specified config key from the db''' - conn = get_db_connection() - cur = conn.cursor() - cur.execute("SELECT value FROM config WHERE key = ?", (key,)) - row = cur.fetchone() - conn.close() + CURSOR.execute("SELECT value FROM config WHERE key = ?", (key,)) + row = CURSOR.fetchone() return row[0] if row else None def set_config(key, value): '''Writes the value for a specified config key to the db''' - conn = get_db_connection() - cur = conn.cursor() - cur.execute("INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)", (key, value)) - conn.commit() - conn.close() + CURSOR.execute("INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)", (key, value)) diff --git a/bot/response.py b/bot/response.py index 55100ad..ea0abe5 100644 --- a/bot/response.py +++ b/bot/response.py @@ -1,25 +1,8 @@ -import random from datetime import datetime, timedelta, timezone -from db_utils import get_or_create_user, add_pull, get_db_connection, get_last_rolled_at +from db_utils import get_or_create_user, insert_pull, get_last_rolled_at, get_random_character from add_character import add_character from config import GACHA_ROLL_INTERVAL -def get_character(): - ''' Gets a random character from the database''' - conn = get_db_connection() - cur = conn.cursor() - cur.execute('SELECT * FROM characters') - characters = cur.fetchall() - conn.close() - - if not characters: - return None, None, None, None - - weights = [c['weight'] for c in characters] - chosen = random.choices(characters, weights=weights, k=1)[0] - - return chosen['id'], chosen['name'], chosen['file_id'], chosen['rarity'] - def do_roll(full_user): '''Determines whether the user can roll, then pulls a random character''' user_id = get_or_create_user(full_user) @@ -50,12 +33,12 @@ def do_roll(full_user): return f'{full_user} ⏱️ Please wait another {remaining_duration} before rolling again.' - character_id, character_name, file_id, rarity = get_character() + character_id, character_name, file_id, rarity = get_random_character() if not character_id: return f'{full_user} Uwaaa... something went wrong! No characters found. 😿' - add_pull(user_id,character_id) + insert_pull(user_id,character_id) stars = '⭐️' * rarity return([f"@{full_user} 🎲 Congrats! You rolled {stars} **{character_name}**\n\ She's all yours now~ 💖✨",[file_id]])