Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions 3p-integrations/using_externally_hosted_llms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
"metadata": {},
"source": [
"# **Using externally-hosted LLMs**\n",
"Use llama_cookbook.inference.llm to perform inference using Llama and other models using third party services. At the moment, three services have been incorporated:\n",
"Use llama_cookbook.inference.llm to perform inference using Llama and other models using third party services. At the moment, the following services have been incorporated:\n",
"- Together.ai\n",
"- Anyscale\n",
"- OpenAI\n",
"- MiniMax\n",
"\n",
"An API token for each service must be obtained and provided to the method before running. "
]
Expand All @@ -26,7 +27,7 @@
"metadata": {},
"outputs": [],
"source": [
"from llama_cookbook.inference.llm import TOGETHER, OPENAI, ANYSCALE\n",
"from llama_cookbook.inference.llm import TOGETHER, OPENAI, ANYSCALE, MINIMAX\n",
"\n",
"together_example = TOGETHER(\"togethercomputer/llama-2-7b-chat\",\"09e45...\")\n",
"print( together_example.query(prompt=\"Why is the sky blue?\"))\n",
Expand All @@ -37,7 +38,11 @@
"\n",
"\n",
"anyscale_example = ANYSCALE(\"meta-llama/Llama-2-7b-chat-hf\",\"esecret_c3u4x7...\")\n",
"print( anyscale_example.query(prompt=\"Why is the sky blue?\"))"
"print( anyscale_example.query(prompt=\"Why is the sky blue?\"))\n",
"\n",
"\n",
"minimax_example = MINIMAX(\"MiniMax-M2.7\",\"eyJhbGci...\")\n",
"print( minimax_example.query(prompt=\"Why is the sky blue?\"))"
]
}
],
Expand All @@ -63,4 +68,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
35 changes: 35 additions & 0 deletions src/llama_cookbook/inference/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,38 @@ def valid_models(self) -> list[str]:
"mistralai/Mistral-7B-Instruct-v0.1",
"HuggingFaceH4/zephyr-7b-beta",
]


class MINIMAX(LLM):
"""Accessing MiniMax via OpenAI-compatible API (https://www.minimaxi.com)"""

def __init__(self, model: str, api_key: str) -> None:
super().__init__(model, api_key)
self.client = openai.OpenAI(base_url="https://api.minimax.io/v1", api_key=api_key) # noqa

@override
def query(self, prompt: str) -> str:
# Best-level effort to suppress openai log-spew.
# Likely not work well in multi-threaded environment.
level = logging.getLogger().level
logging.getLogger().setLevel(logging.WARNING)
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "user", "content": prompt},
],
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE if TEMPERATURE > 0 else 0.01,
)
logging.getLogger().setLevel(level)
return response.choices[0].message.content

@override
def valid_models(self) -> list[str]:
return [
"MiniMax-M2.7",
"MiniMax-M2.7-highspeed",
"MiniMax-M2.5",
"MiniMax-M1",
"MiniMax-M1-80k",
]