From 0eaa58ab83091ab3236ae3633c9f4738d56adbe0 Mon Sep 17 00:00:00 2001 From: Parker Date: Thu, 3 Apr 2025 11:24:41 -0500 Subject: [PATCH] Detect file extensions --- src/cogs/archive.py | 80 ++++++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/src/cogs/archive.py b/src/cogs/archive.py index 36c7044..b5572fe 100644 --- a/src/cogs/archive.py +++ b/src/cogs/archive.py @@ -8,13 +8,44 @@ import os from src.utils.db import get_db from models import Message -from config import BOT_COLOR +from config import BOT_COLOR, LOG class Archive(commands.Cog): def __init__(self, bot): self.bot = bot + async def download_attachments(attachments) -> list: + """Download attachments and return a list of their paths.""" + paths = [] + + for attachment in attachments: + async with aiohttp.ClientSession().get(attachment.url) as response: + content_type = response.headers.get("Content-Type") + if not content_type: + LOG.warn( + f"Failed to get content type for: {attachment.url}" + ) + continue + + # Create a randomized filename + file_extension = content_type.split("/")[-1] + filename = ( + "".join( + random.choice(string.ascii_letters) for i in range(10) + ) + + f".{file_extension}" + ) + + # Save the image to the filesystem + with open(f"images/{filename}]", "wb") as file: + file.write(await response.read()) + + # Add the path to the attachments list + paths.append(f"images/{filename}") + + return paths + @app_commands.command() async def archive( self, @@ -23,6 +54,7 @@ class Archive(commands.Cog): amount: int, ): """Archive a channel's messages.""" + # Ensure valid channel permissions if not channel.permissions_for( interaction.guild.me ).read_message_history: @@ -32,6 +64,7 @@ class Archive(commands.Cog): ephemeral=True, ) + # Ensure valid amount if amount < 1: return await interaction.response.send_message( "You must provide a number greater than 0.", @@ -46,51 +79,22 @@ class Archive(commands.Cog): async for message in messages: count += 1 - author = message.author - stickers = [sticker.name for sticker in message.stickers] - role_mentions = [ - role_mention.id for role_mention in message.role_mentions - ] - mention_everyone = message.mention_everyone - mentions = [mention.id for mention in message.mentions] - attachments = [ - attachment.url for attachment in message.attachments - ] + attachments = [] if not os.path.exists("images"): os.makedirs("images") - # Download all images before saving everything to database - for url in attachments: - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - # Create a randomized filename - filename = "".join( - random.choice(string.ascii_letters) - for i in range(10) - ) - - # Save the image to the filesystem - with open(f"images/{filename}.jpg", "wb") as file: - file.write(await response.read()) - - # Update the attachment URL to the new filename - attachments[attachments.index(url)] = ( - f"images/{filename}.jpg" - ) - - content = message.content + attachments = await self.download_attachments(message.attachments) db_message = Message( - author_id=author.id, + author_id=message.author.id, channel_id=channel.id, - stickers=stickers, - role_mentions=role_mentions, - mention_everyone=mention_everyone, - mentions=mentions, + stickers=[sticker.name for sticker in message.stickers], + role_mentions=[role.id for role in message.role_mentions], + mention_everyone=message.mention_everyone, + mentions=[mention.id for mention in message.mentions], attachments=attachments, - content=content, + content=message.content, ) db.add(db_message)