aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorParker <contact@pkrm.dev>2024-11-16 01:48:58 -0600
committerParker <contact@pkrm.dev>2024-11-16 01:48:58 -0600
commit4d626d423c2e8158205fb525f24f3dc2338a5b1f (patch)
treee01957a5f4373bcdd7c1268e057c2b2b73c665c3
parent6e8f3ee321d703cf454f131c129a147c841467bc (diff)
Add support for MySQL and PostgreSQL
-rw-r--r--config.py75
-rw-r--r--database.py17
-rw-r--r--linklogger.py8
3 files changed, 82 insertions, 18 deletions
diff --git a/config.py b/config.py
index 9b6f23c..54f10fe 100644
--- a/config.py
+++ b/config.py
@@ -23,6 +23,12 @@ LOG.addHandler(stream)
IP_TO_LOCATION = None
API_KEY = None
+DB_NAME = None
+DB_ENGINE = None
+DB_HOST = None
+DB_PORT = None
+DB_USER = None
+DB_PASSWORD = None
schema = {
"type": "object",
@@ -34,7 +40,26 @@ schema = {
"api_key": {"type": "string"},
},
"required": ["ip_to_location"],
- }
+ },
+ "database": {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string", "default": "linklogger"},
+ "engine": {"type": "string"},
+ "host": {"type": "string"},
+ "port": {"type": "integer"},
+ "user": {"type": "string"},
+ "password": {"type": "string"},
+ },
+ "required": [
+ "name",
+ "engine",
+ "host",
+ "port",
+ "user",
+ "password",
+ ],
+ },
},
"required": ["config"],
}
@@ -50,16 +75,26 @@ def load_config():
try:
with open(file_path, "r") as f:
file_contents = f.read()
- validate_config(file_contents)
+ if not validate_config(file_contents):
+ return False
+ else:
+ return True
except FileNotFoundError:
# Create new config.yaml w/ template
with open(file_path, "w") as f:
f.write(
- """
-config:
+ """config:
ip_to_location: false
- api_key: ''"""
+ api_key: ''
+
+database:
+ engine: 'sqlite'
+ name: ''
+ host: ''
+ port: ''
+ user: ''
+ password: ''"""
)
LOG.critical(
"`config.yaml` was not found, a template has been created."
@@ -67,12 +102,10 @@ config:
)
return False
- return True
-
# Validate the options within config.yaml
def validate_config(file_contents):
- global IP_TO_LOCATION, API_KEY
+ global IP_TO_LOCATION, API_KEY, DB_NAME, DB_ENGINE, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD
config = yaml.safe_load(file_contents)
try:
@@ -88,5 +121,31 @@ def validate_config(file_contents):
if IP_TO_LOCATION:
if not config["config"]["api_key"]:
LOG.error("API_KEY is not set")
+ return False
else:
API_KEY = config["config"]["api_key"]
+
+ #
+ # Set/Validate the DATABASE section of the config.yaml
+ #
+ if "database" in config:
+ if config["database"]["engine"] not in [
+ "sqlite",
+ "mysql",
+ "postgresql",
+ ]:
+ LOG.error(
+ "database_engine must be either 'sqlite', 'mysql', or"
+ " 'postgresql'"
+ )
+ return False
+ else:
+ DB_ENGINE = config["database"]["engine"]
+
+ DB_NAME = config["database"]["name"]
+ DB_HOST = config["database"]["host"]
+ DB_PORT = config["database"]["port"]
+ DB_USER = config["database"]["user"]
+ DB_PASSWORD = config["database"]["password"]
+
+ return True
diff --git a/database.py b/database.py
index 544ee05..0166d28 100644
--- a/database.py
+++ b/database.py
@@ -1,13 +1,18 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
-import os
-# Create 'data' directory at root if it doesn't exist
-if not os.path.exists("data"):
- os.makedirs("data")
+import config
-engine = create_engine("sqlite:///data/data.db")
-SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+if config.DB_ENGINE == "mysql":
+ database_url = f"mysql+pymysql://{config.DB_USER}:{config.DB_PASSWORD}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
+
+elif config.DB_ENGINE == "postgresql":
+ database_url = f"postgresql+psycopg2://{config.DB_USER}:{config.DB_PASSWORD}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
+else:
+ database_url = "sqlite:///data/data.db"
+
+engine = create_engine(database_url)
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
diff --git a/linklogger.py b/linklogger.py
index c456a16..17837b0 100644
--- a/linklogger.py
+++ b/linklogger.py
@@ -1,12 +1,12 @@
import uvicorn
import config
-from api.main import app
-from database import Base, engine
-Base.metadata.create_all(bind=engine)
-
if __name__ == "__main__":
if config.load_config():
+ from api.main import app
+ from database import Base, engine
+
+ Base.metadata.create_all(bind=engine)
uvicorn.run(app, port=5252)