diff --git a/3p-integrations/using_externally_hosted_llms.ipynb b/3p-integrations/using_externally_hosted_llms.ipynb index 37d628459..ed6343e44 100644 --- a/3p-integrations/using_externally_hosted_llms.ipynb +++ b/3p-integrations/using_externally_hosted_llms.ipynb @@ -12,10 +12,11 @@ "metadata": {}, "source": [ "# **Using externally-hosted LLMs**\n", - "Use llama_cookbook.inference.llm to perform inference using Llama and other models using third party services. At the moment, three services have been incorporated:\n", + "Use llama_cookbook.inference.llm to perform inference using Llama and other models using third party services. At the moment, the following services have been incorporated:\n", "- Together.ai\n", "- Anyscale\n", "- OpenAI\n", + "- MiniMax\n", "\n", "An API token for each service must be obtained and provided to the method before running. " ] @@ -26,7 +27,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_cookbook.inference.llm import TOGETHER, OPENAI, ANYSCALE\n", + "from llama_cookbook.inference.llm import TOGETHER, OPENAI, ANYSCALE, MINIMAX\n", "\n", "together_example = TOGETHER(\"togethercomputer/llama-2-7b-chat\",\"09e45...\")\n", "print( together_example.query(prompt=\"Why is the sky blue?\"))\n", @@ -37,7 +38,11 @@ "\n", "\n", "anyscale_example = ANYSCALE(\"meta-llama/Llama-2-7b-chat-hf\",\"esecret_c3u4x7...\")\n", - "print( anyscale_example.query(prompt=\"Why is the sky blue?\"))" + "print( anyscale_example.query(prompt=\"Why is the sky blue?\"))\n", + "\n", + "\n", + "minimax_example = MINIMAX(\"MiniMax-M2.7\",\"eyJhbGci...\")\n", + "print( minimax_example.query(prompt=\"Why is the sky blue?\"))" ] } ], @@ -63,4 +68,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/src/llama_cookbook/inference/llm.py b/src/llama_cookbook/inference/llm.py index 9b6e9fefc..fe56fa1e0 100644 --- a/src/llama_cookbook/inference/llm.py +++ b/src/llama_cookbook/inference/llm.py @@ -157,3 +157,38 @@ def valid_models(self) -> list[str]: "mistralai/Mistral-7B-Instruct-v0.1", "HuggingFaceH4/zephyr-7b-beta", ] + + +class MINIMAX(LLM): + """Accessing MiniMax via OpenAI-compatible API (https://www.minimaxi.com)""" + + def __init__(self, model: str, api_key: str) -> None: + super().__init__(model, api_key) + self.client = openai.OpenAI(base_url="https://api.minimax.io/v1", api_key=api_key) # noqa + + @override + def query(self, prompt: str) -> str: + # Best-level effort to suppress openai log-spew. + # Likely not work well in multi-threaded environment. + level = logging.getLogger().level + logging.getLogger().setLevel(logging.WARNING) + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "user", "content": prompt}, + ], + max_tokens=MAX_TOKENS, + temperature=TEMPERATURE if TEMPERATURE > 0 else 0.01, + ) + logging.getLogger().setLevel(level) + return response.choices[0].message.content + + @override + def valid_models(self) -> list[str]: + return [ + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + "MiniMax-M2.5", + "MiniMax-M1", + "MiniMax-M1-80k", + ]