diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index ba05d93140..bc813d34d3 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -6,7 +6,7 @@ from backend.chat.base import BaseChat from backend.chat.custom.tool_calls import async_call_tools from backend.chat.custom.utils import get_deployment -from backend.chat.enums import StreamEvent +from backend.chat.enums import FinishReason, StreamEvent from backend.config import Settings from backend.config.tools import get_available_tools from backend.database_models.file import File @@ -87,7 +87,7 @@ async def chat( ) yield { "event_type": StreamEvent.STREAM_END, - "finish_reason": "ERROR", + "finish_reason": FinishReason.ERROR, "error": str(e), "status_code": 500, } diff --git a/src/backend/chat/enums.py b/src/backend/chat/enums.py index 49042bbec0..35cdb77e51 100644 --- a/src/backend/chat/enums.py +++ b/src/backend/chat/enums.py @@ -16,3 +16,12 @@ class StreamEvent(StrEnum): NON_STREAMED_CHAT_RESPONSE = "non-streamed-chat-response" TOOL_CALLS_GENERATION = "tool-calls-generation" TOOL_CALLS_CHUNK = "tool-calls-chunk" + + +class FinishReason(StrEnum): + """ + Reasons why the model finished the request. + """ + ERROR = "ERROR" + COMPLETE = "COMPLETE" + MAX_TOKENS = "MAX_TOKENS" diff --git a/src/backend/schemas/chat.py b/src/backend/schemas/chat.py index 5882d5b81b..4646bea82f 100644 --- a/src/backend/schemas/chat.py +++ b/src/backend/schemas/chat.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field -from backend.chat.enums import StreamEvent +from backend.chat.enums import FinishReason, StreamEvent from backend.schemas.citation import Citation from backend.schemas.document import Document from backend.schemas.search_query import SearchQuery @@ -288,7 +288,7 @@ class StreamEnd(ChatResponse): title="Tool Calls", description="List of tool calls generated for custom tools", ) - finish_reason: Optional[str] = Field( + finish_reason: Optional[FinishReason] = Field( None, title="Finish Reason", description="Reson why the model finished the request", @@ -322,7 +322,7 @@ class NonStreamedChatResponse(ChatResponse): title="Chat History", description="A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message.", ) - finish_reason: str = Field( + finish_reason: FinishReason = Field( ..., title="Finish Reason", description="Reason the chat stream ended", diff --git a/src/backend/tests/integration/conftest.py b/src/backend/tests/integration/conftest.py index 96b017aa05..d4c2096712 100644 --- a/src/backend/tests/integration/conftest.py +++ b/src/backend/tests/integration/conftest.py @@ -9,6 +9,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import Session +from backend.chat.enums import FinishReason from backend.database_models import get_session from backend.database_models.agent import Agent from backend.database_models.deployment import Deployment @@ -204,7 +205,7 @@ def mock_event_stream(inject_events: list[dict]) -> list[dict]: "search_results": [], "search_queries": [], }, - "finish_reason": "COMPLETE", + "finish_reason": FinishReason.COMPLETE, } ]) return events diff --git a/src/backend/tests/unit/conftest.py b/src/backend/tests/unit/conftest.py index c8016aec81..4101a44d4d 100644 --- a/src/backend/tests/unit/conftest.py +++ b/src/backend/tests/unit/conftest.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import text +from backend.chat.enums import FinishReason from backend.database_models import get_session from backend.database_models.base import CustomFilterQuery from backend.database_models.deployment import Deployment @@ -252,7 +253,7 @@ def mock_event_stream(inject_events: list[dict]) -> list[dict]: "search_results": [], "search_queries": [], }, - "finish_reason": "COMPLETE", + "finish_reason": FinishReason.COMPLETE, } ]) return events diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py index f4531c0a54..143773c6fe 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py @@ -2,6 +2,7 @@ from cohere.types import StreamedChatResponse +from backend.chat.enums import FinishReason from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( @@ -51,7 +52,7 @@ async def invoke_chat( "is_search_required": None, "search_queries": None, "search_results": None, - "finish_reason": "MAX_TOKENS", + "finish_reason": FinishReason.MAX_TOKENS, "tool_calls": None, "chat_history": [ {"role": "USER", "message": "Hello"}, diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py index 8eed1b1fd5..3ee062c987 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py @@ -2,6 +2,7 @@ from cohere.types import StreamedChatResponse +from backend.chat.enums import FinishReason from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( @@ -48,7 +49,7 @@ async def invoke_chat( "is_search_required": None, "search_queries": None, "search_results": None, - "finish_reason": "MAX_TOKENS", + "finish_reason": FinishReason.MAX_TOKENS, "tool_calls": None, "chat_history": [ {"role": "USER", "message": "Hello"}, diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py index d97f9d1cee..13d771ce5c 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py @@ -3,6 +3,7 @@ from cohere.types import StreamedChatResponse +from backend.chat.enums import FinishReason from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context from backend.services.conversation import SEARCH_RELEVANCE_THRESHOLD @@ -55,7 +56,7 @@ def invoke_chat( "is_search_required": None, "search_queries": None, "search_results": None, - "finish_reason": "MAX_TOKENS", + "finish_reason": FinishReason.MAX_TOKENS, "tool_calls": None, "chat_history": [ {"role": "USER", "message": "Hello"}, diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py index b9266d1292..0693aa8e49 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py @@ -2,6 +2,7 @@ from cohere.types import StreamedChatResponse +from backend.chat.enums import FinishReason from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( @@ -48,7 +49,7 @@ async def invoke_chat( "is_search_required": None, "search_queries": None, "search_results": None, - "finish_reason": "MAX_TOKENS", + "finish_reason": FinishReason.MAX_TOKENS, "tool_calls": None, "chat_history": [ {"role": "USER", "message": "Hello"}, diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index d575e56db1..c5b4c77aee 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -6,7 +6,7 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from backend.chat.enums import StreamEvent +from backend.chat.enums import FinishReason, StreamEvent from backend.database_models.conversation import Conversation from backend.database_models.message import Message, MessageAgent from backend.database_models.user import User @@ -1110,7 +1110,7 @@ def validate_stream_end_event( assert is_valid_uuid(data["response_id"]) assert is_valid_uuid(data["conversation_id"]) assert is_valid_uuid(data["generation_id"]) - assert data["finish_reason"] == "COMPLETE" or data["finish_reason"] == "MAX_TOKENS" + assert data["finish_reason"] == FinishReason.COMPLETE or data["finish_reason"] == FinishReason.MAX_TOKENS return data["conversation_id"] diff --git a/src/interfaces/assistants_web/src/cohere-client/constants.ts b/src/interfaces/assistants_web/src/cohere-client/constants.ts index 3ff1f4dfea..a25c91ee3f 100644 --- a/src/interfaces/assistants_web/src/cohere-client/constants.ts +++ b/src/interfaces/assistants_web/src/cohere-client/constants.ts @@ -1,4 +1,5 @@ -// @todo: import from generated types when available +// @todo: Once backend FinishReason enum is merged, run `make generate-client-web` +// and remove this enum in favor of the generated type from backend export enum FinishReason { ERROR = 'ERROR', COMPLETE = 'COMPLETE',