diff --git a/.gitignore b/.gitignore index 0dbf2f2..805d269 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,6 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +images/ +data/ +config.yaml \ No newline at end of file diff --git a/bot.py b/bot.py new file mode 100644 index 0000000..380ace3 --- /dev/null +++ b/bot.py @@ -0,0 +1,43 @@ +import discord +from discord.ext import commands +import os +from database import Base, engine + +import config as config +from src.utils.command_tree import Tree + + +class MyBot(commands.Bot): + def __init__(self): + super().__init__( + command_prefix="***", + activity=discord.Game(name="music!"), + intents=discord.Intents.default(), + tree_cls=Tree, + ) + + async def setup_hook(self): + for ext in os.listdir("./src/cogs"): + if ext.endswith(".py"): + await self.load_extension(f"src.cogs.{ext[:-3]}") + + for ext in os.listdir("./src/cogs/owner"): + if ext.endswith(".py"): + await self.load_extension(f"src.cogs.owner.{ext[:-3]}") + + async def on_ready(self): + Base.metadata.create_all(bind=engine) + config.LOG.info(f"{bot.user} has connected to Discord.") + config.LOG.info( + "Startup complete. Sync slash commands by DMing the bot" + f" {bot.command_prefix}tree sync (guild id)" + ) + + +bot = MyBot() +bot.remove_command("help") + + +if __name__ == "__main__": + config.load_config() + bot.run(config.TOKEN) diff --git a/config.py b/config.py new file mode 100644 index 0000000..720ecdd --- /dev/null +++ b/config.py @@ -0,0 +1,189 @@ +import jsonschema +import os +import re +import yaml +import sys +import discord +import logging +from colorlog import ColoredFormatter + +log_level = logging.DEBUG +log_format = ( + " %(log_color)s%(levelname)-8s%(reset)s |" + " %(log_color)s%(message)s%(reset)s" +) + +logging.root.setLevel(log_level) +formatter = ColoredFormatter(log_format) + +stream = logging.StreamHandler() +stream.setLevel(log_level) +stream.setFormatter(formatter) + +LOG = logging.getLogger("pythonConfig") +LOG.setLevel(log_level) +LOG.addHandler(stream) + +TOKEN = None +BOT_COLOR = None + +DB_NAME = None +DB_ENGINE = None +DB_HOST = None +DB_PORT = None +DB_USER = None +DB_PASSWORD = None + + +schema = { + "type": "object", + "properties": { + "bot_info": { + "type": "object", + "properties": { + "token": {"type": "string"}, + "bot_color": {"type": "string", "default": "#fc5f4e"}, + }, + "required": ["token"], + }, + "database": { + "type": "object", + "properties": { + "name": {"type": "string", "default": "disarchive"}, + "engine": {"type": "string"}, + "host": {"type": "string"}, + "port": {"type": "integer"}, + "user": {"type": "string"}, + "password": {"type": "string"}, + }, + "required": [ + "name", + "engine", + "host", + "port", + "user", + "password", + ], + }, + }, + "required": ["bot_info", "database"], +} + + +# Attempt to load the config file, otherwise create a new template +def load_config(): + if os.path.exists("/.dockerenv"): + file_path = "/config/config.yaml" + else: + file_path = "config.yaml" + + try: + with open(file_path, "r") as f: + file_contents = f.read() + validate_config(file_contents) + + except FileNotFoundError: + # Create a new config.yaml file with the template + with open(file_path, "w") as f: + f.write( + """ +bot_info: + token: + bot_color: + +database: + name: + engine: + host: + port: + user: + password: """ + ) + + sys.exit( + LOG.critical( + "Configuration file `config.yaml` has been generated. Please" + " fill out all of the necessary information. Refer to the docs" + " for information on what a specific configuration option is." + ) + ) + + +# Thouroughly validate all of the options in the config.yaml file +def validate_config(file_contents): + global TOKEN, BOT_COLOR, DB_NAME, DB_ENGINE, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD + config = yaml.safe_load(file_contents) + + try: + jsonschema.validate(config, schema) + except jsonschema.ValidationError as e: + sys.exit(LOG.critical(f"Error in config.yaml file: {e.message}")) + + # Make sure "bot_color" is a valid hex color + hex_pattern_one = "^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$" + hex_pattern_two = "^([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$" + + if "bot_color" in config["bot_info"]: + if not bool( + re.match(hex_pattern_one, config["bot_info"]["bot_color"]) + ) and not bool( + re.match(hex_pattern_two, config["bot_info"]["bot_color"]) + ): + LOG.critical( + "Error in config.yaml file: bot_color is not a valid hex color" + ) + else: + BOT_COLOR = discord.Color( + int((config["bot_info"]["bot_color"]).replace("#", ""), 16) + ) + + if config["database"]["engine"] not in [ + "sqlite", + "mysql", + "postgresql", + ]: + LOG.error( + "database_engine must be either 'sqlite', 'mysql', or 'postgresql'" + ) + return False + else: + DB_ENGINE = config["database"]["engine"] + + DB_NAME = config["database"]["name"] + DB_HOST = config["database"]["host"] + DB_PORT = config["database"]["port"] + DB_USER = config["database"]["user"] + DB_PASSWORD = config["database"]["password"] + + TOKEN = config["bot_info"]["token"] + + +""" +Template for embeds +""" + + +def create_embed( + title: str = None, + description: str = None, + color=None, + footer=None, + thumbnail=None, +): + embed = discord.Embed( + title=title, + description=description, + color=color if color else BOT_COLOR, + ) + + if footer: + embed.set_footer(text=footer) + # else: + # embed.set_footer( + # text=datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") + " UTC" + # ) + + if thumbnail: + embed.set_thumbnail(url=thumbnail) + + return embed diff --git a/database.py b/database.py new file mode 100644 index 0000000..1258e42 --- /dev/null +++ b/database.py @@ -0,0 +1,21 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +import os + +import config + +if config.DB_ENGINE == "mysql": + database_url = f"mysql+pymysql://{config.DB_USER}:{config.DB_PASSWORD}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}" + +elif config.DB_ENGINE == "postgresql": + database_url = f"postgresql+psycopg2://{config.DB_USER}:{config.DB_PASSWORD}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}" + +else: + if not os.path.exists("data"): + os.makedirs("data") + database_url = "sqlite:///data/data.db" + +engine = create_engine(database_url) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +Base = declarative_base() diff --git a/models.py b/models.py new file mode 100644 index 0000000..2c04f67 --- /dev/null +++ b/models.py @@ -0,0 +1,17 @@ +from sqlalchemy import Column, Boolean, Integer, String, JSON + +from database import Base + + +class Message(Base): + __tablename__ = "messages" + + id = Column(Integer, primary_key=True, autoincrement=True) + author_id = Column(Integer, nullable=False) + channel_id = Column(Integer, nullable=False) + stickers = Column(JSON, nullable=False, default=list) + role_mentions = Column(JSON, nullable=False, default=list) + mention_everyone = Column(Boolean, nullable=False, default=False) + mentions = Column(JSON, nullable=False, default=list) + attachments = Column(JSON, nullable=False, default=list) + content = Column(String, nullable=False) diff --git a/src/cogs/archive.py b/src/cogs/archive.py new file mode 100644 index 0000000..36c7044 --- /dev/null +++ b/src/cogs/archive.py @@ -0,0 +1,101 @@ +import discord +from discord import app_commands +from discord.ext import commands +import aiohttp +import random +import string +import os + +from src.utils.db import get_db +from models import Message +from config import BOT_COLOR + + +class Archive(commands.Cog): + def __init__(self, bot): + self.bot = bot + + @app_commands.command() + async def archive( + self, + interaction: discord.Interaction, + channel: discord.TextChannel, + amount: int, + ): + """Archive a channel's messages.""" + if not channel.permissions_for( + interaction.guild.me + ).read_message_history: + return await interaction.response.send_message( + "I do not have permission to read message history in that" + " channel.", + ephemeral=True, + ) + + if amount < 1: + return await interaction.response.send_message( + "You must provide a number greater than 0.", + ephemeral=True, + ) + + await interaction.response.send_message("Archiving messages now.") + + db = next(get_db()) + count = 0 + messages = channel.history(limit=amount) + async for message in messages: + count += 1 + + author = message.author + stickers = [sticker.name for sticker in message.stickers] + role_mentions = [ + role_mention.id for role_mention in message.role_mentions + ] + mention_everyone = message.mention_everyone + mentions = [mention.id for mention in message.mentions] + attachments = [ + attachment.url for attachment in message.attachments + ] + + if not os.path.exists("images"): + os.makedirs("images") + + # Download all images before saving everything to database + for url in attachments: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + # Create a randomized filename + filename = "".join( + random.choice(string.ascii_letters) + for i in range(10) + ) + + # Save the image to the filesystem + with open(f"images/{filename}.jpg", "wb") as file: + file.write(await response.read()) + + # Update the attachment URL to the new filename + attachments[attachments.index(url)] = ( + f"images/{filename}.jpg" + ) + + content = message.content + + db_message = Message( + author_id=author.id, + channel_id=channel.id, + stickers=stickers, + role_mentions=role_mentions, + mention_everyone=mention_everyone, + mentions=mentions, + attachments=attachments, + content=content, + ) + + db.add(db_message) + db.commit() + + +async def setup(bot): + await bot.add_cog(Archive(bot)) diff --git a/src/cogs/owner/sync.py b/src/cogs/owner/sync.py new file mode 100644 index 0000000..d6647b5 --- /dev/null +++ b/src/cogs/owner/sync.py @@ -0,0 +1,87 @@ +import discord +from discord.ext import commands + + +class TreeSync(commands.Cog): + def __init__(self, bot): + self.bot = bot + + @commands.group(invoke_without_command=True) + @commands.dm_only() + @commands.is_owner() + async def tree(self, ctx): + await ctx.author.send( + "This is a group command. Use either" + f" `{self.bot.command_prefix}tree sync` or" + f" `{self.bot.command_prefix}tree clear` followed by an optional" + " guild ID." + ) + + @commands.dm_only() + @commands.is_owner() + @tree.command() + async def sync( + self, ctx: commands.Context, *, guild: discord.Object = None + ): + """Sync the command tree to a guild or globally.""" + if guild: + self.bot.tree.copy_global_to(guild=guild) + await self.bot.tree.sync(guild=guild) + return await ctx.author.send( + "Synced the command tree to" + f" `{self.bot.get_guild(guild.id).name}`" + ) + else: + await self.bot.tree.sync() + return await ctx.author.send("Synced the command tree globally.") + + @sync.error + async def tree_sync_error(self, ctx, error): + if isinstance(error, commands.ObjectNotFound): + return await ctx.author.send( + "The guild you provided does not exist." + ) + if isinstance(error, commands.CommandInvokeError): + return await ctx.author.send( + "Guild ID provided is not a guild that the bot is in." + ) + else: + return await ctx.author.send( + "An unknown error occurred. Perhaps you've been rate limited." + ) + + @commands.dm_only() + @commands.is_owner() + @tree.command() + async def clear(self, ctx: commands.Context, *, guild: discord.Object): + """Clear the command tree from a guild.""" + self.bot.tree.clear_commands(guild=guild) + await self.bot.tree.sync(guild=guild) + return await ctx.author.send( + "Cleared the command tree from" + f" `{self.bot.get_guild(guild.id).name}`" + ) + + @clear.error + async def tree_sync_error(self, ctx, error): + if isinstance(error, commands.MissingRequiredArgument): + return await ctx.author.send( + "You need to provide a guild ID to clear the command tree" + " from." + ) + if isinstance(error, commands.ObjectNotFound): + return await ctx.author.send( + "The guild you provided does not exist." + ) + if isinstance(error, commands.CommandInvokeError): + return await ctx.author.send( + "Guild ID provided is not a guild that the bot is in." + ) + else: + return await ctx.author.send( + "An unknown error occurred. Perhaps you've been rate limited." + ) + + +async def setup(bot): + await bot.add_cog(TreeSync(bot)) diff --git a/src/schemas/message_schema.py b/src/schemas/message_schema.py new file mode 100644 index 0000000..79c9659 --- /dev/null +++ b/src/schemas/message_schema.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + + +class Message(BaseModel): + author_id: int + channel_id: int + stickers: list[str] + role_mentions: list[int] + mention_everyone: bool + mentions: list[int] + attachments: list[str] + content: str diff --git a/src/utils/command_tree.py b/src/utils/command_tree.py new file mode 100644 index 0000000..31ec2c1 --- /dev/null +++ b/src/utils/command_tree.py @@ -0,0 +1,19 @@ +import discord +from discord import app_commands +from discord.ext.commands.errors import * + +from config import create_embed + + +class Tree(app_commands.CommandTree): + async def on_error( + self, + interaction: discord.Interaction, + error: app_commands.AppCommandError, + ): + + if isinstance(error, CommandNotFound): + return + + else: + raise error diff --git a/src/utils/db.py b/src/utils/db.py new file mode 100644 index 0000000..a6734ea --- /dev/null +++ b/src/utils/db.py @@ -0,0 +1,9 @@ +from database import SessionLocal + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close()