aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorParker <contact@pkrm.dev>2025-04-04 16:46:27 -0500
committerParker <contact@pkrm.dev>2025-04-04 16:46:27 -0500
commit382f0f271f3cd5d5b0444a2ffa73a4f700c4d59e (patch)
treedfad3795d83e59f5572823abee29cfe698c29edb
parent0b76123301627c69a2a83b8302199d006c5039bb (diff)
Support multiple file naming schemes
-rw-r--r--README.md2
-rw-r--r--config.py31
-rw-r--r--config.yaml.example2
-rw-r--r--src/cogs/archive.py58
-rw-r--r--src/utils/attachments.py68
5 files changed, 99 insertions, 62 deletions
diff --git a/README.md b/README.md
index 8a64151..31d68ca 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@ To run DisArchive, follow the steps below.
> [!NOTE]
> If a configuration option is labeled as optional, do not just use an empty value for that field, remove the field entirely from the config file.
-### BOT_INFO | REQUIRED
+### GENERAL | REQUIRED
Field | Description | Requirement
--- | --- | ---
TOKEN | The token for your bot. Create a bot at [discord.com/developers](https://discord.com/developers) | **REQUIRED** - *message content intent is REQUIRED*
diff --git a/config.py b/config.py
index 1c44585..4ea598f 100644
--- a/config.py
+++ b/config.py
@@ -25,6 +25,7 @@ LOG.setLevel(log_level)
LOG.addHandler(stream)
TOKEN = None
+NAMING_SCHEME = None
BOT_COLOR = None
SQLITE_NAME = "disarchive"
@@ -38,11 +39,14 @@ DB_PASSWORD = None
schema = {
"type": "object",
"properties": {
- "bot_info": {
+ "general": {
"type": "object",
"properties": {
"token": {"type": "string"},
"bot_color": {"type": "string", "default": "#fc5f4e"},
+ "naming_scheme": {
+ "enum": ["random", "timestamp", "id", "original"]
+ },
},
"required": ["token"],
},
@@ -88,12 +92,16 @@ schema = {
],
},
},
- "required": ["bot_info"],
+ "required": ["general"],
}
# Load config file or alert user if not found
def load_config():
+ # create images directory if it doesn't exist
+ if not os.path.exists("images"):
+ os.makedirs("images")
+
if os.path.exists("/.dockerenv"):
file_path = "/config/config.yaml"
else:
@@ -115,7 +123,7 @@ def load_config():
# Validate the config file against the schema
def validate_config(file_contents):
- global TOKEN, BOT_COLOR, SQLITE_NAME, DB_NAME, DB_ENGINE, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD
+ global TOKEN, NAMING_SCHEME, BOT_COLOR, SQLITE_NAME, DB_NAME, DB_ENGINE, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD
config = yaml.safe_load(file_contents)
try:
@@ -128,20 +136,27 @@ def validate_config(file_contents):
hex_pattern_two = "^([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
# Check if the bot_color is a valid hex color
- if "bot_color" in config["bot_info"]:
+ if "bot_color" in config["general"]:
if not bool(
- re.match(hex_pattern_one, config["bot_info"]["bot_color"])
+ re.match(hex_pattern_one, config["general"]["bot_color"])
) and not bool(
- re.match(hex_pattern_two, config["bot_info"]["bot_color"])
+ re.match(hex_pattern_two, config["general"]["bot_color"])
):
LOG.warn(
"bot_color is not a valid hex color... defaulting to #26dfc9"
)
else:
BOT_COLOR = discord.Color(
- int((config["bot_info"]["bot_color"]).replace("#", ""), 16)
+ int((config["general"]["bot_color"]).replace("#", ""), 16)
)
+ # Naming scheme
+ if "naming_scheme" in config["general"]:
+ NAMING_SCHEME = config["general"]["naming_scheme"]
+ else:
+ LOG.info("No naming scheme specified... defaulting to random")
+ NAMING_SCHEME = "random"
+
# Assign database variables
if "sqlite" in config:
DB_ENGINE = "sqlite"
@@ -168,4 +183,4 @@ def validate_config(file_contents):
LOG.warn("No database engine specified. Defaulting to SQLite.")
DB_ENGINE = "sqlite"
- TOKEN = config["bot_info"]["token"]
+ TOKEN = config["general"]["token"]
diff --git a/config.yaml.example b/config.yaml.example
index 911e52f..e53ecc5 100644
--- a/config.yaml.example
+++ b/config.yaml.example
@@ -1,4 +1,4 @@
-bot_info:
+general:
token: "BOT TOKEN"
bot_color: 26dfc9 #optional - default is 26dfc9
diff --git a/src/cogs/archive.py b/src/cogs/archive.py
index c6fdb08..0c7e852 100644
--- a/src/cogs/archive.py
+++ b/src/cogs/archive.py
@@ -1,61 +1,17 @@
import discord
from discord import app_commands
from discord.ext import commands
-import aiohttp
-import random
-import string
-import os
+from utils.attachments import save_attachments
from src.utils.db import get_db
from models import Message
-from config import BOT_COLOR, LOG
+from config import BOT_COLOR
class Archive(commands.Cog):
def __init__(self, bot):
self.bot = bot
- async def download_attachments(self, attachments) -> list:
- """Download attachments and return a list of their paths."""
- paths = []
-
- for attachment in attachments:
- async with aiohttp.ClientSession() as session:
- async with session.get(attachment.url) as response:
- # Check if the request was successful
- if response.status != 200:
- LOG.warn(
- f"Failed to download attachment: {attachment.url}"
- )
- continue
-
- # Check for content type
- content_type = response.headers.get("Content-Type")
- if not content_type:
- LOG.warn(
- f"Failed to get content type for: {attachment.url}"
- )
- continue
-
- # Create a randomized filename
- file_extension = content_type.split("/")[-1]
- filename = (
- "".join(
- random.choice(string.ascii_letters)
- for i in range(10)
- )
- + f".{file_extension}"
- )
-
- # Save the attachment
- with open(f"images/{filename}", "wb") as file:
- file.write(await response.read())
-
- # Add the path to the attachments list
- paths.append(f"images/{filename}")
-
- return paths
-
@app_commands.command()
async def archive(
self,
@@ -92,16 +48,13 @@ class Archive(commands.Cog):
)
await interaction.response.send_message(embed=embed, ephemeral=True)
+ # get database session and begin archiving
db = next(get_db())
count = 0
messages = channel.history(limit=amount)
async for message in messages:
count += 1
-
- if not os.path.exists("images"):
- os.makedirs("images")
-
- attachments = await self.download_attachments(message.attachments)
+ paths = await save_attachments(message)
db_message = Message(
timestamp=message.created_at.isoformat(),
@@ -112,7 +65,7 @@ class Archive(commands.Cog):
role_mentions=[role.id for role in message.role_mentions],
mention_everyone=message.mention_everyone,
mentions=[mention.id for mention in message.mentions],
- attachments=attachments,
+ attachments=paths,
content=message.content,
)
@@ -122,6 +75,7 @@ class Archive(commands.Cog):
db.commit()
count = 0
+ # commit any remaining messages
db.commit()
embed = discord.Embed(
diff --git a/src/utils/attachments.py b/src/utils/attachments.py
new file mode 100644
index 0000000..986dae8
--- /dev/null
+++ b/src/utils/attachments.py
@@ -0,0 +1,68 @@
+import os
+import aiohttp
+import random
+import string
+
+from config import NAMING_SCHEME, LOG
+
+
+async def save_attachments(message) -> list:
+ """Download attachments and return a list of their paths."""
+ paths = []
+
+ for attachment in message.attachments:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(attachment.url) as response:
+ # Check if the request was successful
+ if response.status != 200:
+ LOG.warn(
+ f"Failed to download attachment: {attachment.url}"
+ )
+ continue
+
+ # Check for content type
+ content_type = response.headers.get("Content-Type")
+ if not content_type:
+ LOG.warn(
+ f"Failed to get content type for: {attachment.url}"
+ )
+ continue
+
+ filename = get_filename(
+ attachment, message, content_type.split("/")[-1]
+ )
+
+ # Save the attachment
+ with open(f"images/{filename}", "wb") as file:
+ file.write(await response.read())
+
+ # Add the path to the attachments list
+ paths.append(f"images/{filename}")
+
+ return paths
+
+
+def get_filename(attachment, message, file_extension) -> str:
+ """Generate a filename based on the naming scheme."""
+ if NAMING_SCHEME == "original":
+ i = 1
+ filename = attachment.filename
+ # account for duplicate filenames
+ while os.path.exists(f"images/{filename}.{file_extension}"):
+ filename = f"{attachment.filename}_{i}"
+ i += 1
+ elif NAMING_SCHEME == "timestamp":
+ i = 1
+ filename = message.created_at.isoformat()
+ # account for multiple attachments from the same message
+ while os.path.exists(f"images/{filename}.{file_extension}"):
+ filename = f"{message.created_at.isoformat()}_{i}"
+ i += 1
+ elif NAMING_SCHEME == "id":
+ filename = str(attachment.id)
+ else: # random
+ filename = "".join(
+ random.choice(string.ascii_letters) for _ in range(15)
+ )
+
+ return f"{filename}.{file_extension}"