aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--code/ai_recommendations.py62
-rw-r--r--code/bot.py5
-rw-r--r--code/cogs/autoplay.py4
-rw-r--r--code/cogs/music.py2
-rw-r--r--code/config.py10
5 files changed, 47 insertions, 36 deletions
diff --git a/code/ai_recommendations.py b/code/ai_recommendations.py
index 9f73976..9c0831f 100644
--- a/code/ai_recommendations.py
+++ b/code/ai_recommendations.py
@@ -1,34 +1,38 @@
from lavalink import LoadType
+import re
-from config import CLIENT
+from config import OPENAI_API_KEY
+import openai
-async def add_song_recommendations(bot_user, player, number, inputs, retries: int = 1):
- input_string = ""
- for song, artist in inputs.items():
- input_string += f"{song} - {artist}, "
- # Remove the final ", "
- input_string = input_string[:-2]
+async def add_song_recommendations(openai_client, bot_user, player, number, inputs, retries: int = 1):
+ input_list = [f'"{song} by {artist}"' for song, artist in inputs.items()]
completion = (
- CLIENT.chat.completions.create(
+ openai_client.chat.completions.create(
messages=[
{
"role": "user",
- "content": f"""I need songs that are similar in nature to ones that I list.
- Send {number} songs formatted as:
+ "content": f"""
+ BACKGROUND: You're an AI music recommendation system with a knack for understanding
+ user preferences based on provided input. Your task is to generate a list
+ of {number} songs that the user might enjoy, derived from a given list of {number} songs.
+ The input will be in the format of
+ ["Song-1-Name by Song-1-Artist", "Song-2-Name by Song-2-Artist", ...]
+ and you need to return a list formatted in the same way.
- SONG NAME - ARTIST NAME
- SONG NAME - ARTIST NAME
- ...
+ When recommending songs, consider the genre, tempo, and mood of the input
+ songs to suggest similar ones that align with the user's tastes. Also, it
+ is important to mix up the artists, don't only give the same artists that
+ are already in the queue. If you cannot find {number} songs that match the
+ criteria or encounter any issues, return the list ["NOTHING FOUND"].
- Do not provide anything except for the exactly what I need, no
- list numbers, no quotations, only what I have shown.
+ Please be sure to also only use characters A-Z, a-z, 0-9, and spaces in the
+ song and artist names. Do not include escape/special characters, emojis, or
+ quotes in the output.
- The songs you should base the list off of are: {input_string}
-
- NOTE: If you believe that there are not many songs that are similar to the ones I list, then please just respond with the message "SONG FIND ERROR"
- """,
+ INPUT: {input_list}
+ """,
}
],
model="gpt-3.5-turbo",
@@ -38,28 +42,32 @@ async def add_song_recommendations(bot_user, player, number, inputs, retries: in
.strip('"')
)
- # Sometimes, we get false failures, so we check for a failure, and it we haven't tried
- # at least 3 times, then continue retrying, otherwise, we actually can't get any songs
- if completion == "SONG FIND ERROR":
+ # Sometimes ChatGPT will return `["NOTHING FOUND"]` even if it should
+ # have found something, so we check each prompt up to 3 times before
+ # giving up.
+ if completion == '["NOTHING FOUND"]':
if retries <= 3:
await add_song_recommendations(
- bot_user, player, number, inputs, retries + 1
+ openai_client, bot_user, player, number, inputs, retries + 1
)
else:
return False
else:
- for entry in completion.split("\n"):
- song, artist = entry.split(" - ")
+ # Clean up the completion string to remove any potential issues
+ # with the eval function (e.g. OUTPUT: prefix, escaped quotes, etc.)
+ completion = re.sub(r"[\\\'\[\]\n]+|OUTPUT: ", "", completion)
- ytsearch = f"ytsearch:{song} {artist} audio"
+ for entry in eval(completion):
+ song, artist = entry.split(" by ")
+ ytsearch = f"ytsearch:{song} by {artist} audio"
results = await player.node.get_tracks(ytsearch)
if not results.tracks or results.load_type in (
LoadType.EMPTY,
LoadType.ERROR,
):
- dzsearch = f"dzsearch:{song} {artist}"
+ dzsearch = f"dzsearch:{song}"
results = await player.node.get_tracks(dzsearch)
if not results.tracks or results.load_type in (
diff --git a/code/bot.py b/code/bot.py
index 6735522..da23a28 100644
--- a/code/bot.py
+++ b/code/bot.py
@@ -2,6 +2,7 @@ import discord
from discord.ext import commands, tasks
import os
import requests
+import openai
import config
from tree import Tree
@@ -32,6 +33,8 @@ class MyBot(commands.Bot):
if ext.endswith(".py"):
await self.load_extension(f"cogs.owner.{ext[:-3]}")
+ bot.openai = openai.OpenAI(api_key=config.OPENAI_API_KEY)
+
bot = MyBot()
bot.remove_command("help")
@@ -65,4 +68,4 @@ async def get_access_token():
if __name__ == "__main__":
config.load_config()
- bot.run(config.TOKEN)
+ bot.run(config.TOKEN) \ No newline at end of file
diff --git a/code/cogs/autoplay.py b/code/cogs/autoplay.py
index 4f5dc7a..0779e69 100644
--- a/code/cogs/autoplay.py
+++ b/code/cogs/autoplay.py
@@ -4,8 +4,8 @@ from discord import app_commands
from discord.ext import commands
from cogs.music import Music
from typing import Literal
-from ai_recommendations import add_song_recommendations
+from ai_recommendations import add_song_recommendations
from config import BOT_COLOR
@@ -67,7 +67,7 @@ class Autoplay(commands.Cog):
)
await interaction.response.send_message(embed=embed)
- if await add_song_recommendations(self.bot.user, player, 5, inputs):
+ if await add_song_recommendations(self.bot.openai, self.bot.user, player, 5, inputs):
self.bot.autoplay.append(interaction.guild.id)
embed = discord.Embed(
title=":infinity: Autoplay Enabled :infinity:",
diff --git a/code/cogs/music.py b/code/cogs/music.py
index d118643..9b12c4c 100644
--- a/code/cogs/music.py
+++ b/code/cogs/music.py
@@ -214,7 +214,7 @@ class Music(commands.Cog):
inputs = {}
for song in event.player.queue[:10]:
inputs[song.title] = song.author
- await add_song_recommendations(self.bot.user, event.player, 5, inputs)
+ await add_song_recommendations(self.bot.openai, self.bot.user, event.player, 5, inputs)
async def setup(bot):
diff --git a/code/config.py b/code/config.py
index 4da47f7..8f47e2f 100644
--- a/code/config.py
+++ b/code/config.py
@@ -31,7 +31,7 @@ FEEDBACK_CHANNEL_ID = None
BUG_CHANNEL_ID = None
SPOTIFY_CLIENT_ID = None
SPOTIFY_CLIENT_SECRET = None
-CLIENT = None
+OPENAI_API_KEY = None
LAVALINK_HOST = None
LAVALINK_PORT = None
LAVALINK_PASSWORD = None
@@ -96,7 +96,7 @@ Validate all of the options in the config.ini file.
def validate_config(file_contents):
- global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, CLIENT, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
+ global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, OPENAI_API_KEY, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
config = configparser.ConfigParser()
config.read_string(file_contents)
@@ -175,7 +175,7 @@ def validate_config(file_contents):
TOKEN = config["BOT_INFO"]["TOKEN"]
SPOTIFY_CLIENT_ID = config["SPOTIFY"]["SPOTIFY_CLIENT_ID"]
SPOTIFY_CLIENT_SECRET = config["SPOTIFY"]["SPOTIFY_CLIENT_SECRET"]
- CLIENT = openai.OpenAI(api_key=config["OPENAI"]["OPENAI_API_KEY"])
+ OPENAI_API_KEY = config["OPENAI"]["OPENAI_API_KEY"]
LAVALINK_HOST = config["LAVALINK"]["HOST"]
LAVALINK_PORT = config["LAVALINK"]["PORT"]
LAVALINK_PASSWORD = config["LAVALINK"]["PASSWORD"]
@@ -194,7 +194,7 @@ Validate all of the environment variables.
def validate_env_vars():
- global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, CLIENT, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
+ global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, OPENAI_API_KEY, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
hex_pattern_one = "^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
hex_pattern_two = "^([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
@@ -254,7 +254,7 @@ def validate_env_vars():
TOKEN = os.environ["TOKEN"]
SPOTIFY_CLIENT_ID = os.environ["SPOTIFY_CLIENT_ID"]
SPOTIFY_CLIENT_SECRET = os.environ["SPOTIFY_CLIENT_SECRET"]
- CLIENT = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
+ OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
LAVALINK_HOST = os.environ["LAVALINK_HOST"]
LAVALINK_PORT = os.environ["LAVALINK_PORT"]
LAVALINK_PASSWORD = os.environ["LAVALINK_PASSWORD"] \ No newline at end of file