aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorParker <contact@pkrm.dev>2025-04-03 11:24:41 -0500
committerParker <contact@pkrm.dev>2025-04-03 11:24:41 -0500
commit0eaa58ab83091ab3236ae3633c9f4738d56adbe0 (patch)
treeb3ad78ea8f43906f3c118b46dd83128742fb01e1
parentb5099937f84e2fb58b69d97b4b8fff17363fe7d9 (diff)
Detect file extensions
-rw-r--r--src/cogs/archive.py80
1 files changed, 42 insertions, 38 deletions
diff --git a/src/cogs/archive.py b/src/cogs/archive.py
index 36c7044..b5572fe 100644
--- a/src/cogs/archive.py
+++ b/src/cogs/archive.py
@@ -8,13 +8,44 @@ import os
from src.utils.db import get_db
from models import Message
-from config import BOT_COLOR
+from config import BOT_COLOR, LOG
class Archive(commands.Cog):
def __init__(self, bot):
self.bot = bot
+ async def download_attachments(attachments) -> list:
+ """Download attachments and return a list of their paths."""
+ paths = []
+
+ for attachment in attachments:
+ async with aiohttp.ClientSession().get(attachment.url) as response:
+ content_type = response.headers.get("Content-Type")
+ if not content_type:
+ LOG.warn(
+ f"Failed to get content type for: {attachment.url}"
+ )
+ continue
+
+ # Create a randomized filename
+ file_extension = content_type.split("/")[-1]
+ filename = (
+ "".join(
+ random.choice(string.ascii_letters) for i in range(10)
+ )
+ + f".{file_extension}"
+ )
+
+ # Save the image to the filesystem
+ with open(f"images/{filename}]", "wb") as file:
+ file.write(await response.read())
+
+ # Add the path to the attachments list
+ paths.append(f"images/{filename}")
+
+ return paths
+
@app_commands.command()
async def archive(
self,
@@ -23,6 +54,7 @@ class Archive(commands.Cog):
amount: int,
):
"""Archive a channel's messages."""
+ # Ensure valid channel permissions
if not channel.permissions_for(
interaction.guild.me
).read_message_history:
@@ -32,6 +64,7 @@ class Archive(commands.Cog):
ephemeral=True,
)
+ # Ensure valid amount
if amount < 1:
return await interaction.response.send_message(
"You must provide a number greater than 0.",
@@ -46,51 +79,22 @@ class Archive(commands.Cog):
async for message in messages:
count += 1
- author = message.author
- stickers = [sticker.name for sticker in message.stickers]
- role_mentions = [
- role_mention.id for role_mention in message.role_mentions
- ]
- mention_everyone = message.mention_everyone
- mentions = [mention.id for mention in message.mentions]
- attachments = [
- attachment.url for attachment in message.attachments
- ]
+ attachments = []
if not os.path.exists("images"):
os.makedirs("images")
- # Download all images before saving everything to database
- for url in attachments:
- async with aiohttp.ClientSession() as session:
- async with session.get(url) as response:
- if response.status == 200:
- # Create a randomized filename
- filename = "".join(
- random.choice(string.ascii_letters)
- for i in range(10)
- )
-
- # Save the image to the filesystem
- with open(f"images/{filename}.jpg", "wb") as file:
- file.write(await response.read())
-
- # Update the attachment URL to the new filename
- attachments[attachments.index(url)] = (
- f"images/{filename}.jpg"
- )
-
- content = message.content
+ attachments = await self.download_attachments(message.attachments)
db_message = Message(
- author_id=author.id,
+ author_id=message.author.id,
channel_id=channel.id,
- stickers=stickers,
- role_mentions=role_mentions,
- mention_everyone=mention_everyone,
- mentions=mentions,
+ stickers=[sticker.name for sticker in message.stickers],
+ role_mentions=[role.id for role in message.role_mentions],
+ mention_everyone=message.mention_everyone,
+ mentions=[mention.id for mention in message.mentions],
attachments=attachments,
- content=content,
+ content=message.content,
)
db.add(db_message)