Detect file extensions
This commit is contained in:
parent
b5099937f8
commit
0eaa58ab83
@ -8,13 +8,44 @@ import os
|
||||
|
||||
from src.utils.db import get_db
|
||||
from models import Message
|
||||
from config import BOT_COLOR
|
||||
from config import BOT_COLOR, LOG
|
||||
|
||||
|
||||
class Archive(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
async def download_attachments(attachments) -> list:
|
||||
"""Download attachments and return a list of their paths."""
|
||||
paths = []
|
||||
|
||||
for attachment in attachments:
|
||||
async with aiohttp.ClientSession().get(attachment.url) as response:
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if not content_type:
|
||||
LOG.warn(
|
||||
f"Failed to get content type for: {attachment.url}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create a randomized filename
|
||||
file_extension = content_type.split("/")[-1]
|
||||
filename = (
|
||||
"".join(
|
||||
random.choice(string.ascii_letters) for i in range(10)
|
||||
)
|
||||
+ f".{file_extension}"
|
||||
)
|
||||
|
||||
# Save the image to the filesystem
|
||||
with open(f"images/{filename}]", "wb") as file:
|
||||
file.write(await response.read())
|
||||
|
||||
# Add the path to the attachments list
|
||||
paths.append(f"images/{filename}")
|
||||
|
||||
return paths
|
||||
|
||||
@app_commands.command()
|
||||
async def archive(
|
||||
self,
|
||||
@ -23,6 +54,7 @@ class Archive(commands.Cog):
|
||||
amount: int,
|
||||
):
|
||||
"""Archive a channel's messages."""
|
||||
# Ensure valid channel permissions
|
||||
if not channel.permissions_for(
|
||||
interaction.guild.me
|
||||
).read_message_history:
|
||||
@ -32,6 +64,7 @@ class Archive(commands.Cog):
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
# Ensure valid amount
|
||||
if amount < 1:
|
||||
return await interaction.response.send_message(
|
||||
"You must provide a number greater than 0.",
|
||||
@ -46,51 +79,22 @@ class Archive(commands.Cog):
|
||||
async for message in messages:
|
||||
count += 1
|
||||
|
||||
author = message.author
|
||||
stickers = [sticker.name for sticker in message.stickers]
|
||||
role_mentions = [
|
||||
role_mention.id for role_mention in message.role_mentions
|
||||
]
|
||||
mention_everyone = message.mention_everyone
|
||||
mentions = [mention.id for mention in message.mentions]
|
||||
attachments = [
|
||||
attachment.url for attachment in message.attachments
|
||||
]
|
||||
attachments = []
|
||||
|
||||
if not os.path.exists("images"):
|
||||
os.makedirs("images")
|
||||
|
||||
# Download all images before saving everything to database
|
||||
for url in attachments:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
# Create a randomized filename
|
||||
filename = "".join(
|
||||
random.choice(string.ascii_letters)
|
||||
for i in range(10)
|
||||
)
|
||||
|
||||
# Save the image to the filesystem
|
||||
with open(f"images/{filename}.jpg", "wb") as file:
|
||||
file.write(await response.read())
|
||||
|
||||
# Update the attachment URL to the new filename
|
||||
attachments[attachments.index(url)] = (
|
||||
f"images/{filename}.jpg"
|
||||
)
|
||||
|
||||
content = message.content
|
||||
attachments = await self.download_attachments(message.attachments)
|
||||
|
||||
db_message = Message(
|
||||
author_id=author.id,
|
||||
author_id=message.author.id,
|
||||
channel_id=channel.id,
|
||||
stickers=stickers,
|
||||
role_mentions=role_mentions,
|
||||
mention_everyone=mention_everyone,
|
||||
mentions=mentions,
|
||||
stickers=[sticker.name for sticker in message.stickers],
|
||||
role_mentions=[role.id for role in message.role_mentions],
|
||||
mention_everyone=message.mention_everyone,
|
||||
mentions=[mention.id for mention in message.mentions],
|
||||
attachments=attachments,
|
||||
content=content,
|
||||
content=message.content,
|
||||
)
|
||||
|
||||
db.add(db_message)
|
||||
|
Loading…
x
Reference in New Issue
Block a user