diff options
author | Parker <contact@pkrm.dev> | 2025-04-04 16:46:27 -0500 |
---|---|---|
committer | Parker <contact@pkrm.dev> | 2025-04-04 16:46:27 -0500 |
commit | 382f0f271f3cd5d5b0444a2ffa73a4f700c4d59e (patch) | |
tree | dfad3795d83e59f5572823abee29cfe698c29edb | |
parent | 0b76123301627c69a2a83b8302199d006c5039bb (diff) |
Support multiple file naming schemes
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | config.py | 31 | ||||
-rw-r--r-- | config.yaml.example | 2 | ||||
-rw-r--r-- | src/cogs/archive.py | 58 | ||||
-rw-r--r-- | src/utils/attachments.py | 68 |
5 files changed, 99 insertions, 62 deletions
@@ -36,7 +36,7 @@ To run DisArchive, follow the steps below. > [!NOTE] > If a configuration option is labeled as optional, do not just use an empty value for that field, remove the field entirely from the config file. -### BOT_INFO | REQUIRED +### GENERAL | REQUIRED Field | Description | Requirement --- | --- | --- TOKEN | The token for your bot. Create a bot at [discord.com/developers](https://discord.com/developers) | **REQUIRED** - *message content intent is REQUIRED* @@ -25,6 +25,7 @@ LOG.setLevel(log_level) LOG.addHandler(stream) TOKEN = None +NAMING_SCHEME = None BOT_COLOR = None SQLITE_NAME = "disarchive" @@ -38,11 +39,14 @@ DB_PASSWORD = None schema = { "type": "object", "properties": { - "bot_info": { + "general": { "type": "object", "properties": { "token": {"type": "string"}, "bot_color": {"type": "string", "default": "#fc5f4e"}, + "naming_scheme": { + "enum": ["random", "timestamp", "id", "original"] + }, }, "required": ["token"], }, @@ -88,12 +92,16 @@ schema = { ], }, }, - "required": ["bot_info"], + "required": ["general"], } # Load config file or alert user if not found def load_config(): + # create images directory if it doesn't exist + if not os.path.exists("images"): + os.makedirs("images") + if os.path.exists("/.dockerenv"): file_path = "/config/config.yaml" else: @@ -115,7 +123,7 @@ def load_config(): # Validate the config file against the schema def validate_config(file_contents): - global TOKEN, BOT_COLOR, SQLITE_NAME, DB_NAME, DB_ENGINE, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD + global TOKEN, NAMING_SCHEME, BOT_COLOR, SQLITE_NAME, DB_NAME, DB_ENGINE, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD config = yaml.safe_load(file_contents) try: @@ -128,20 +136,27 @@ def validate_config(file_contents): hex_pattern_two = "^([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$" # Check if the bot_color is a valid hex color - if "bot_color" in config["bot_info"]: + if "bot_color" in config["general"]: if not bool( - re.match(hex_pattern_one, config["bot_info"]["bot_color"]) + re.match(hex_pattern_one, config["general"]["bot_color"]) ) and not bool( - re.match(hex_pattern_two, config["bot_info"]["bot_color"]) + re.match(hex_pattern_two, config["general"]["bot_color"]) ): LOG.warn( "bot_color is not a valid hex color... defaulting to #26dfc9" ) else: BOT_COLOR = discord.Color( - int((config["bot_info"]["bot_color"]).replace("#", ""), 16) + int((config["general"]["bot_color"]).replace("#", ""), 16) ) + # Naming scheme + if "naming_scheme" in config["general"]: + NAMING_SCHEME = config["general"]["naming_scheme"] + else: + LOG.info("No naming scheme specified... defaulting to random") + NAMING_SCHEME = "random" + # Assign database variables if "sqlite" in config: DB_ENGINE = "sqlite" @@ -168,4 +183,4 @@ def validate_config(file_contents): LOG.warn("No database engine specified. Defaulting to SQLite.") DB_ENGINE = "sqlite" - TOKEN = config["bot_info"]["token"] + TOKEN = config["general"]["token"] diff --git a/config.yaml.example b/config.yaml.example index 911e52f..e53ecc5 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -1,4 +1,4 @@ -bot_info: +general: token: "BOT TOKEN" bot_color: 26dfc9 #optional - default is 26dfc9 diff --git a/src/cogs/archive.py b/src/cogs/archive.py index c6fdb08..0c7e852 100644 --- a/src/cogs/archive.py +++ b/src/cogs/archive.py @@ -1,61 +1,17 @@ import discord from discord import app_commands from discord.ext import commands -import aiohttp -import random -import string -import os +from utils.attachments import save_attachments from src.utils.db import get_db from models import Message -from config import BOT_COLOR, LOG +from config import BOT_COLOR class Archive(commands.Cog): def __init__(self, bot): self.bot = bot - async def download_attachments(self, attachments) -> list: - """Download attachments and return a list of their paths.""" - paths = [] - - for attachment in attachments: - async with aiohttp.ClientSession() as session: - async with session.get(attachment.url) as response: - # Check if the request was successful - if response.status != 200: - LOG.warn( - f"Failed to download attachment: {attachment.url}" - ) - continue - - # Check for content type - content_type = response.headers.get("Content-Type") - if not content_type: - LOG.warn( - f"Failed to get content type for: {attachment.url}" - ) - continue - - # Create a randomized filename - file_extension = content_type.split("/")[-1] - filename = ( - "".join( - random.choice(string.ascii_letters) - for i in range(10) - ) - + f".{file_extension}" - ) - - # Save the attachment - with open(f"images/{filename}", "wb") as file: - file.write(await response.read()) - - # Add the path to the attachments list - paths.append(f"images/{filename}") - - return paths - @app_commands.command() async def archive( self, @@ -92,16 +48,13 @@ class Archive(commands.Cog): ) await interaction.response.send_message(embed=embed, ephemeral=True) + # get database session and begin archiving db = next(get_db()) count = 0 messages = channel.history(limit=amount) async for message in messages: count += 1 - - if not os.path.exists("images"): - os.makedirs("images") - - attachments = await self.download_attachments(message.attachments) + paths = await save_attachments(message) db_message = Message( timestamp=message.created_at.isoformat(), @@ -112,7 +65,7 @@ class Archive(commands.Cog): role_mentions=[role.id for role in message.role_mentions], mention_everyone=message.mention_everyone, mentions=[mention.id for mention in message.mentions], - attachments=attachments, + attachments=paths, content=message.content, ) @@ -122,6 +75,7 @@ class Archive(commands.Cog): db.commit() count = 0 + # commit any remaining messages db.commit() embed = discord.Embed( diff --git a/src/utils/attachments.py b/src/utils/attachments.py new file mode 100644 index 0000000..986dae8 --- /dev/null +++ b/src/utils/attachments.py @@ -0,0 +1,68 @@ +import os +import aiohttp +import random +import string + +from config import NAMING_SCHEME, LOG + + +async def save_attachments(message) -> list: + """Download attachments and return a list of their paths.""" + paths = [] + + for attachment in message.attachments: + async with aiohttp.ClientSession() as session: + async with session.get(attachment.url) as response: + # Check if the request was successful + if response.status != 200: + LOG.warn( + f"Failed to download attachment: {attachment.url}" + ) + continue + + # Check for content type + content_type = response.headers.get("Content-Type") + if not content_type: + LOG.warn( + f"Failed to get content type for: {attachment.url}" + ) + continue + + filename = get_filename( + attachment, message, content_type.split("/")[-1] + ) + + # Save the attachment + with open(f"images/{filename}", "wb") as file: + file.write(await response.read()) + + # Add the path to the attachments list + paths.append(f"images/{filename}") + + return paths + + +def get_filename(attachment, message, file_extension) -> str: + """Generate a filename based on the naming scheme.""" + if NAMING_SCHEME == "original": + i = 1 + filename = attachment.filename + # account for duplicate filenames + while os.path.exists(f"images/{filename}.{file_extension}"): + filename = f"{attachment.filename}_{i}" + i += 1 + elif NAMING_SCHEME == "timestamp": + i = 1 + filename = message.created_at.isoformat() + # account for multiple attachments from the same message + while os.path.exists(f"images/{filename}.{file_extension}"): + filename = f"{message.created_at.isoformat()}_{i}" + i += 1 + elif NAMING_SCHEME == "id": + filename = str(attachment.id) + else: # random + filename = "".join( + random.choice(string.ascii_letters) for _ in range(15) + ) + + return f"{filename}.{file_extension}" |