First commit. Lots of stuff
This commit is contained in:
parent
ae1ac1d731
commit
b5099937f8
3
.gitignore
vendored
3
.gitignore
vendored
@ -168,3 +168,6 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
images/
|
||||
data/
|
||||
config.yaml
|
43
bot.py
Normal file
43
bot.py
Normal file
@ -0,0 +1,43 @@
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import os
|
||||
from database import Base, engine
|
||||
|
||||
import config as config
|
||||
from src.utils.command_tree import Tree
|
||||
|
||||
|
||||
class MyBot(commands.Bot):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
command_prefix="***",
|
||||
activity=discord.Game(name="music!"),
|
||||
intents=discord.Intents.default(),
|
||||
tree_cls=Tree,
|
||||
)
|
||||
|
||||
async def setup_hook(self):
|
||||
for ext in os.listdir("./src/cogs"):
|
||||
if ext.endswith(".py"):
|
||||
await self.load_extension(f"src.cogs.{ext[:-3]}")
|
||||
|
||||
for ext in os.listdir("./src/cogs/owner"):
|
||||
if ext.endswith(".py"):
|
||||
await self.load_extension(f"src.cogs.owner.{ext[:-3]}")
|
||||
|
||||
async def on_ready(self):
|
||||
Base.metadata.create_all(bind=engine)
|
||||
config.LOG.info(f"{bot.user} has connected to Discord.")
|
||||
config.LOG.info(
|
||||
"Startup complete. Sync slash commands by DMing the bot"
|
||||
f" {bot.command_prefix}tree sync (guild id)"
|
||||
)
|
||||
|
||||
|
||||
bot = MyBot()
|
||||
bot.remove_command("help")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.load_config()
|
||||
bot.run(config.TOKEN)
|
189
config.py
Normal file
189
config.py
Normal file
@ -0,0 +1,189 @@
|
||||
import jsonschema
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
import sys
|
||||
import discord
|
||||
import logging
|
||||
from colorlog import ColoredFormatter
|
||||
|
||||
log_level = logging.DEBUG
|
||||
log_format = (
|
||||
" %(log_color)s%(levelname)-8s%(reset)s |"
|
||||
" %(log_color)s%(message)s%(reset)s"
|
||||
)
|
||||
|
||||
logging.root.setLevel(log_level)
|
||||
formatter = ColoredFormatter(log_format)
|
||||
|
||||
stream = logging.StreamHandler()
|
||||
stream.setLevel(log_level)
|
||||
stream.setFormatter(formatter)
|
||||
|
||||
LOG = logging.getLogger("pythonConfig")
|
||||
LOG.setLevel(log_level)
|
||||
LOG.addHandler(stream)
|
||||
|
||||
TOKEN = None
|
||||
BOT_COLOR = None
|
||||
|
||||
DB_NAME = None
|
||||
DB_ENGINE = None
|
||||
DB_HOST = None
|
||||
DB_PORT = None
|
||||
DB_USER = None
|
||||
DB_PASSWORD = None
|
||||
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bot_info": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"token": {"type": "string"},
|
||||
"bot_color": {"type": "string", "default": "#fc5f4e"},
|
||||
},
|
||||
"required": ["token"],
|
||||
},
|
||||
"database": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "default": "disarchive"},
|
||||
"engine": {"type": "string"},
|
||||
"host": {"type": "string"},
|
||||
"port": {"type": "integer"},
|
||||
"user": {"type": "string"},
|
||||
"password": {"type": "string"},
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"engine",
|
||||
"host",
|
||||
"port",
|
||||
"user",
|
||||
"password",
|
||||
],
|
||||
},
|
||||
},
|
||||
"required": ["bot_info", "database"],
|
||||
}
|
||||
|
||||
|
||||
# Attempt to load the config file, otherwise create a new template
|
||||
def load_config():
|
||||
if os.path.exists("/.dockerenv"):
|
||||
file_path = "/config/config.yaml"
|
||||
else:
|
||||
file_path = "config.yaml"
|
||||
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
file_contents = f.read()
|
||||
validate_config(file_contents)
|
||||
|
||||
except FileNotFoundError:
|
||||
# Create a new config.yaml file with the template
|
||||
with open(file_path, "w") as f:
|
||||
f.write(
|
||||
"""
|
||||
bot_info:
|
||||
token:
|
||||
bot_color:
|
||||
|
||||
database:
|
||||
name:
|
||||
engine:
|
||||
host:
|
||||
port:
|
||||
user:
|
||||
password: """
|
||||
)
|
||||
|
||||
sys.exit(
|
||||
LOG.critical(
|
||||
"Configuration file `config.yaml` has been generated. Please"
|
||||
" fill out all of the necessary information. Refer to the docs"
|
||||
" for information on what a specific configuration option is."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Thouroughly validate all of the options in the config.yaml file
|
||||
def validate_config(file_contents):
|
||||
global TOKEN, BOT_COLOR, DB_NAME, DB_ENGINE, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD
|
||||
config = yaml.safe_load(file_contents)
|
||||
|
||||
try:
|
||||
jsonschema.validate(config, schema)
|
||||
except jsonschema.ValidationError as e:
|
||||
sys.exit(LOG.critical(f"Error in config.yaml file: {e.message}"))
|
||||
|
||||
# Make sure "bot_color" is a valid hex color
|
||||
hex_pattern_one = "^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
|
||||
hex_pattern_two = "^([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
|
||||
|
||||
if "bot_color" in config["bot_info"]:
|
||||
if not bool(
|
||||
re.match(hex_pattern_one, config["bot_info"]["bot_color"])
|
||||
) and not bool(
|
||||
re.match(hex_pattern_two, config["bot_info"]["bot_color"])
|
||||
):
|
||||
LOG.critical(
|
||||
"Error in config.yaml file: bot_color is not a valid hex color"
|
||||
)
|
||||
else:
|
||||
BOT_COLOR = discord.Color(
|
||||
int((config["bot_info"]["bot_color"]).replace("#", ""), 16)
|
||||
)
|
||||
|
||||
if config["database"]["engine"] not in [
|
||||
"sqlite",
|
||||
"mysql",
|
||||
"postgresql",
|
||||
]:
|
||||
LOG.error(
|
||||
"database_engine must be either 'sqlite', 'mysql', or 'postgresql'"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
DB_ENGINE = config["database"]["engine"]
|
||||
|
||||
DB_NAME = config["database"]["name"]
|
||||
DB_HOST = config["database"]["host"]
|
||||
DB_PORT = config["database"]["port"]
|
||||
DB_USER = config["database"]["user"]
|
||||
DB_PASSWORD = config["database"]["password"]
|
||||
|
||||
TOKEN = config["bot_info"]["token"]
|
||||
|
||||
|
||||
"""
|
||||
Template for embeds
|
||||
"""
|
||||
|
||||
|
||||
def create_embed(
|
||||
title: str = None,
|
||||
description: str = None,
|
||||
color=None,
|
||||
footer=None,
|
||||
thumbnail=None,
|
||||
):
|
||||
embed = discord.Embed(
|
||||
title=title,
|
||||
description=description,
|
||||
color=color if color else BOT_COLOR,
|
||||
)
|
||||
|
||||
if footer:
|
||||
embed.set_footer(text=footer)
|
||||
# else:
|
||||
# embed.set_footer(
|
||||
# text=datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") + " UTC"
|
||||
# )
|
||||
|
||||
if thumbnail:
|
||||
embed.set_thumbnail(url=thumbnail)
|
||||
|
||||
return embed
|
21
database.py
Normal file
21
database.py
Normal file
@ -0,0 +1,21 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import os
|
||||
|
||||
import config
|
||||
|
||||
if config.DB_ENGINE == "mysql":
|
||||
database_url = f"mysql+pymysql://{config.DB_USER}:{config.DB_PASSWORD}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
|
||||
|
||||
elif config.DB_ENGINE == "postgresql":
|
||||
database_url = f"postgresql+psycopg2://{config.DB_USER}:{config.DB_PASSWORD}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
|
||||
|
||||
else:
|
||||
if not os.path.exists("data"):
|
||||
os.makedirs("data")
|
||||
database_url = "sqlite:///data/data.db"
|
||||
|
||||
engine = create_engine(database_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
17
models.py
Normal file
17
models.py
Normal file
@ -0,0 +1,17 @@
|
||||
from sqlalchemy import Column, Boolean, Integer, String, JSON
|
||||
|
||||
from database import Base
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
author_id = Column(Integer, nullable=False)
|
||||
channel_id = Column(Integer, nullable=False)
|
||||
stickers = Column(JSON, nullable=False, default=list)
|
||||
role_mentions = Column(JSON, nullable=False, default=list)
|
||||
mention_everyone = Column(Boolean, nullable=False, default=False)
|
||||
mentions = Column(JSON, nullable=False, default=list)
|
||||
attachments = Column(JSON, nullable=False, default=list)
|
||||
content = Column(String, nullable=False)
|
101
src/cogs/archive.py
Normal file
101
src/cogs/archive.py
Normal file
@ -0,0 +1,101 @@
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
import aiohttp
|
||||
import random
|
||||
import string
|
||||
import os
|
||||
|
||||
from src.utils.db import get_db
|
||||
from models import Message
|
||||
from config import BOT_COLOR
|
||||
|
||||
|
||||
class Archive(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
@app_commands.command()
|
||||
async def archive(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
channel: discord.TextChannel,
|
||||
amount: int,
|
||||
):
|
||||
"""Archive a channel's messages."""
|
||||
if not channel.permissions_for(
|
||||
interaction.guild.me
|
||||
).read_message_history:
|
||||
return await interaction.response.send_message(
|
||||
"I do not have permission to read message history in that"
|
||||
" channel.",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
if amount < 1:
|
||||
return await interaction.response.send_message(
|
||||
"You must provide a number greater than 0.",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
await interaction.response.send_message("Archiving messages now.")
|
||||
|
||||
db = next(get_db())
|
||||
count = 0
|
||||
messages = channel.history(limit=amount)
|
||||
async for message in messages:
|
||||
count += 1
|
||||
|
||||
author = message.author
|
||||
stickers = [sticker.name for sticker in message.stickers]
|
||||
role_mentions = [
|
||||
role_mention.id for role_mention in message.role_mentions
|
||||
]
|
||||
mention_everyone = message.mention_everyone
|
||||
mentions = [mention.id for mention in message.mentions]
|
||||
attachments = [
|
||||
attachment.url for attachment in message.attachments
|
||||
]
|
||||
|
||||
if not os.path.exists("images"):
|
||||
os.makedirs("images")
|
||||
|
||||
# Download all images before saving everything to database
|
||||
for url in attachments:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
# Create a randomized filename
|
||||
filename = "".join(
|
||||
random.choice(string.ascii_letters)
|
||||
for i in range(10)
|
||||
)
|
||||
|
||||
# Save the image to the filesystem
|
||||
with open(f"images/{filename}.jpg", "wb") as file:
|
||||
file.write(await response.read())
|
||||
|
||||
# Update the attachment URL to the new filename
|
||||
attachments[attachments.index(url)] = (
|
||||
f"images/{filename}.jpg"
|
||||
)
|
||||
|
||||
content = message.content
|
||||
|
||||
db_message = Message(
|
||||
author_id=author.id,
|
||||
channel_id=channel.id,
|
||||
stickers=stickers,
|
||||
role_mentions=role_mentions,
|
||||
mention_everyone=mention_everyone,
|
||||
mentions=mentions,
|
||||
attachments=attachments,
|
||||
content=content,
|
||||
)
|
||||
|
||||
db.add(db_message)
|
||||
db.commit()
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(Archive(bot))
|
87
src/cogs/owner/sync.py
Normal file
87
src/cogs/owner/sync.py
Normal file
@ -0,0 +1,87 @@
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
|
||||
class TreeSync(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
@commands.group(invoke_without_command=True)
|
||||
@commands.dm_only()
|
||||
@commands.is_owner()
|
||||
async def tree(self, ctx):
|
||||
await ctx.author.send(
|
||||
"This is a group command. Use either"
|
||||
f" `{self.bot.command_prefix}tree sync` or"
|
||||
f" `{self.bot.command_prefix}tree clear` followed by an optional"
|
||||
" guild ID."
|
||||
)
|
||||
|
||||
@commands.dm_only()
|
||||
@commands.is_owner()
|
||||
@tree.command()
|
||||
async def sync(
|
||||
self, ctx: commands.Context, *, guild: discord.Object = None
|
||||
):
|
||||
"""Sync the command tree to a guild or globally."""
|
||||
if guild:
|
||||
self.bot.tree.copy_global_to(guild=guild)
|
||||
await self.bot.tree.sync(guild=guild)
|
||||
return await ctx.author.send(
|
||||
"Synced the command tree to"
|
||||
f" `{self.bot.get_guild(guild.id).name}`"
|
||||
)
|
||||
else:
|
||||
await self.bot.tree.sync()
|
||||
return await ctx.author.send("Synced the command tree globally.")
|
||||
|
||||
@sync.error
|
||||
async def tree_sync_error(self, ctx, error):
|
||||
if isinstance(error, commands.ObjectNotFound):
|
||||
return await ctx.author.send(
|
||||
"The guild you provided does not exist."
|
||||
)
|
||||
if isinstance(error, commands.CommandInvokeError):
|
||||
return await ctx.author.send(
|
||||
"Guild ID provided is not a guild that the bot is in."
|
||||
)
|
||||
else:
|
||||
return await ctx.author.send(
|
||||
"An unknown error occurred. Perhaps you've been rate limited."
|
||||
)
|
||||
|
||||
@commands.dm_only()
|
||||
@commands.is_owner()
|
||||
@tree.command()
|
||||
async def clear(self, ctx: commands.Context, *, guild: discord.Object):
|
||||
"""Clear the command tree from a guild."""
|
||||
self.bot.tree.clear_commands(guild=guild)
|
||||
await self.bot.tree.sync(guild=guild)
|
||||
return await ctx.author.send(
|
||||
"Cleared the command tree from"
|
||||
f" `{self.bot.get_guild(guild.id).name}`"
|
||||
)
|
||||
|
||||
@clear.error
|
||||
async def tree_sync_error(self, ctx, error):
|
||||
if isinstance(error, commands.MissingRequiredArgument):
|
||||
return await ctx.author.send(
|
||||
"You need to provide a guild ID to clear the command tree"
|
||||
" from."
|
||||
)
|
||||
if isinstance(error, commands.ObjectNotFound):
|
||||
return await ctx.author.send(
|
||||
"The guild you provided does not exist."
|
||||
)
|
||||
if isinstance(error, commands.CommandInvokeError):
|
||||
return await ctx.author.send(
|
||||
"Guild ID provided is not a guild that the bot is in."
|
||||
)
|
||||
else:
|
||||
return await ctx.author.send(
|
||||
"An unknown error occurred. Perhaps you've been rate limited."
|
||||
)
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(TreeSync(bot))
|
12
src/schemas/message_schema.py
Normal file
12
src/schemas/message_schema.py
Normal file
@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
author_id: int
|
||||
channel_id: int
|
||||
stickers: list[str]
|
||||
role_mentions: list[int]
|
||||
mention_everyone: bool
|
||||
mentions: list[int]
|
||||
attachments: list[str]
|
||||
content: str
|
19
src/utils/command_tree.py
Normal file
19
src/utils/command_tree.py
Normal file
@ -0,0 +1,19 @@
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext.commands.errors import *
|
||||
|
||||
from config import create_embed
|
||||
|
||||
|
||||
class Tree(app_commands.CommandTree):
|
||||
async def on_error(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
error: app_commands.AppCommandError,
|
||||
):
|
||||
|
||||
if isinstance(error, CommandNotFound):
|
||||
return
|
||||
|
||||
else:
|
||||
raise error
|
9
src/utils/db.py
Normal file
9
src/utils/db.py
Normal file
@ -0,0 +1,9 @@
|
||||
from database import SessionLocal
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
Loading…
x
Reference in New Issue
Block a user