From c884b640f77c14886c5c04e97c93c89c64fd7f1e Mon Sep 17 00:00:00 2001 From: Parker Date: Thu, 11 Jul 2024 00:43:06 -0500 Subject: [PATCH] Revamp OpenAI recommendations --- code/ai_recommendations.py | 62 +++++++++++++++++++++----------------- code/bot.py | 5 ++- code/cogs/autoplay.py | 4 +-- code/cogs/music.py | 2 +- code/config.py | 10 +++--- 5 files changed, 47 insertions(+), 36 deletions(-) diff --git a/code/ai_recommendations.py b/code/ai_recommendations.py index 9f73976..9c0831f 100644 --- a/code/ai_recommendations.py +++ b/code/ai_recommendations.py @@ -1,34 +1,38 @@ from lavalink import LoadType +import re -from config import CLIENT +from config import OPENAI_API_KEY +import openai -async def add_song_recommendations(bot_user, player, number, inputs, retries: int = 1): - input_string = "" - for song, artist in inputs.items(): - input_string += f"{song} - {artist}, " - # Remove the final ", " - input_string = input_string[:-2] +async def add_song_recommendations(openai_client, bot_user, player, number, inputs, retries: int = 1): + input_list = [f'"{song} by {artist}"' for song, artist in inputs.items()] completion = ( - CLIENT.chat.completions.create( + openai_client.chat.completions.create( messages=[ { "role": "user", - "content": f"""I need songs that are similar in nature to ones that I list. - Send {number} songs formatted as: + "content": f""" + BACKGROUND: You're an AI music recommendation system with a knack for understanding + user preferences based on provided input. Your task is to generate a list + of {number} songs that the user might enjoy, derived from a given list of {number} songs. + The input will be in the format of + ["Song-1-Name by Song-1-Artist", "Song-2-Name by Song-2-Artist", ...] + and you need to return a list formatted in the same way. - SONG NAME - ARTIST NAME - SONG NAME - ARTIST NAME - ... + When recommending songs, consider the genre, tempo, and mood of the input + songs to suggest similar ones that align with the user's tastes. Also, it + is important to mix up the artists, don't only give the same artists that + are already in the queue. If you cannot find {number} songs that match the + criteria or encounter any issues, return the list ["NOTHING FOUND"]. - Do not provide anything except for the exactly what I need, no - list numbers, no quotations, only what I have shown. + Please be sure to also only use characters A-Z, a-z, 0-9, and spaces in the + song and artist names. Do not include escape/special characters, emojis, or + quotes in the output. - The songs you should base the list off of are: {input_string} - - NOTE: If you believe that there are not many songs that are similar to the ones I list, then please just respond with the message "SONG FIND ERROR" - """, + INPUT: {input_list} + """, } ], model="gpt-3.5-turbo", @@ -38,28 +42,32 @@ async def add_song_recommendations(bot_user, player, number, inputs, retries: in .strip('"') ) - # Sometimes, we get false failures, so we check for a failure, and it we haven't tried - # at least 3 times, then continue retrying, otherwise, we actually can't get any songs - if completion == "SONG FIND ERROR": + # Sometimes ChatGPT will return `["NOTHING FOUND"]` even if it should + # have found something, so we check each prompt up to 3 times before + # giving up. + if completion == '["NOTHING FOUND"]': if retries <= 3: await add_song_recommendations( - bot_user, player, number, inputs, retries + 1 + openai_client, bot_user, player, number, inputs, retries + 1 ) else: return False else: - for entry in completion.split("\n"): - song, artist = entry.split(" - ") + # Clean up the completion string to remove any potential issues + # with the eval function (e.g. OUTPUT: prefix, escaped quotes, etc.) + completion = re.sub(r"[\\\'\[\]\n]+|OUTPUT: ", "", completion) - ytsearch = f"ytsearch:{song} {artist} audio" + for entry in eval(completion): + song, artist = entry.split(" by ") + ytsearch = f"ytsearch:{song} by {artist} audio" results = await player.node.get_tracks(ytsearch) if not results.tracks or results.load_type in ( LoadType.EMPTY, LoadType.ERROR, ): - dzsearch = f"dzsearch:{song} {artist}" + dzsearch = f"dzsearch:{song}" results = await player.node.get_tracks(dzsearch) if not results.tracks or results.load_type in ( diff --git a/code/bot.py b/code/bot.py index 6735522..da23a28 100644 --- a/code/bot.py +++ b/code/bot.py @@ -2,6 +2,7 @@ import discord from discord.ext import commands, tasks import os import requests +import openai import config from tree import Tree @@ -32,6 +33,8 @@ class MyBot(commands.Bot): if ext.endswith(".py"): await self.load_extension(f"cogs.owner.{ext[:-3]}") + bot.openai = openai.OpenAI(api_key=config.OPENAI_API_KEY) + bot = MyBot() bot.remove_command("help") @@ -65,4 +68,4 @@ async def get_access_token(): if __name__ == "__main__": config.load_config() - bot.run(config.TOKEN) + bot.run(config.TOKEN) \ No newline at end of file diff --git a/code/cogs/autoplay.py b/code/cogs/autoplay.py index 4f5dc7a..0779e69 100644 --- a/code/cogs/autoplay.py +++ b/code/cogs/autoplay.py @@ -4,8 +4,8 @@ from discord import app_commands from discord.ext import commands from cogs.music import Music from typing import Literal -from ai_recommendations import add_song_recommendations +from ai_recommendations import add_song_recommendations from config import BOT_COLOR @@ -67,7 +67,7 @@ class Autoplay(commands.Cog): ) await interaction.response.send_message(embed=embed) - if await add_song_recommendations(self.bot.user, player, 5, inputs): + if await add_song_recommendations(self.bot.openai, self.bot.user, player, 5, inputs): self.bot.autoplay.append(interaction.guild.id) embed = discord.Embed( title=":infinity: Autoplay Enabled :infinity:", diff --git a/code/cogs/music.py b/code/cogs/music.py index d118643..9b12c4c 100644 --- a/code/cogs/music.py +++ b/code/cogs/music.py @@ -214,7 +214,7 @@ class Music(commands.Cog): inputs = {} for song in event.player.queue[:10]: inputs[song.title] = song.author - await add_song_recommendations(self.bot.user, event.player, 5, inputs) + await add_song_recommendations(self.bot.openai, self.bot.user, event.player, 5, inputs) async def setup(bot): diff --git a/code/config.py b/code/config.py index 4da47f7..8f47e2f 100644 --- a/code/config.py +++ b/code/config.py @@ -31,7 +31,7 @@ FEEDBACK_CHANNEL_ID = None BUG_CHANNEL_ID = None SPOTIFY_CLIENT_ID = None SPOTIFY_CLIENT_SECRET = None -CLIENT = None +OPENAI_API_KEY = None LAVALINK_HOST = None LAVALINK_PORT = None LAVALINK_PASSWORD = None @@ -96,7 +96,7 @@ Validate all of the options in the config.ini file. def validate_config(file_contents): - global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, CLIENT, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD + global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, OPENAI_API_KEY, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD config = configparser.ConfigParser() config.read_string(file_contents) @@ -175,7 +175,7 @@ def validate_config(file_contents): TOKEN = config["BOT_INFO"]["TOKEN"] SPOTIFY_CLIENT_ID = config["SPOTIFY"]["SPOTIFY_CLIENT_ID"] SPOTIFY_CLIENT_SECRET = config["SPOTIFY"]["SPOTIFY_CLIENT_SECRET"] - CLIENT = openai.OpenAI(api_key=config["OPENAI"]["OPENAI_API_KEY"]) + OPENAI_API_KEY = config["OPENAI"]["OPENAI_API_KEY"] LAVALINK_HOST = config["LAVALINK"]["HOST"] LAVALINK_PORT = config["LAVALINK"]["PORT"] LAVALINK_PASSWORD = config["LAVALINK"]["PASSWORD"] @@ -194,7 +194,7 @@ Validate all of the environment variables. def validate_env_vars(): - global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, CLIENT, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD + global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, OPENAI_API_KEY, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD hex_pattern_one = "^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$" hex_pattern_two = "^([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$" @@ -254,7 +254,7 @@ def validate_env_vars(): TOKEN = os.environ["TOKEN"] SPOTIFY_CLIENT_ID = os.environ["SPOTIFY_CLIENT_ID"] SPOTIFY_CLIENT_SECRET = os.environ["SPOTIFY_CLIENT_SECRET"] - CLIENT = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] LAVALINK_HOST = os.environ["LAVALINK_HOST"] LAVALINK_PORT = os.environ["LAVALINK_PORT"] LAVALINK_PASSWORD = os.environ["LAVALINK_PASSWORD"] \ No newline at end of file