From 4d626d423c2e8158205fb525f24f3dc2338a5b1f Mon Sep 17 00:00:00 2001 From: Parker Date: Sat, 16 Nov 2024 01:48:58 -0600 Subject: [PATCH] Add support for MySQL and PostgreSQL --- config.py | 75 +++++++++++++++++++++++++++++++++++++++++++++------ database.py | 17 +++++++----- linklogger.py | 8 +++--- 3 files changed, 82 insertions(+), 18 deletions(-) diff --git a/config.py b/config.py index 9b6f23c..54f10fe 100644 --- a/config.py +++ b/config.py @@ -23,6 +23,12 @@ LOG.addHandler(stream) IP_TO_LOCATION = None API_KEY = None +DB_NAME = None +DB_ENGINE = None +DB_HOST = None +DB_PORT = None +DB_USER = None +DB_PASSWORD = None schema = { "type": "object", @@ -34,7 +40,26 @@ schema = { "api_key": {"type": "string"}, }, "required": ["ip_to_location"], - } + }, + "database": { + "type": "object", + "properties": { + "name": {"type": "string", "default": "linklogger"}, + "engine": {"type": "string"}, + "host": {"type": "string"}, + "port": {"type": "integer"}, + "user": {"type": "string"}, + "password": {"type": "string"}, + }, + "required": [ + "name", + "engine", + "host", + "port", + "user", + "password", + ], + }, }, "required": ["config"], } @@ -50,16 +75,26 @@ def load_config(): try: with open(file_path, "r") as f: file_contents = f.read() - validate_config(file_contents) + if not validate_config(file_contents): + return False + else: + return True except FileNotFoundError: # Create new config.yaml w/ template with open(file_path, "w") as f: f.write( - """ -config: + """config: ip_to_location: false - api_key: ''""" + api_key: '' + +database: + engine: 'sqlite' + name: '' + host: '' + port: '' + user: '' + password: ''""" ) LOG.critical( "`config.yaml` was not found, a template has been created." @@ -67,12 +102,10 @@ config: ) return False - return True - # Validate the options within config.yaml def validate_config(file_contents): - global IP_TO_LOCATION, API_KEY + global IP_TO_LOCATION, API_KEY, DB_NAME, DB_ENGINE, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD config = yaml.safe_load(file_contents) try: @@ -88,5 +121,31 @@ def validate_config(file_contents): if IP_TO_LOCATION: if not config["config"]["api_key"]: LOG.error("API_KEY is not set") + return False else: API_KEY = config["config"]["api_key"] + + # + # Set/Validate the DATABASE section of the config.yaml + # + if "database" in config: + if config["database"]["engine"] not in [ + "sqlite", + "mysql", + "postgresql", + ]: + LOG.error( + "database_engine must be either 'sqlite', 'mysql', or" + " 'postgresql'" + ) + return False + else: + DB_ENGINE = config["database"]["engine"] + + DB_NAME = config["database"]["name"] + DB_HOST = config["database"]["host"] + DB_PORT = config["database"]["port"] + DB_USER = config["database"]["user"] + DB_PASSWORD = config["database"]["password"] + + return True diff --git a/database.py b/database.py index 544ee05..0166d28 100644 --- a/database.py +++ b/database.py @@ -1,13 +1,18 @@ from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -import os -# Create 'data' directory at root if it doesn't exist -if not os.path.exists("data"): - os.makedirs("data") +import config -engine = create_engine("sqlite:///data/data.db") +if config.DB_ENGINE == "mysql": + database_url = f"mysql+pymysql://{config.DB_USER}:{config.DB_PASSWORD}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}" + +elif config.DB_ENGINE == "postgresql": + database_url = f"postgresql+psycopg2://{config.DB_USER}:{config.DB_PASSWORD}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}" + +else: + database_url = "sqlite:///data/data.db" + +engine = create_engine(database_url) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base = declarative_base() diff --git a/linklogger.py b/linklogger.py index c456a16..17837b0 100644 --- a/linklogger.py +++ b/linklogger.py @@ -1,12 +1,12 @@ import uvicorn import config -from api.main import app -from database import Base, engine -Base.metadata.create_all(bind=engine) - if __name__ == "__main__": if config.load_config(): + from api.main import app + from database import Base, engine + + Base.metadata.create_all(bind=engine) uvicorn.run(app, port=5252)