forked from waifu/kemoverse
		
	Merge pull request 'Implement Database Connection Pooling' (#31) from 27_pool_db_connections into dev
Reviewed-on: waifu/kemoverse#31
This commit is contained in:
		
						commit
						d210f44efc
					
				
					 4 changed files with 80 additions and 93 deletions
				
			
		|  | @ -1,11 +1,12 @@ | |||
| import requests | ||||
| from misskey.exceptions import MisskeyAPIException | ||||
| from client import client_connection | ||||
| from db_utils import get_db_connection | ||||
| from db_utils import insert_character | ||||
| 
 | ||||
| def add_character(name: str, rarity: int, weight: float, image_url: str) -> tuple[int, str]: | ||||
|     """ | ||||
|     Adds a character to the database, uploading the image from a public URL to the bot's Misskey Drive. | ||||
|     Adds a character to the database, uploading the image from a public URL to | ||||
|     the bot's Misskey Drive. | ||||
| 
 | ||||
|     Args: | ||||
|         name (str): Character name. | ||||
|  | @ -20,41 +21,29 @@ def add_character(name: str, rarity: int, weight: float, image_url: str) -> tupl | |||
|         ValueError: If inputs are invalid. | ||||
|         RuntimeError: If image download/upload or database operation fails. | ||||
|     """ | ||||
|     # Validate inputs | ||||
|     if not name or not name.strip(): | ||||
|         raise ValueError("Character name cannot be empty.") | ||||
|     if not isinstance(rarity, int) or rarity < 1: | ||||
|         raise ValueError("Rarity must be a positive integer.") | ||||
|     if not isinstance(weight, (int, float)) or weight <= 0: | ||||
|         raise ValueError("Weight must be a positive number.") | ||||
|     if not image_url: | ||||
|         raise ValueError("Image URL must be provided.") | ||||
| 
 | ||||
|     # Download image | ||||
|     response = requests.get(image_url, stream=True, timeout=30) | ||||
|     if response.status_code != 200: | ||||
|         raise RuntimeError(f"Failed to download image from {image_url}") | ||||
| 
 | ||||
|     # Upload to bot's Drive | ||||
|     mk = client_connection() | ||||
|     try: | ||||
|         # Validate inputs | ||||
|         if not name or not name.strip(): | ||||
|             raise ValueError("Character name cannot be empty.") | ||||
|         if not isinstance(rarity, int) or rarity < 1: | ||||
|             raise ValueError("Rarity must be a positive integer.") | ||||
|         if not isinstance(weight, (int, float)) or weight <= 0: | ||||
|             raise ValueError("Weight must be a positive number.") | ||||
|         if not image_url: | ||||
|             raise ValueError("Image URL must be provided.") | ||||
|         media = mk.drive_files_create(response.raw) | ||||
|         file_id = media["id"] | ||||
|     except MisskeyAPIException as e: | ||||
|         raise RuntimeError(f"Failed to upload image to bot's Drive: {e}") from e | ||||
| 
 | ||||
|         # Download image | ||||
|         response = requests.get(image_url, stream=True, timeout=30) | ||||
|         if response.status_code != 200: | ||||
|             raise RuntimeError(f"Failed to download image from {image_url}") | ||||
| 
 | ||||
|         # Upload to bot's Drive | ||||
|         mk = client_connection() | ||||
|         try: | ||||
|             media = mk.drive_files_create(response.raw) | ||||
|             file_id = media["id"] | ||||
|         except MisskeyAPIException as e: | ||||
|             raise RuntimeError(f"Failed to upload image to bot's Drive: {e}") from e | ||||
| 
 | ||||
|         # Insert into database | ||||
|         conn = get_db_connection() | ||||
|         cur = conn.cursor() | ||||
|         cur.execute( | ||||
|             'INSERT INTO characters (name, rarity, weight, file_id) VALUES (?, ?, ?, ?)', | ||||
|             (name.strip(), rarity, float(weight), file_id) | ||||
|         ) | ||||
|         conn.commit() | ||||
|         character_id = cur.lastrowid | ||||
| 
 | ||||
|         return character_id, file_id | ||||
|     finally: | ||||
|         if 'conn' in locals(): | ||||
|             conn.close() | ||||
|     # Insert into database | ||||
|     character_id = insert_character(name.strip(), rarity, float(weight), file_id) | ||||
|     return character_id, file_id | ||||
|  |  | |||
|  | @ -1,6 +1,7 @@ | |||
| import time | ||||
| import misskey as misskey | ||||
| from client import client_connection | ||||
| import db_utils as db | ||||
| 
 | ||||
| from config import NOTIFICATION_POLL_INTERVAL | ||||
| from notification import process_notifications | ||||
|  | @ -8,6 +9,9 @@ from notification import process_notifications | |||
| if __name__ == '__main__': | ||||
|     # Initialize the Misskey client | ||||
|     client = client_connection() | ||||
|     # Connect to DB | ||||
|     db.connect() | ||||
| 
 | ||||
|     print('Listening for notifications...') | ||||
|     while True: | ||||
|         if not process_notifications(client): | ||||
|  |  | |||
|  | @ -1,68 +1,79 @@ | |||
| from random import choices | ||||
| import sqlite3 | ||||
| import config | ||||
| 
 | ||||
| DB_PATH = config.DB_PATH | ||||
| CONNECTION: sqlite3.Connection | ||||
| CURSOR: sqlite3.Cursor | ||||
| 
 | ||||
| def get_db_connection(): | ||||
| def connect() -> None: | ||||
|     '''Creates a connection to the database''' | ||||
|     conn = sqlite3.connect(DB_PATH) | ||||
|     conn.row_factory = sqlite3.Row | ||||
|     return conn | ||||
|     print('Connecting to the database...') | ||||
|     global CONNECTION | ||||
|     global CURSOR | ||||
|     CONNECTION = sqlite3.connect(DB_PATH, autocommit=True) | ||||
|     CONNECTION.row_factory = sqlite3.Row | ||||
|     CURSOR = CONNECTION.cursor() | ||||
| 
 | ||||
| def get_random_character(): | ||||
|     ''' Gets a random character from the database''' | ||||
|     CURSOR.execute('SELECT * FROM characters') | ||||
|     characters = CURSOR.fetchall() | ||||
| 
 | ||||
|     if not characters: | ||||
|         return None, None, None, None | ||||
| 
 | ||||
|     weights = [c['weight'] for c in characters] | ||||
|     chosen = choices(characters, weights=weights, k=1)[0] | ||||
| 
 | ||||
|     return chosen['id'], chosen['name'], chosen['file_id'], chosen['rarity'] | ||||
| 
 | ||||
| def get_or_create_user(username): | ||||
|     '''Retrieves an ID for a given user, if the user does not exist, it will be | ||||
|     created.''' | ||||
|     conn = get_db_connection() | ||||
|     conn.row_factory = sqlite3.Row | ||||
|     cur = conn.cursor() | ||||
|     cur.execute('SELECT id FROM users WHERE username = ?', (username,)) | ||||
|     user = cur.fetchone() | ||||
|     CURSOR.execute('SELECT id FROM users WHERE username = ?', (username,)) | ||||
|     user = CURSOR.fetchone() | ||||
|     if user: | ||||
|         conn.close() | ||||
|         return user[0] | ||||
| 
 | ||||
|     # New user starts with has_rolled = False | ||||
|     cur.execute( | ||||
|     CURSOR.execute( | ||||
|         'INSERT INTO users (username, has_rolled) VALUES (?, ?)', | ||||
|         (username, False) | ||||
|     ) | ||||
|     conn.commit() | ||||
|     user_id = cur.lastrowid | ||||
|     conn.close() | ||||
|     user_id = CURSOR.lastrowid | ||||
|     return user_id | ||||
| 
 | ||||
| def add_pull(user_id, character_id): | ||||
| def insert_character(name: str, rarity: int, weight: float, file_id: str) -> int: | ||||
|     '''Inserts a character''' | ||||
|     CURSOR.execute( | ||||
|         'INSERT INTO characters (name, rarity, weight, file_id) VALUES (?, ?, ?, ?)', | ||||
|         (name, rarity, weight, file_id) | ||||
|     ) | ||||
|     character_id = CURSOR.lastrowid | ||||
|     return character_id if character_id else 0 | ||||
| 
 | ||||
| def insert_pull(user_id, character_id): | ||||
|     '''Creates a pull in the database''' | ||||
|     conn = get_db_connection() | ||||
|     cur = conn.cursor() | ||||
|     cur.execute('INSERT INTO pulls (user_id, character_id) VALUES (?, ?)', (user_id, character_id)) | ||||
|     conn.commit() | ||||
|     conn.close() | ||||
|     CURSOR.execute( | ||||
|         'INSERT INTO pulls (user_id, character_id) VALUES (?, ?)', | ||||
|         (user_id, character_id) | ||||
|     ) | ||||
| 
 | ||||
| def get_last_rolled_at(user_id): | ||||
|     '''Gets the timestamp when the user last rolled''' | ||||
|     conn = get_db_connection() | ||||
|     cur = conn.cursor() | ||||
|     cur.execute("SELECT timestamp FROM pulls WHERE user_id = ? ORDER BY timestamp DESC", \ | ||||
|     CURSOR.execute("SELECT timestamp FROM pulls WHERE user_id = ? ORDER BY timestamp DESC", \ | ||||
|             (user_id,)) | ||||
|     row = cur.fetchone() | ||||
|     conn.close() | ||||
|     row = CURSOR.fetchone() | ||||
|     return row[0] if row else None | ||||
| 
 | ||||
| 
 | ||||
| def get_config(key): | ||||
|     '''Reads the value for a specified config key from the db''' | ||||
|     conn = get_db_connection() | ||||
|     cur = conn.cursor() | ||||
|     cur.execute("SELECT value FROM config WHERE key = ?", (key,)) | ||||
|     row = cur.fetchone() | ||||
|     conn.close() | ||||
|     CURSOR.execute("SELECT value FROM config WHERE key = ?", (key,)) | ||||
|     row = CURSOR.fetchone() | ||||
|     return row[0] if row else None | ||||
| 
 | ||||
| def set_config(key, value): | ||||
|     '''Writes the value for a specified config key to the db''' | ||||
|     conn = get_db_connection() | ||||
|     cur = conn.cursor() | ||||
|     cur.execute("INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)", (key, value)) | ||||
|     conn.commit() | ||||
|     conn.close() | ||||
|     CURSOR.execute("INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)", (key, value)) | ||||
|  |  | |||
|  | @ -1,25 +1,8 @@ | |||
| import random | ||||
| from datetime import datetime, timedelta, timezone | ||||
| from db_utils import get_or_create_user, add_pull, get_db_connection, get_last_rolled_at | ||||
| from db_utils import get_or_create_user, insert_pull, get_last_rolled_at, get_random_character | ||||
| from add_character import add_character | ||||
| from config import GACHA_ROLL_INTERVAL | ||||
| 
 | ||||
| def get_character(): | ||||
|     ''' Gets a random character from the database''' | ||||
|     conn = get_db_connection() | ||||
|     cur = conn.cursor() | ||||
|     cur.execute('SELECT * FROM characters') | ||||
|     characters = cur.fetchall() | ||||
|     conn.close() | ||||
| 
 | ||||
|     if not characters: | ||||
|         return None, None, None, None | ||||
| 
 | ||||
|     weights = [c['weight'] for c in characters] | ||||
|     chosen = random.choices(characters, weights=weights, k=1)[0] | ||||
| 
 | ||||
|     return chosen['id'], chosen['name'], chosen['file_id'], chosen['rarity'] | ||||
| 
 | ||||
| def do_roll(full_user): | ||||
|     '''Determines whether the user can roll, then pulls a random character''' | ||||
|     user_id = get_or_create_user(full_user) | ||||
|  | @ -50,12 +33,12 @@ def do_roll(full_user): | |||
| 
 | ||||
|             return f'{full_user} ⏱️ Please wait another {remaining_duration} before rolling again.' | ||||
| 
 | ||||
|     character_id, character_name, file_id, rarity = get_character() | ||||
|     character_id, character_name, file_id, rarity = get_random_character() | ||||
| 
 | ||||
|     if not character_id: | ||||
|         return f'{full_user} Uwaaa... something went wrong! No characters found. 😿' | ||||
| 
 | ||||
|     add_pull(user_id,character_id) | ||||
|     insert_pull(user_id,character_id) | ||||
|     stars = '⭐️' * rarity | ||||
|     return([f"@{full_user} 🎲 Congrats! You rolled {stars} **{character_name}**\n\ | ||||
|             She's all yours now~ 💖✨",[file_id]]) | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue