Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from textgrad import Variable, TextualGradientDescent, BlackboxLLM, sum
from textgrad.engine.base import EngineLM
from textgrad.engine.openai import ChatOpenAI
from textgrad.engine.openai import AzureChatOpenAI, ChatOpenAI
from textgrad.autograd import LLMCall, FormattedLLMCall

logging.disable(logging.CRITICAL)
Expand Down Expand Up @@ -247,4 +247,15 @@ def test_multimodal_from_url():
image_variable_2 = Variable(image_data,
role_description="image to answer a question about", requires_grad=False)

assert image_variable_2.value == image_variable.value
assert image_variable_2.value == image_variable.value

def test_azure_openai_engine():
if os.environ.get("OPENAI_API_KEY"):
os.environ.pop("OPENAI_API_KEY")

with pytest.raises(ValueError):
engine = AzureChatOpenAI()

os.environ['AZURE_OPENAI_API_KEY'] = "fake_key"
os.environ['AZURE_OPENAI_API_BASE'] = "fake_base"
engine = AzureChatOpenAI()
40 changes: 23 additions & 17 deletions textgrad/engine/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def __init__(
system_prompt: str=DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool=False,
base_url: str=None,
azure_openai: bool=False,
**kwargs):
"""
:param model_string:
:param system_prompt:
:param base_url: Used to support Ollama
:param azure_openai: Set to True if you use Azure OpenAI.
"""
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")
Expand All @@ -47,20 +49,21 @@ def __init__(
self.system_prompt = system_prompt
self.base_url = base_url

if not base_url:
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")

self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)
elif base_url and base_url == OLLAMA_BASE_URL:
self.client = OpenAI(
base_url=base_url,
api_key="ollama"
)
else:
raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.")
if not azure_openai:
Copy link
Copy Markdown
Collaborator

@vinid vinid Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now getting a bit messy, a possible way of cleaning this up is making a create_client factory that depends on a type in the input client (which defaults to openai).

So something like

__init__(..., client_type: str ...):

self.client = _factory_client_openai(client_type)

and have the checks in the factored classes

any thoughts?

if not base_url:
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")

self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)
elif base_url and base_url == OLLAMA_BASE_URL:
self.client = OpenAI(
base_url=base_url,
api_key="ollama"
)
else:
raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.")

self.model_string = model_string
self.is_multimodal = is_multimodal
Expand Down Expand Up @@ -184,11 +187,14 @@ def __init__(
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_azure_{model_string}.db") # Changed cache path to differentiate from OpenAI cache

super().__init__(cache_path=cache_path, system_prompt=system_prompt, **kwargs)
super().__init__(cache_path=cache_path,
system_prompt=system_prompt,
azure_openai=True,
**kwargs)

self.system_prompt = system_prompt
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2023-07-01-preview")
if os.getenv("AZURE_OPENAI_API_KEY") is None:
if (os.getenv("AZURE_OPENAI_API_KEY") is None) or (os.getenv("AZURE_OPENAI_API_BASE") is None):
raise ValueError("Please set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_BASE, and AZURE_OPENAI_API_VERSION environment variables if you'd like to use Azure OpenAI models.")

self.client = AzureOpenAI(
Expand All @@ -197,4 +203,4 @@ def __init__(
azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"),
azure_deployment=model_string,
)
self.model_string = model_string
self.model_string = model_string