Detect file extensions

This commit is contained in:
Parker M. 2025-04-03 11:24:41 -05:00
parent b5099937f8
commit 0eaa58ab83
Signed by: parker
GPG Key ID: 505ED36FC12B5D5E

View File

@ -8,13 +8,44 @@ import os
from src.utils.db import get_db from src.utils.db import get_db
from models import Message from models import Message
from config import BOT_COLOR from config import BOT_COLOR, LOG
class Archive(commands.Cog): class Archive(commands.Cog):
def __init__(self, bot): def __init__(self, bot):
self.bot = bot self.bot = bot
async def download_attachments(attachments) -> list:
"""Download attachments and return a list of their paths."""
paths = []
for attachment in attachments:
async with aiohttp.ClientSession().get(attachment.url) as response:
content_type = response.headers.get("Content-Type")
if not content_type:
LOG.warn(
f"Failed to get content type for: {attachment.url}"
)
continue
# Create a randomized filename
file_extension = content_type.split("/")[-1]
filename = (
"".join(
random.choice(string.ascii_letters) for i in range(10)
)
+ f".{file_extension}"
)
# Save the image to the filesystem
with open(f"images/{filename}]", "wb") as file:
file.write(await response.read())
# Add the path to the attachments list
paths.append(f"images/{filename}")
return paths
@app_commands.command() @app_commands.command()
async def archive( async def archive(
self, self,
@ -23,6 +54,7 @@ class Archive(commands.Cog):
amount: int, amount: int,
): ):
"""Archive a channel's messages.""" """Archive a channel's messages."""
# Ensure valid channel permissions
if not channel.permissions_for( if not channel.permissions_for(
interaction.guild.me interaction.guild.me
).read_message_history: ).read_message_history:
@ -32,6 +64,7 @@ class Archive(commands.Cog):
ephemeral=True, ephemeral=True,
) )
# Ensure valid amount
if amount < 1: if amount < 1:
return await interaction.response.send_message( return await interaction.response.send_message(
"You must provide a number greater than 0.", "You must provide a number greater than 0.",
@ -46,51 +79,22 @@ class Archive(commands.Cog):
async for message in messages: async for message in messages:
count += 1 count += 1
author = message.author attachments = []
stickers = [sticker.name for sticker in message.stickers]
role_mentions = [
role_mention.id for role_mention in message.role_mentions
]
mention_everyone = message.mention_everyone
mentions = [mention.id for mention in message.mentions]
attachments = [
attachment.url for attachment in message.attachments
]
if not os.path.exists("images"): if not os.path.exists("images"):
os.makedirs("images") os.makedirs("images")
# Download all images before saving everything to database attachments = await self.download_attachments(message.attachments)
for url in attachments:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
# Create a randomized filename
filename = "".join(
random.choice(string.ascii_letters)
for i in range(10)
)
# Save the image to the filesystem
with open(f"images/{filename}.jpg", "wb") as file:
file.write(await response.read())
# Update the attachment URL to the new filename
attachments[attachments.index(url)] = (
f"images/{filename}.jpg"
)
content = message.content
db_message = Message( db_message = Message(
author_id=author.id, author_id=message.author.id,
channel_id=channel.id, channel_id=channel.id,
stickers=stickers, stickers=[sticker.name for sticker in message.stickers],
role_mentions=role_mentions, role_mentions=[role.id for role in message.role_mentions],
mention_everyone=mention_everyone, mention_everyone=message.mention_everyone,
mentions=mentions, mentions=[mention.id for mention in message.mentions],
attachments=attachments, attachments=attachments,
content=content, content=message.content,
) )
db.add(db_message) db.add(db_message)