diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cogs/archive.py | 58 | ||||
-rw-r--r-- | src/utils/attachments.py | 68 |
2 files changed, 74 insertions, 52 deletions
diff --git a/src/cogs/archive.py b/src/cogs/archive.py index c6fdb08..0c7e852 100644 --- a/src/cogs/archive.py +++ b/src/cogs/archive.py @@ -1,61 +1,17 @@ import discord from discord import app_commands from discord.ext import commands -import aiohttp -import random -import string -import os +from utils.attachments import save_attachments from src.utils.db import get_db from models import Message -from config import BOT_COLOR, LOG +from config import BOT_COLOR class Archive(commands.Cog): def __init__(self, bot): self.bot = bot - async def download_attachments(self, attachments) -> list: - """Download attachments and return a list of their paths.""" - paths = [] - - for attachment in attachments: - async with aiohttp.ClientSession() as session: - async with session.get(attachment.url) as response: - # Check if the request was successful - if response.status != 200: - LOG.warn( - f"Failed to download attachment: {attachment.url}" - ) - continue - - # Check for content type - content_type = response.headers.get("Content-Type") - if not content_type: - LOG.warn( - f"Failed to get content type for: {attachment.url}" - ) - continue - - # Create a randomized filename - file_extension = content_type.split("/")[-1] - filename = ( - "".join( - random.choice(string.ascii_letters) - for i in range(10) - ) - + f".{file_extension}" - ) - - # Save the attachment - with open(f"images/{filename}", "wb") as file: - file.write(await response.read()) - - # Add the path to the attachments list - paths.append(f"images/{filename}") - - return paths - @app_commands.command() async def archive( self, @@ -92,16 +48,13 @@ class Archive(commands.Cog): ) await interaction.response.send_message(embed=embed, ephemeral=True) + # get database session and begin archiving db = next(get_db()) count = 0 messages = channel.history(limit=amount) async for message in messages: count += 1 - - if not os.path.exists("images"): - os.makedirs("images") - - attachments = await self.download_attachments(message.attachments) + paths = await save_attachments(message) db_message = Message( timestamp=message.created_at.isoformat(), @@ -112,7 +65,7 @@ class Archive(commands.Cog): role_mentions=[role.id for role in message.role_mentions], mention_everyone=message.mention_everyone, mentions=[mention.id for mention in message.mentions], - attachments=attachments, + attachments=paths, content=message.content, ) @@ -122,6 +75,7 @@ class Archive(commands.Cog): db.commit() count = 0 + # commit any remaining messages db.commit() embed = discord.Embed( diff --git a/src/utils/attachments.py b/src/utils/attachments.py new file mode 100644 index 0000000..986dae8 --- /dev/null +++ b/src/utils/attachments.py @@ -0,0 +1,68 @@ +import os +import aiohttp +import random +import string + +from config import NAMING_SCHEME, LOG + + +async def save_attachments(message) -> list: + """Download attachments and return a list of their paths.""" + paths = [] + + for attachment in message.attachments: + async with aiohttp.ClientSession() as session: + async with session.get(attachment.url) as response: + # Check if the request was successful + if response.status != 200: + LOG.warn( + f"Failed to download attachment: {attachment.url}" + ) + continue + + # Check for content type + content_type = response.headers.get("Content-Type") + if not content_type: + LOG.warn( + f"Failed to get content type for: {attachment.url}" + ) + continue + + filename = get_filename( + attachment, message, content_type.split("/")[-1] + ) + + # Save the attachment + with open(f"images/{filename}", "wb") as file: + file.write(await response.read()) + + # Add the path to the attachments list + paths.append(f"images/{filename}") + + return paths + + +def get_filename(attachment, message, file_extension) -> str: + """Generate a filename based on the naming scheme.""" + if NAMING_SCHEME == "original": + i = 1 + filename = attachment.filename + # account for duplicate filenames + while os.path.exists(f"images/{filename}.{file_extension}"): + filename = f"{attachment.filename}_{i}" + i += 1 + elif NAMING_SCHEME == "timestamp": + i = 1 + filename = message.created_at.isoformat() + # account for multiple attachments from the same message + while os.path.exists(f"images/{filename}.{file_extension}"): + filename = f"{message.created_at.isoformat()}_{i}" + i += 1 + elif NAMING_SCHEME == "id": + filename = str(attachment.id) + else: # random + filename = "".join( + random.choice(string.ascii_letters) for _ in range(15) + ) + + return f"{filename}.{file_extension}" |