Revamp OpenAI recommendations
This commit is contained in:
parent
be8e929be5
commit
c884b640f7
@ -1,34 +1,38 @@
|
|||||||
from lavalink import LoadType
|
from lavalink import LoadType
|
||||||
|
import re
|
||||||
|
|
||||||
from config import CLIENT
|
from config import OPENAI_API_KEY
|
||||||
|
import openai
|
||||||
|
|
||||||
|
|
||||||
async def add_song_recommendations(bot_user, player, number, inputs, retries: int = 1):
|
async def add_song_recommendations(openai_client, bot_user, player, number, inputs, retries: int = 1):
|
||||||
input_string = ""
|
input_list = [f'"{song} by {artist}"' for song, artist in inputs.items()]
|
||||||
for song, artist in inputs.items():
|
|
||||||
input_string += f"{song} - {artist}, "
|
|
||||||
# Remove the final ", "
|
|
||||||
input_string = input_string[:-2]
|
|
||||||
|
|
||||||
completion = (
|
completion = (
|
||||||
CLIENT.chat.completions.create(
|
openai_client.chat.completions.create(
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"""I need songs that are similar in nature to ones that I list.
|
"content": f"""
|
||||||
Send {number} songs formatted as:
|
BACKGROUND: You're an AI music recommendation system with a knack for understanding
|
||||||
|
user preferences based on provided input. Your task is to generate a list
|
||||||
|
of {number} songs that the user might enjoy, derived from a given list of {number} songs.
|
||||||
|
The input will be in the format of
|
||||||
|
["Song-1-Name by Song-1-Artist", "Song-2-Name by Song-2-Artist", ...]
|
||||||
|
and you need to return a list formatted in the same way.
|
||||||
|
|
||||||
SONG NAME - ARTIST NAME
|
When recommending songs, consider the genre, tempo, and mood of the input
|
||||||
SONG NAME - ARTIST NAME
|
songs to suggest similar ones that align with the user's tastes. Also, it
|
||||||
...
|
is important to mix up the artists, don't only give the same artists that
|
||||||
|
are already in the queue. If you cannot find {number} songs that match the
|
||||||
|
criteria or encounter any issues, return the list ["NOTHING FOUND"].
|
||||||
|
|
||||||
Do not provide anything except for the exactly what I need, no
|
Please be sure to also only use characters A-Z, a-z, 0-9, and spaces in the
|
||||||
list numbers, no quotations, only what I have shown.
|
song and artist names. Do not include escape/special characters, emojis, or
|
||||||
|
quotes in the output.
|
||||||
|
|
||||||
The songs you should base the list off of are: {input_string}
|
INPUT: {input_list}
|
||||||
|
""",
|
||||||
NOTE: If you believe that there are not many songs that are similar to the ones I list, then please just respond with the message "SONG FIND ERROR"
|
|
||||||
""",
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
@ -38,28 +42,32 @@ async def add_song_recommendations(bot_user, player, number, inputs, retries: in
|
|||||||
.strip('"')
|
.strip('"')
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sometimes, we get false failures, so we check for a failure, and it we haven't tried
|
# Sometimes ChatGPT will return `["NOTHING FOUND"]` even if it should
|
||||||
# at least 3 times, then continue retrying, otherwise, we actually can't get any songs
|
# have found something, so we check each prompt up to 3 times before
|
||||||
if completion == "SONG FIND ERROR":
|
# giving up.
|
||||||
|
if completion == '["NOTHING FOUND"]':
|
||||||
if retries <= 3:
|
if retries <= 3:
|
||||||
await add_song_recommendations(
|
await add_song_recommendations(
|
||||||
bot_user, player, number, inputs, retries + 1
|
openai_client, bot_user, player, number, inputs, retries + 1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for entry in completion.split("\n"):
|
# Clean up the completion string to remove any potential issues
|
||||||
song, artist = entry.split(" - ")
|
# with the eval function (e.g. OUTPUT: prefix, escaped quotes, etc.)
|
||||||
|
completion = re.sub(r"[\\\'\[\]\n]+|OUTPUT: ", "", completion)
|
||||||
|
|
||||||
ytsearch = f"ytsearch:{song} {artist} audio"
|
for entry in eval(completion):
|
||||||
|
song, artist = entry.split(" by ")
|
||||||
|
ytsearch = f"ytsearch:{song} by {artist} audio"
|
||||||
results = await player.node.get_tracks(ytsearch)
|
results = await player.node.get_tracks(ytsearch)
|
||||||
|
|
||||||
if not results.tracks or results.load_type in (
|
if not results.tracks or results.load_type in (
|
||||||
LoadType.EMPTY,
|
LoadType.EMPTY,
|
||||||
LoadType.ERROR,
|
LoadType.ERROR,
|
||||||
):
|
):
|
||||||
dzsearch = f"dzsearch:{song} {artist}"
|
dzsearch = f"dzsearch:{song}"
|
||||||
results = await player.node.get_tracks(dzsearch)
|
results = await player.node.get_tracks(dzsearch)
|
||||||
|
|
||||||
if not results.tracks or results.load_type in (
|
if not results.tracks or results.load_type in (
|
||||||
|
@ -2,6 +2,7 @@ import discord
|
|||||||
from discord.ext import commands, tasks
|
from discord.ext import commands, tasks
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
|
import openai
|
||||||
|
|
||||||
import config
|
import config
|
||||||
from tree import Tree
|
from tree import Tree
|
||||||
@ -32,6 +33,8 @@ class MyBot(commands.Bot):
|
|||||||
if ext.endswith(".py"):
|
if ext.endswith(".py"):
|
||||||
await self.load_extension(f"cogs.owner.{ext[:-3]}")
|
await self.load_extension(f"cogs.owner.{ext[:-3]}")
|
||||||
|
|
||||||
|
bot.openai = openai.OpenAI(api_key=config.OPENAI_API_KEY)
|
||||||
|
|
||||||
|
|
||||||
bot = MyBot()
|
bot = MyBot()
|
||||||
bot.remove_command("help")
|
bot.remove_command("help")
|
||||||
@ -65,4 +68,4 @@ async def get_access_token():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config.load_config()
|
config.load_config()
|
||||||
bot.run(config.TOKEN)
|
bot.run(config.TOKEN)
|
@ -4,8 +4,8 @@ from discord import app_commands
|
|||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from cogs.music import Music
|
from cogs.music import Music
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from ai_recommendations import add_song_recommendations
|
|
||||||
|
|
||||||
|
from ai_recommendations import add_song_recommendations
|
||||||
from config import BOT_COLOR
|
from config import BOT_COLOR
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ class Autoplay(commands.Cog):
|
|||||||
)
|
)
|
||||||
await interaction.response.send_message(embed=embed)
|
await interaction.response.send_message(embed=embed)
|
||||||
|
|
||||||
if await add_song_recommendations(self.bot.user, player, 5, inputs):
|
if await add_song_recommendations(self.bot.openai, self.bot.user, player, 5, inputs):
|
||||||
self.bot.autoplay.append(interaction.guild.id)
|
self.bot.autoplay.append(interaction.guild.id)
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title=":infinity: Autoplay Enabled :infinity:",
|
title=":infinity: Autoplay Enabled :infinity:",
|
||||||
|
@ -214,7 +214,7 @@ class Music(commands.Cog):
|
|||||||
inputs = {}
|
inputs = {}
|
||||||
for song in event.player.queue[:10]:
|
for song in event.player.queue[:10]:
|
||||||
inputs[song.title] = song.author
|
inputs[song.title] = song.author
|
||||||
await add_song_recommendations(self.bot.user, event.player, 5, inputs)
|
await add_song_recommendations(self.bot.openai, self.bot.user, event.player, 5, inputs)
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
|
@ -31,7 +31,7 @@ FEEDBACK_CHANNEL_ID = None
|
|||||||
BUG_CHANNEL_ID = None
|
BUG_CHANNEL_ID = None
|
||||||
SPOTIFY_CLIENT_ID = None
|
SPOTIFY_CLIENT_ID = None
|
||||||
SPOTIFY_CLIENT_SECRET = None
|
SPOTIFY_CLIENT_SECRET = None
|
||||||
CLIENT = None
|
OPENAI_API_KEY = None
|
||||||
LAVALINK_HOST = None
|
LAVALINK_HOST = None
|
||||||
LAVALINK_PORT = None
|
LAVALINK_PORT = None
|
||||||
LAVALINK_PASSWORD = None
|
LAVALINK_PASSWORD = None
|
||||||
@ -96,7 +96,7 @@ Validate all of the options in the config.ini file.
|
|||||||
|
|
||||||
|
|
||||||
def validate_config(file_contents):
|
def validate_config(file_contents):
|
||||||
global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, CLIENT, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
|
global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, OPENAI_API_KEY, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read_string(file_contents)
|
config.read_string(file_contents)
|
||||||
|
|
||||||
@ -175,7 +175,7 @@ def validate_config(file_contents):
|
|||||||
TOKEN = config["BOT_INFO"]["TOKEN"]
|
TOKEN = config["BOT_INFO"]["TOKEN"]
|
||||||
SPOTIFY_CLIENT_ID = config["SPOTIFY"]["SPOTIFY_CLIENT_ID"]
|
SPOTIFY_CLIENT_ID = config["SPOTIFY"]["SPOTIFY_CLIENT_ID"]
|
||||||
SPOTIFY_CLIENT_SECRET = config["SPOTIFY"]["SPOTIFY_CLIENT_SECRET"]
|
SPOTIFY_CLIENT_SECRET = config["SPOTIFY"]["SPOTIFY_CLIENT_SECRET"]
|
||||||
CLIENT = openai.OpenAI(api_key=config["OPENAI"]["OPENAI_API_KEY"])
|
OPENAI_API_KEY = config["OPENAI"]["OPENAI_API_KEY"]
|
||||||
LAVALINK_HOST = config["LAVALINK"]["HOST"]
|
LAVALINK_HOST = config["LAVALINK"]["HOST"]
|
||||||
LAVALINK_PORT = config["LAVALINK"]["PORT"]
|
LAVALINK_PORT = config["LAVALINK"]["PORT"]
|
||||||
LAVALINK_PASSWORD = config["LAVALINK"]["PASSWORD"]
|
LAVALINK_PASSWORD = config["LAVALINK"]["PASSWORD"]
|
||||||
@ -194,7 +194,7 @@ Validate all of the environment variables.
|
|||||||
|
|
||||||
|
|
||||||
def validate_env_vars():
|
def validate_env_vars():
|
||||||
global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, CLIENT, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
|
global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, OPENAI_API_KEY, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD
|
||||||
|
|
||||||
hex_pattern_one = "^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
|
hex_pattern_one = "^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
|
||||||
hex_pattern_two = "^([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
|
hex_pattern_two = "^([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$"
|
||||||
@ -254,7 +254,7 @@ def validate_env_vars():
|
|||||||
TOKEN = os.environ["TOKEN"]
|
TOKEN = os.environ["TOKEN"]
|
||||||
SPOTIFY_CLIENT_ID = os.environ["SPOTIFY_CLIENT_ID"]
|
SPOTIFY_CLIENT_ID = os.environ["SPOTIFY_CLIENT_ID"]
|
||||||
SPOTIFY_CLIENT_SECRET = os.environ["SPOTIFY_CLIENT_SECRET"]
|
SPOTIFY_CLIENT_SECRET = os.environ["SPOTIFY_CLIENT_SECRET"]
|
||||||
CLIENT = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
|
||||||
LAVALINK_HOST = os.environ["LAVALINK_HOST"]
|
LAVALINK_HOST = os.environ["LAVALINK_HOST"]
|
||||||
LAVALINK_PORT = os.environ["LAVALINK_PORT"]
|
LAVALINK_PORT = os.environ["LAVALINK_PORT"]
|
||||||
LAVALINK_PASSWORD = os.environ["LAVALINK_PASSWORD"]
|
LAVALINK_PASSWORD = os.environ["LAVALINK_PASSWORD"]
|
Loading…
x
Reference in New Issue
Block a user