Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic import BaseModel
from typing_extensions import Self

from railtracks.exceptions.errors import LLMError, NodeInvocationError
from railtracks.exceptions.errors import NodeInvocationError
from railtracks.exceptions.messages.exception_messages import get_message
from railtracks.llm import (
Message,
Expand All @@ -26,6 +26,7 @@
SystemMessage,
UserMessage,
)
from railtracks.llm.errors import LLMError
from railtracks.llm.response import Response
from railtracks.nodes.nodes import Node
from railtracks.prompts.prompt import inject_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

from railtracks.built_nodes.concrete.response import LLMResponse
from railtracks.exceptions import LLMError, NodeCreationError
from railtracks.exceptions import NodeCreationError
from railtracks.interaction._call import call
from railtracks.llm import (
AssistantMessage,
Expand All @@ -28,6 +28,7 @@
UserMessage,
)
from railtracks.llm.content import Content
from railtracks.llm.errors import LLMError
from railtracks.llm.message import Role
from railtracks.llm.providers import ModelProvider
from railtracks.llm.response import Response
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
from abc import ABC

from railtracks.exceptions import LLMError
from railtracks.interaction import call
from railtracks.llm import (
AssistantMessage,
Expand All @@ -10,6 +9,7 @@
ToolResponse,
UserMessage,
)
from railtracks.llm.errors import LLMError
from railtracks.llm.message import Role

from ._llm_base import StringOutputMixIn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from pydantic import BaseModel

from railtracks.exceptions.errors import LLMError
from railtracks.llm import Message, MessageHistory, ModelBase, UserMessage
from railtracks.llm.errors import LLMError
from railtracks.validation.node_creation.validation import (
check_classmethod,
check_schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from pydantic import BaseModel

from railtracks.exceptions.errors import LLMError
from railtracks.interaction import call
from railtracks.llm import (
AssistantMessage,
Expand All @@ -12,6 +11,7 @@
ModelBase,
UserMessage,
)
from railtracks.llm.errors import LLMError

from ._llm_base import StructuredOutputMixIn
from ._tool_call_base import OutputLessToolCallLLM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC
from typing import Generator, Generic, Literal, TypeVar

from railtracks.exceptions import LLMError
from railtracks.llm.errors import LLMError
from railtracks.llm.response import Response

from ._llm_base import LLMBase, StringOutputMixIn
Expand Down
2 changes: 0 additions & 2 deletions packages/railtracks/src/railtracks/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
ContextError,
FatalError,
GlobalTimeOutError,
LLMError,
NodeCreationError,
NodeInvocationError,
)
Expand All @@ -13,7 +12,6 @@
"NodeCreationError",
"NodeInvocationError",
"GlobalTimeOutError",
"LLMError",
"ContextError",
"VisualExtraRequiredError",
]
43 changes: 0 additions & 43 deletions packages/railtracks/src/railtracks/exceptions/errors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from typing import TYPE_CHECKING

from ._base import RTError

if TYPE_CHECKING:
from railtracks.llm.history import MessageHistory


class NodeInvocationError(RTError):
"""
Expand Down Expand Up @@ -55,44 +50,6 @@ def __str__(self):
return self._color(base, self.RED)


class LLMError(RTError):
"""
Raised when an error occurs during LLM invocation or completion.
"""

def __init__(
self,
reason: str,
message_history: "MessageHistory" = None,
):
self.reason = reason
self.message_history = message_history

message = f"{self._color('LLM Error: ', self.BOLD_RED)}{self._color(reason, self.RED)}"
super().__init__(message)

def __str__(self):
base = super().__str__()
details = []
if self.message_history:
mh_str = str(self.message_history)
indented_mh = "\n".join(
" " + line for line in mh_str.splitlines()
) # 2 indents (2-spaces) per indent
details.append(
self._color("Message History:\n", self.BOLD_GREEN)
+ self._color(indented_mh, self.GREEN)
)
if details:
notes_str = (
"\n"
+ self._color("Details:\n", self.BOLD_GREEN)
+ "\n".join(f" {d}" for d in details)
)
return f"\n{self._color(base, self.RED)}{notes_str}"
return self._color(base, self.RED)


class GlobalTimeOutError(RTError):
"""
Raised on global timeout for whole execution.
Expand Down
2 changes: 2 additions & 0 deletions packages/railtracks/src/railtracks/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import retries
from .content import ToolCall, ToolResponse
from .errors import LLMError
from .history import MessageHistory
from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
from .model import ModelBase
Expand Down Expand Up @@ -27,6 +28,7 @@

