Revamp OpenAI recommendations

This commit is contained in:
Parker M. 2024-07-11 00:43:06 -05:00
parent be8e929be5
commit c884b640f7
No known key found for this signature in database
GPG Key ID: 95CD2E0C7E329F2A
5 changed files with 47 additions and 36 deletions

View File

@ -1,34 +1,38 @@
from lavalink import LoadType
import re
from config import CLIENT
from config import OPENAI_API_KEY
import openai
async def add_song_recommendations(bot_user, player, number, inputs, retries: int = 1):
input_string = ""
for song, artist in inputs.items():
input_string += f"{song} - {artist}, "
# Remove the final ", "
input_string = input_string[:-2]
async def add_song_recommendations(openai_client, bot_user, player, number, inputs, retries: int = 1):
input_list = [f'"{song} by {artist}"' for song, artist in inputs.items()]
completion = (
CLIENT.chat.completions.create(
openai_client.chat.completions.create(
messages=[
{
"role": "user",
"content": f"""I need songs that are similar in nature to ones that I list.
Send {number} songs formatted as:
"content": f"""
BACKGROUND: You're an AI music recommendation system with a knack for understanding
user preferences based on provided input. Your task is to generate a list
of {number} songs that the user might enjoy, derived from a given list of {number} songs.
The input will be in the format of
["Song-1-Name by Song-1-Artist", "Song-2-Name by Song-2-Artist", ...]
and you need to return a list formatted in the same way.
SONG NAME - ARTIST NAME
SONG NAME - ARTIST NAME
...
When recommending songs, consider the genre, tempo, and mood of the input
songs to suggest similar ones that align with the user's tastes. Also, it
is important to mix up the artists, don't only give the same artists that
are already in the queue. If you cannot find {number} songs that match the
criteria or encounter any issues, return the list ["NOTHING FOUND"].
Do not provide anything except for the exactly what I need, no
list numbers, no quotations, only what I have shown.
Please be sure to also only use characters A-Z, a-z, 0-9, and spaces in the
song and artist names. Do not include escape/special characters, emojis, or
quotes in the output.
The songs you should base the list off of are: {input_string}
NOTE: If you believe that there are not many songs that are similar to the ones I list, then please just respond with the message "SONG FIND ERROR"
""",
INPUT: {input_list}
""",
}
],
model="gpt-3.5-turbo",
@ -38,28 +42,32 @@ async def add_song_recommendations(bot_user, player, number, inputs, retries: in
.strip('"')
)
# Sometimes, we get false failures, so we check for a failure, and it we haven't tried
# at least 3 times, then continue retrying, otherwise, we actually can't get any songs
if completion == "SONG FIND ERROR":
# Sometimes ChatGPT will return `["NOTHING FOUND"]` even if it should
# have found something, so we check each prompt up to 3 times before
# giving up.
if completion == '["NOTHING FOUND"]':
if retries <= 3:
await add_song_recommendations(
bot_user, player, number, inputs, retries + 1
openai_client, bot_user, player, number, inputs, retries + 1
)
else:
return False
else:
for entry in completion.split("\n"):
song, artist = entry.split(" - ")
# Clean up the completion string to remove any potential issues
# with the eval function (e.g. OUTPUT: prefix, escaped quotes, etc.)
completion = re.sub(r"[\\\'\[\]\n]+|OUTPUT: ", "", completion)
ytsearch = f"ytsearch:{song} {artist} audio"
for entry in eval(completion):
song, artist = entry.split(" by ")
ytsearch = f"ytsearch:{song} by {artist} audio"
results = await player.node.get_tracks(ytsearch)
if not results.tracks or results.load_type in (
LoadType.EMPTY,
LoadType.ERROR,
):
dzsearch = f"dzsearch:{song} {artist}"
dzsearch = f"dzsearch:{song}"
results = await player.node.get_tracks(dzsearch)
if not results.tracks or results.load_type in (

View File

@ -2,6 +2,7 @@ import discord
from discord.ext import commands, tasks
import os
import requests
import openai
import config
from tree import Tree
@ -32,6 +33,8 @@ class MyBot(commands.Bot):
if ext.endswith(".py"):
await self.load_extension(f"cogs.owner.{ext[:-3]}")
bot.openai = openai.OpenAI(api_key=config.OPENAI_API_KEY)
bot = MyBot()
bot.remove_command("help")
@ -65,4 +68,4 @@ async def get_access_token():
if __name__ == "__main__":
config.load_config()
bot.run(config.TOKEN)
bot.run(config.TOKEN)

View File

@ -4,8 +4,8 @@ from discord import app_commands
from discord.ext import commands
from cogs.music import Music
from typing import Literal
from ai_recommendations import add_song_recommendations
from ai_recommendations import add_song_recommendations
from config import BOT_COLOR
@ -67,7 +67,7 @@ class Autoplay(commands.Cog):
)
await interaction.response.send_message(embed=embed)
if await add_song_recommendations(self.bot.user, player, 5, inputs):
if await add_song_recommendations(self.bot.openai, self.bot.user, player, 5, inputs):
self.bot.autoplay.append(interaction.guild.id)
embed = discord.Embed(
title=":infinity: Autoplay Enabled :infinity:",

View File

@ -214,7 +214,7 @@ class Music(commands.Cog):
inputs = {}
for song in event.player.queue[:10]:
inputs[song.title] = song.author
await add_song_recommendations(self.bot.user, event.player, 5, inputs)
await add_song_recommendations(self.bot.openai, self.bot.user, event.player, 5, inputs)
async def setup(bot):

View File

@ -31,7 +31,7 @@ FEEDBACK_CHANNEL_ID = None
BUG_CHANNEL_ID = None
SPOTIFY_CLIENT_ID = None
SPOTIFY_CLIENT_SECRET = None
CLIENT = None
OPENAI_API_KEY = None
LAVALINK_HOST = None
LAVALINK_PORT = None
LAVALINK_PASSWORD = None
@ -96,7 +96,7 @@ Validate all of the options in the config.ini file.
def validate_config(file_contents):
global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, CLIENT, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, OPENAI_API_KEY, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
config = configparser.ConfigParser()
config.read_string(file_contents)
@ -175,7 +175,7 @@ def validate_config(file_contents):
TOKEN = config["BOT_INFO"]["TOKEN"]
SPOTIFY_CLIENT_ID = config["SPOTIFY"]["SPOTIFY_CLIENT_ID"]
SPOTIFY_CLIENT_SECRET = config["SPOTIFY"]["SPOTIFY_CLIENT_SECRET"]
CLIENT = openai.OpenAI(api_key=config["OPENAI"]["OPENAI_API_KEY"])
OPENAI_API_KEY = config["OPENAI"]["OPENAI_API_KEY"]
LAVALINK_HOST = config["LAVALINK"]["HOST"]
LAVALINK_PORT = config["LAVALINK"]["PORT"]
LAVALINK_PASSWORD = config["LAVALINK"]["PASSWORD"]
@ -194,7 +194,7 @@ Validate all of the environment variables.
def validate_env_vars():
global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, CLIENT, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, OPENAI_API_KEY, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
hex_pattern_one = "^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
hex_pattern_two = "^([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
@ -254,7 +254,7 @@ def validate_env_vars():
TOKEN = os.environ["TOKEN"]
SPOTIFY_CLIENT_ID = os.environ["SPOTIFY_CLIENT_ID"]
SPOTIFY_CLIENT_SECRET = os.environ["SPOTIFY_CLIENT_SECRET"]
CLIENT = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
LAVALINK_HOST = os.environ["LAVALINK_HOST"]
LAVALINK_PORT = os.environ["LAVALINK_PORT"]
LAVALINK_PASSWORD = os.environ["LAVALINK_PASSWORD"]