diff options
author | Parker <contact@pkrm.dev> | 2025-02-17 21:24:31 -0600 |
---|---|---|
committer | Parker <contact@pkrm.dev> | 2025-02-17 21:24:31 -0600 |
commit | 06034d0b373a9aed5033c2e670950f765e285c2a (patch) | |
tree | 9121832c8535afae443cf5d581e302af8d0fae7b /code/utils | |
parent | b335e82699bc177c689450ee2f732398cdd372ac (diff) |
Support OpenAI AND Groq
Diffstat (limited to 'code/utils')
-rw-r--r-- | code/utils/ai_recommendations.py | 54 | ||||
-rw-r--r-- | code/utils/config.py | 37 |
2 files changed, 56 insertions, 35 deletions
diff --git a/code/utils/ai_recommendations.py b/code/utils/ai_recommendations.py index 1ff5415..14c61c6 100644 --- a/code/utils/ai_recommendations.py +++ b/code/utils/ai_recommendations.py @@ -1,40 +1,52 @@ from lavalink import LoadType import re +from utils.config import AI_CLIENT, AI_MODEL + async def add_song_recommendations( - openai_client, bot_user, player, number, inputs, retries: int = 1 + bot_user, player, number, inputs, retries: int = 1 ): input_list = [f'"{song} by {artist}"' for song, artist in inputs.items()] completion = ( - openai_client.chat.completions.create( + AI_CLIENT.chat.completions.create( messages=[ { - "role": "user", + "role": "system", "content": f""" - BACKGROUND: You're an AI music recommendation system with a knack for understanding - user preferences based on provided input. Your task is to generate a list - of {number} songs that the user might enjoy, derived from a given list of {number} songs. - The input will be in the format of - ["Song-1-Name by Song-1-Artist", "Song-2-Name by Song-2-Artist", ...] - and you need to return a list formatted in the same way. + Given an input list of songs formatted as ["song_name + by artist_name", "song_name by artist_name", ...], generate + a list of 5 new songs that the user may enjoy based on + the input. - When recommending songs, consider the genre, tempo, and mood of the input - songs to suggest similar ones that align with the user's tastes. Also, it - is important to mix up the artists, don't only give the same artists that - are already in the queue. If you cannot find {number} songs that match the - criteria or encounter any issues, return the list ["NOTHING FOUND"]. + Thoroughly analyze each song in the input list, considering + factors such as tempo, beat, mood, genre, lyrical themes, + instrumentation, and overall meaning. Use this analysis to + recommend 5 songs that closely align with the user's musical + preferences. - Please be sure to also only use characters A-Z, a-z, 0-9, and spaces in the - song and artist names. Do not include escape/special characters, emojis, or - quotes in the output. + The output must be formatted in the exact same way: + ["song_name by artist_name", "song_name by artist_name", ...]. - INPUT: {input_list} + If you are unable to find 5 new songs or encounter any issues, + return the following list instead: ["NOTHING_FOUND"]. Do + not return partial results—either provide 5 songs or return + ["NOTHING_FOUND"]. Ensure accuracy in song and artist names. + + DO NOT include any additional information or text in the + output, it should STRICTLY be either a list of the songs + or ["NOTHING_FOUND"]. + """, + }, + { + "role": "user", + "content": f""" + {input_list} """, - } + }, ], - model="gpt-4o-mini", + model=AI_MODEL, ) .choices[0] .message.content.strip() @@ -47,7 +59,7 @@ async def add_song_recommendations( if completion == '["NOTHING FOUND"]': if retries <= 3: await add_song_recommendations( - openai_client, bot_user, player, number, inputs, retries + 1 + bot_user, player, number, inputs, retries + 1 ) else: return False diff --git a/code/utils/config.py b/code/utils/config.py index bacf3fc..19fed87 100644 --- a/code/utils/config.py +++ b/code/utils/config.py @@ -8,7 +8,7 @@ import sys import discord import logging import requests -from datetime import datetime +from groq import Groq from colorlog import ColoredFormatter log_level = logging.DEBUG @@ -39,7 +39,8 @@ SPOTIFY_CLIENT_ID = None SPOTIFY_CLIENT_SECRET = None GENIUS_CLIENT_ID = None GENIUS_CLIENT_SECRET = None -OPENAI_API_KEY = None +AI_CLIENT = None +AI_MODEL = None LAVALINK_HOST = None LAVALINK_PORT = None LAVALINK_PASSWORD = None @@ -82,12 +83,13 @@ schema = { }, "required": ["genius_client_id", "genius_client_secret"], }, - "openai": { + "ai": { "type": "object", "properties": { - "openai_api_key": {"type": "string"}, + "service": {"enum": ["openai", "groq"]}, + "api_key": {"type": "string"}, }, - "required": ["openai_api_key"], + "required": ["service"], }, "lavalink": { "type": "object", @@ -144,9 +146,9 @@ genius: genius_client_id: genius_client_secret: -openai: - openai_api_key: - """ +ai: + service: + api_key: """ ) sys.exit( @@ -160,7 +162,7 @@ openai: # Thouroughly validate all of the options in the config.yaml file def validate_config(file_contents): - global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, LOG_SONGS, YOUTUBE_SUPPORT, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, GENIUS_CLIENT_ID, GENIUS_CLIENT_SECRET, OPENAI_API_KEY, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD + global TOKEN, BOT_COLOR, BOT_INVITE_LINK, FEEDBACK_CHANNEL_ID, BUG_CHANNEL_ID, LOG_SONGS, YOUTUBE_SUPPORT, SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, GENIUS_CLIENT_ID, GENIUS_CLIENT_SECRET, AI_CLIENT, AI_MODEL, LAVALINK_HOST, LAVALINK_PORT, LAVALINK_PASSWORD config = yaml.safe_load(file_contents) try: @@ -270,17 +272,24 @@ def validate_config(file_contents): ) # - # If the OPENAI section is present, make sure the API key is valid + # If the AI section is present, make sure the API key is valid # - if "openai" in config: - client = openai.OpenAI(api_key=config["openai"]["openai_api_key"]) + if "ai" in config: + if config["ai"]["service"] == "openai": + client = openai.OpenAI(api_key=config["ai"]["api_key"]) + model = "gpt-4o-mini" + elif config["ai"]["service"] == "groq": + client = Groq(api_key=config["ai"]["api_key"]) + model = "llama-3.3-70b-specdec" + try: client.models.list() - OPENAI_API_KEY = config["openai"]["openai_api_key"] + AI_CLIENT = client + AI_MODEL = model except openai.AuthenticationError: LOG.critical( - "Error in config.yaml file: OpenAI API key is invalid" + "Error in config.yaml file: OpenAI/Groq API key is invalid" ) # Set appropriate values for all non-optional variables |