__all__ = [
"ModelBase",
"LLMError",
"ToolCall",
"ToolResponse",
"UserMessage",
Expand Down
46 changes: 46 additions & 0 deletions packages/railtracks/src/railtracks/llm/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from ._exceptions import RTLLMError

if TYPE_CHECKING:
from .history import MessageHistory


class LLMError(RTLLMError):
"""
Raised when an error occurs during LLM invocation or completion.
"""

def __init__(
self,
reason: str,
message_history: "MessageHistory" = None,
):
self.reason = reason
self.message_history = message_history

message = f"{self._color('LLM Error: ', self.BOLD_RED)}{self._color(reason, self.RED)}"
super().__init__(message)

def __str__(self):
base = super().__str__()
details = []
if self.message_history:
mh_str = str(self.message_history)
indented_mh = "\n".join(
" " + line for line in mh_str.splitlines()
) # 2 indents (2-spaces) per indent
details.append(
self._color("Message History:\n", self.BOLD_GREEN)
+ self._color(indented_mh, self.GREEN)
)
if details:
notes_str = (
"\n"
+ self._color("Details:\n", self.BOLD_GREEN)
+ "\n".join(f" {d}" for d in details)
)
return f"\n{self._color(base, self.RED)}{notes_str}"
return self._color(base, self.RED)
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@
from litellm.types.utils import ModelResponse
from pydantic import BaseModel, Field

from ...exceptions.errors import LLMError, NodeInvocationError
from ..content import ToolCall
from ..errors import LLMError
from ..history import MessageHistory
from ..message import AssistantMessage, Message, ToolMessage, UserMessage
from ..model import ModelBase
from ..response import MessageInfo, Response
from ..retries import RetryApproach
from ..tools import Tool
from ..tools.parameters import Parameter
from ._model_exception_base import ModelError

_TBaseModel = TypeVar("_TBaseModel", bound=BaseModel)

Expand Down Expand Up @@ -96,12 +97,8 @@ def _parameters_to_json_schema(
):
return _handle_set_of_parameters(list(parameters))

raise NodeInvocationError(
message=f"Unable to parse Tool.parameters. It was {parameters}",
fatal=True,
notes=[
"Tool.parameters must be a set of Parameter objects",
],
raise ModelError(
reason=f"Unable to parse Tool.parameters. It was {parameters}. Tool.parameters must be a set of Parameter objects.",
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import litellm
import pytest

import railtracks as rt
from railtracks.exceptions import LLMError
from railtracks.llm.errors import LLMError
from railtracks.llm.retries import ExponentialRetry, FixedRetry


Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import json
from json import JSONDecodeError
from typing import Generator, Literal
from unittest.mock import patch

import litellm
import pytest
from pydantic import BaseModel
from railtracks.llm import AssistantMessage, UserMessage
from railtracks.llm.errors import LLMError
from railtracks.llm.models._litellm_wrapper import (
LiteLLMWrapper,
_parameters_to_json_schema,
_to_litellm_tool,
)
from railtracks.exceptions import NodeInvocationError, LLMError
from railtracks.llm import AssistantMessage, UserMessage
from railtracks.llm.models._model_exception_base import ModelError
from railtracks.llm.providers import ModelProvider
from pydantic import BaseModel
from railtracks.llm.response import Response
from json import JSONDecodeError
import litellm
from railtracks.llm.content import Stream
import json


class _ConcreteLiteLLMWrapperForTest(LiteLLMWrapper[Literal[False]]):
Expand Down Expand Up @@ -54,7 +55,7 @@ def test_parameters_to_json_schema_invalid_input(self):
"""
Test _parameters_to_json_schema with invalid input.
"""
with pytest.raises(NodeInvocationError):
with pytest.raises(ModelError):
_parameters_to_json_schema(123) # type: ignore

# =================================== END _parameters_to_json_schema Tests ====================================
Expand Down Expand Up @@ -252,9 +253,9 @@ class Schema(BaseModel):
wrapper = mock_litellm_wrapper(content="Invalid JSON")
method = getattr(wrapper, method_name)
if is_async:
result = await method(message_history, schema=Schema)
await method(message_history, schema=Schema)
else:
result = method(message_history, schema=Schema)
method(message_history, schema=Schema)

@pytest.mark.asyncio
@pytest.mark.parametrize("method_name,is_async", [
Expand All @@ -270,9 +271,9 @@ class Schema(BaseModel):
wrapper = mock_litellm_wrapper(content='{"field": "VAL", "invalid": "json"}')
method = getattr(wrapper, method_name)
if is_async:
result = await method(message_history, schema=Schema)
await method(message_history, schema=Schema)
else:
result = method(message_history, schema=Schema)
method(message_history, schema=Schema)

@pytest.mark.parametrize("method_name,is_async,stream", [
("_chat_with_tools", False, False),
Expand Down Expand Up @@ -324,7 +325,7 @@ async def test_chat_with_tools(
assert calls[0]["name"] == "tool_x"
assert calls[0]["arguments"] == {"foo": 1}
assert calls[0]["identifier"] == "id123"
except Exception as e:
except Exception:
pytest.fail("Structured response did not match schema")
elif not isinstance(chunk, str):
pytest.fail("Stream yielded non-string, non-Response chunk")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest
import railtracks as rt
from pydantic import BaseModel, Field

from railtracks.llm.response import Response
from railtracks.built_nodes.easy_usage_wrappers.helpers import structured_tool_call_llm, structured_llm
from railtracks.exceptions import NodeCreationError, LLMError
from railtracks.llm import MessageHistory, SystemMessage, UserMessage, AssistantMessage

from railtracks.built_nodes.concrete import StructuredToolCallLLM

from railtracks.built_nodes.easy_usage_wrappers.helpers import (
structured_llm,
structured_tool_call_llm,
)
from railtracks.exceptions import NodeCreationError
from railtracks.llm import AssistantMessage, MessageHistory, SystemMessage, UserMessage
from railtracks.llm.errors import LLMError

# =========================== Basic functionality ==========================

Expand Down
10 changes: 5 additions & 5 deletions packages/railtracks/tests/unit_tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import pytest
from railtracks.exceptions._base import RTError
from railtracks.exceptions.errors import (
NodeInvocationError,
NodeCreationError,
LLMError,
GlobalTimeOutError,
ContextError,
FatalError,
GlobalTimeOutError,
NodeCreationError,
NodeInvocationError,
)
from railtracks.exceptions._base import RTError
from railtracks.llm.errors import LLMError

# NOTE: This file contains very basic tests to ensure that the exception classes are working as expected.

Expand Down
Loading