diff --git a/benchmarks/bench_agent_card.cpp b/benchmarks/bench_agent_card.cpp index e67bae6..76f3fc8 100644 --- a/benchmarks/bench_agent_card.cpp +++ b/benchmarks/bench_agent_card.cpp @@ -2,7 +2,7 @@ #include -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include "a2a/core/protojson.h" #include "a2a/server/rest_server_transport.h" #include "bench_common.h" diff --git a/benchmarks/bench_common.h b/benchmarks/bench_common.h index eee233b..b96eef6 100644 --- a/benchmarks/bench_common.h +++ b/benchmarks/bench_common.h @@ -11,7 +11,7 @@ #include #include -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include "a2a/core/protocol_bindings.h" #include "a2a/core/protojson.h" #include "a2a/core/result.h" diff --git a/examples/apps/auth_policy_server/main.cpp b/examples/apps/auth_policy_server/main.cpp index c69aae4..1357904 100644 --- a/examples/apps/auth_policy_server/main.cpp +++ b/examples/apps/auth_policy_server/main.cpp @@ -8,7 +8,7 @@ #include "a2a/client/client.h" #include "a2a/client/http_json_transport.h" -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include "a2a/core/url_utils.h" #include "a2a/server/agent_executor.h" #include "a2a/server/dispatcher.h" diff --git a/examples/apps/hello_agent/main.cpp b/examples/apps/hello_agent/main.cpp index a8be1eb..b7c8574 100644 --- a/examples/apps/hello_agent/main.cpp +++ b/examples/apps/hello_agent/main.cpp @@ -8,7 +8,7 @@ #include "a2a/client/client.h" #include "a2a/client/http_json_transport.h" -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include "a2a/core/url_utils.h" #include "a2a/server/agent_executor.h" #include "a2a/server/dispatcher.h" diff --git a/examples/apps/json_rpc_server/main.cpp b/examples/apps/json_rpc_server/main.cpp index a05cdf9..490b486 100644 --- a/examples/apps/json_rpc_server/main.cpp +++ b/examples/apps/json_rpc_server/main.cpp @@ -8,7 +8,7 @@ #include "a2a/client/client.h" #include "a2a/client/json_rpc_transport.h" -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include "a2a/core/url_utils.h" #include "a2a/server/agent_executor.h" #include "a2a/server/dispatcher.h" diff --git a/examples/apps/rest_server/main.cpp b/examples/apps/rest_server/main.cpp index 1c2fe32..5671a6b 100644 --- a/examples/apps/rest_server/main.cpp +++ b/examples/apps/rest_server/main.cpp @@ -8,7 +8,7 @@ #include "a2a/client/client.h" #include "a2a/client/http_json_transport.h" -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include "a2a/core/url_utils.h" #include "a2a/server/agent_executor.h" #include "a2a/server/dispatcher.h" diff --git a/examples/apps/streaming_server/main.cpp b/examples/apps/streaming_server/main.cpp index 5231786..d2bd3c6 100644 --- a/examples/apps/streaming_server/main.cpp +++ b/examples/apps/streaming_server/main.cpp @@ -8,7 +8,7 @@ #include "a2a/client/client.h" #include "a2a/client/http_json_transport.h" -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include "a2a/core/url_utils.h" #include "a2a/server/agent_executor.h" #include "a2a/server/dispatcher.h" diff --git a/include/a2a/client/client.h b/include/a2a/client/client.h index 8e88aae..72ad10d 100644 --- a/include/a2a/client/client.h +++ b/include/a2a/client/client.h @@ -21,6 +21,7 @@ #include #include "a2a/client/call_options.h" +#include "a2a/core/non_copyable.h" #include "a2a/core/result.h" #include "a2a/v1/a2a.pb.h" @@ -72,7 +73,7 @@ class StreamObserver { virtual void OnCompleted() = 0; }; -class StreamHandle final { +class StreamHandle final : private core::NonCopyable { public: struct State final { std::atomic cancel_requested{false}; @@ -82,8 +83,6 @@ class StreamHandle final { }; StreamHandle() = delete; - StreamHandle(const StreamHandle&) = delete; - StreamHandle& operator=(const StreamHandle&) = delete; StreamHandle(StreamHandle&&) noexcept; StreamHandle& operator=(StreamHandle&&) noexcept; ~StreamHandle(); diff --git a/include/a2a/core/agent_card/agent_card_builder.h b/include/a2a/core/agent_card/agent_card_builder.h new file mode 100644 index 0000000..9462612 --- /dev/null +++ b/include/a2a/core/agent_card/agent_card_builder.h @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2026 Vladimir Pavlov (https://github.com/MisterVVP) + +#pragma once + +#include +#include + +#include "a2a/core/error.h" +#include "a2a/core/protocol_bindings.h" +#include "a2a/core/result.h" +#include "a2a/core/version.h" +#include "a2a/v1/a2a.pb.h" + +namespace a2a::core { + +class AgentCardBuilder final { + public: + AgentCardBuilder& SetName(std::string_view name); + AgentCardBuilder& SetVersion(std::string_view version); + AgentCardBuilder& SetDescription(std::string_view description); + AgentCardBuilder& AddDefaultInputMode(std::string_view mode); + AgentCardBuilder& AddDefaultOutputMode(std::string_view mode); + AgentCardBuilder& WithPushNotifications(bool enabled); + AgentCardBuilder& WithExtendedAgentCard(bool enabled); + AgentCardBuilder& AddExtension(std::string_view uri, bool required, std::string_view description = {}); + + struct InterfaceSpec final { + std::string_view binding; + std::string_view version; + std::string_view url; + }; + + AgentCardBuilder& AddInterface(const InterfaceSpec& spec); + + [[nodiscard]] Result Validate() const; + [[nodiscard]] lf::a2a::v1::AgentCard Build() const; + + [[nodiscard]] static AgentCardBuilder RestPreset(std::string_view name, std::string_view url, + std::string_view version = Version::kAgentCardVersion); + [[nodiscard]] static AgentCardBuilder JsonRpcPreset(std::string_view name, std::string_view url, + std::string_view version = Version::kAgentCardVersion); + [[nodiscard]] static AgentCardBuilder GrpcPreset(std::string_view name, std::string_view url, + std::string_view version = Version::kAgentCardVersion); + struct ConformancePresetSpec final { + std::string_view rest_url; + std::string_view json_rpc_url; + std::string_view grpc_url; + }; + + [[nodiscard]] static AgentCardBuilder ConformancePreset(const ConformancePresetSpec& spec, + std::string_view name = "Conformance SUT", + std::string_view version = Version::kAgentCardVersion, + std::string_view description = "A2A conformance agent"); + + private: + lf::a2a::v1::AgentCard card_; +}; + +} // namespace a2a::core diff --git a/include/a2a/core/agent_card/agent_card_provider.h b/include/a2a/core/agent_card/agent_card_provider.h new file mode 100644 index 0000000..2766fea --- /dev/null +++ b/include/a2a/core/agent_card/agent_card_provider.h @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2026 Vladimir Pavlov (https://github.com/MisterVVP) + +#pragma once + +#include +#include +#include + +#include "a2a/core/non_copyable.h" +#include "a2a/core/result.h" +#include "a2a/v1/a2a.pb.h" + +namespace a2a::core { + +struct AgentCardRequestContext final { + std::optional tenant; + std::optional remote_address; + std::unordered_map client_headers; + std::unordered_map auth_metadata; +}; + +class AgentCardProvider : private NonCopyableOrMovable { + public: + AgentCardProvider() = default; + virtual ~AgentCardProvider() = default; + + [[nodiscard]] virtual Result GetExtendedAgentCard( + const AgentCardRequestContext& context) const = 0; +}; + +class StaticAgentCardProvider final : public AgentCardProvider { + public: + StaticAgentCardProvider() = default; + explicit StaticAgentCardProvider(std::optional extended_agent_card); + + [[nodiscard]] Result GetExtendedAgentCard( + const AgentCardRequestContext& context) const override; + + private: + std::optional extended_agent_card_; +}; + +} // namespace a2a::core diff --git a/include/a2a/core/agent_card_builder.h b/include/a2a/core/agent_card_builder.h index 2fbf40e..b892f99 100644 --- a/include/a2a/core/agent_card_builder.h +++ b/include/a2a/core/agent_card_builder.h @@ -3,57 +3,4 @@ #pragma once -#include -#include - -#include "a2a/core/error.h" -#include "a2a/core/protocol_bindings.h" -#include "a2a/core/result.h" -#include "a2a/core/version.h" -#include "a2a/v1/a2a.pb.h" - -namespace a2a::core { - -class AgentCardBuilder final { - public: - AgentCardBuilder& SetName(std::string_view name); - AgentCardBuilder& SetVersion(std::string_view version); - AgentCardBuilder& SetDescription(std::string_view description); - AgentCardBuilder& AddDefaultInputMode(std::string_view mode); - AgentCardBuilder& AddDefaultOutputMode(std::string_view mode); - AgentCardBuilder& WithPushNotifications(bool enabled); - AgentCardBuilder& AddExtension(std::string_view uri, bool required, std::string_view description = {}); - - struct InterfaceSpec final { - std::string_view binding; - std::string_view version; - std::string_view url; - }; - - AgentCardBuilder& AddInterface(const InterfaceSpec& spec); - - [[nodiscard]] Result Validate() const; - [[nodiscard]] lf::a2a::v1::AgentCard Build() const; - - [[nodiscard]] static AgentCardBuilder RestPreset(std::string_view name, std::string_view url, - std::string_view version = Version::kAgentCardVersion); - [[nodiscard]] static AgentCardBuilder JsonRpcPreset(std::string_view name, std::string_view url, - std::string_view version = Version::kAgentCardVersion); - [[nodiscard]] static AgentCardBuilder GrpcPreset(std::string_view name, std::string_view url, - std::string_view version = Version::kAgentCardVersion); - struct ConformancePresetSpec final { - std::string_view rest_url; - std::string_view json_rpc_url; - std::string_view grpc_url; - }; - - [[nodiscard]] static AgentCardBuilder ConformancePreset(const ConformancePresetSpec& spec, - std::string_view name = "Conformance SUT", - std::string_view version = Version::kAgentCardVersion, - std::string_view description = "A2A conformance agent"); - - private: - lf::a2a::v1::AgentCard card_; -}; - -} // namespace a2a::core +#include "a2a/core/agent_card/agent_card_builder.h" diff --git a/include/a2a/core/core.h b/include/a2a/core/core.h index 20988f2..b61f892 100644 --- a/include/a2a/core/core.h +++ b/include/a2a/core/core.h @@ -3,7 +3,8 @@ #pragma once -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_provider.h" #include "a2a/core/error.h" #include "a2a/core/extensions.h" #include "a2a/core/protojson.h" diff --git a/include/a2a/core/http_constants.h b/include/a2a/core/http_constants.h index edf3563..9c41f9c 100644 --- a/include/a2a/core/http_constants.h +++ b/include/a2a/core/http_constants.h @@ -21,6 +21,7 @@ inline constexpr char kContentTypeParameterSeparator = ';'; inline constexpr std::string_view kMethodGet = "GET"; inline constexpr std::string_view kMethodPost = "POST"; +inline constexpr std::string_view kMethodDelete = "DELETE"; inline constexpr std::string_view kHttpScheme = "http://"; inline constexpr std::string_view kHttpsScheme = "https://"; diff --git a/include/a2a/core/non_copyable.h b/include/a2a/core/non_copyable.h new file mode 100644 index 0000000..b986394 --- /dev/null +++ b/include/a2a/core/non_copyable.h @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2026 Vladimir Pavlov (https://github.com/MisterVVP) + +#pragma once + +namespace a2a::core { + +class NonCopyable { + public: + NonCopyable(const NonCopyable&) = delete; + NonCopyable& operator=(const NonCopyable&) = delete; + + protected: + constexpr NonCopyable() noexcept = default; + constexpr NonCopyable(NonCopyable&&) noexcept = default; + constexpr NonCopyable& operator=(NonCopyable&&) noexcept = default; + ~NonCopyable() = default; +}; + +class NonCopyableOrMovable { + public: + NonCopyableOrMovable(const NonCopyableOrMovable&) = delete; + NonCopyableOrMovable& operator=(const NonCopyableOrMovable&) = delete; + NonCopyableOrMovable(NonCopyableOrMovable&&) = delete; + NonCopyableOrMovable& operator=(NonCopyableOrMovable&&) = delete; + + protected: + constexpr NonCopyableOrMovable() noexcept = default; + ~NonCopyableOrMovable() = default; +}; + +} // namespace a2a::core diff --git a/include/a2a/core/protocol_methods.h b/include/a2a/core/protocol_methods.h index 3f3c6fc..38adc83 100644 --- a/include/a2a/core/protocol_methods.h +++ b/include/a2a/core/protocol_methods.h @@ -19,4 +19,22 @@ inline constexpr std::string_view kListTaskPushNotificationConfigs = "ListTaskPu inline constexpr std::string_view kDeleteTaskPushNotificationConfig = "DeleteTaskPushNotificationConfig"; inline constexpr std::string_view kPushNotificationConfigsSegment = "/pushNotificationConfigs"; +struct GetExtendedAgentCardMethodName final { + static constexpr std::string_view kCanonical = "GetExtendedAgentCard"; + static constexpr std::string_view kJsonRpcAlias = "a2a.getExtendedAgentCard"; + + constexpr operator std::string_view() const noexcept { return kCanonical; } +}; + +inline constexpr GetExtendedAgentCardMethodName kGetExtendedAgentCard{}; + +constexpr bool operator==(std::string_view actual, GetExtendedAgentCardMethodName) noexcept { + return actual == GetExtendedAgentCardMethodName::kCanonical || + actual == GetExtendedAgentCardMethodName::kJsonRpcAlias; +} + +constexpr bool operator==(GetExtendedAgentCardMethodName method, std::string_view actual) noexcept { + return actual == method; +} + } // namespace a2a::core::protocol_methods diff --git a/include/a2a/core/protocol_paths.h b/include/a2a/core/protocol_paths.h new file mode 100644 index 0000000..9f94bf8 --- /dev/null +++ b/include/a2a/core/protocol_paths.h @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2026 Vladimir Pavlov (https://github.com/MisterVVP) + +#pragma once + +#include + +namespace a2a::core::protocol_paths { + +inline constexpr std::string_view kAgentCard = "/.well-known/agent-card.json"; +inline constexpr std::string_view kLegacyAgentCard = "/.well-known/agent.json"; +inline constexpr std::string_view kExtendedAgentCard = "/extendedAgentCard"; +inline constexpr std::string_view kWellKnownPrefix = "/.well-known/"; + +} // namespace a2a::core::protocol_paths diff --git a/include/a2a/server/agent_card/agent_card_serializer.h b/include/a2a/server/agent_card/agent_card_serializer.h new file mode 100644 index 0000000..e961dfd --- /dev/null +++ b/include/a2a/server/agent_card/agent_card_serializer.h @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2026 Vladimir Pavlov (https://github.com/MisterVVP) + +#pragma once + +#include + +#include "a2a/core/result.h" +#include "a2a/v1/a2a.pb.h" + +namespace a2a::server { + +[[nodiscard]] core::Result BuildNormalizedAgentCard(const lf::a2a::v1::AgentCard& agent_card, + bool include_legacy_transport_fields); + +[[nodiscard]] core::Result BuildAgentCardJsonValue(const lf::a2a::v1::AgentCard& agent_card, + bool include_legacy_transport_fields); + +} // namespace a2a::server diff --git a/include/a2a/server/dispatch_types.h b/include/a2a/server/dispatch_types.h index 638902c..b2538a2 100644 --- a/include/a2a/server/dispatch_types.h +++ b/include/a2a/server/dispatch_types.h @@ -25,6 +25,7 @@ enum class DispatcherOperation : std::uint8_t { kGetTaskPushNotificationConfig, kListTaskPushNotificationConfigs, kDeleteTaskPushNotificationConfig, + kGetExtendedAgentCard, }; struct DispatchRequest final { @@ -32,13 +33,14 @@ struct DispatchRequest final { std::variant + lf::a2a::v1::DeleteTaskPushNotificationConfigRequest, lf::a2a::v1::GetExtendedAgentCardRequest> payload = ListTasksRequest{}; }; -using DispatchPayload = std::variant, - lf::a2a::v1::Task, ListTasksResponse, lf::a2a::v1::TaskPushNotificationConfig, - lf::a2a::v1::ListTaskPushNotificationConfigsResponse, std::monostate>; +using DispatchPayload = + std::variant, lf::a2a::v1::Task, + ListTasksResponse, lf::a2a::v1::TaskPushNotificationConfig, + lf::a2a::v1::ListTaskPushNotificationConfigsResponse, lf::a2a::v1::AgentCard, std::monostate>; class DispatchResponse final { public: @@ -52,6 +54,8 @@ class DispatchResponse final { explicit DispatchResponse(const lf::a2a::v1::TaskPushNotificationConfig& payload) : payload_(payload) {} explicit DispatchResponse(lf::a2a::v1::TaskPushNotificationConfig&& payload) : payload_(std::move(payload)) {} explicit DispatchResponse(const lf::a2a::v1::ListTaskPushNotificationConfigsResponse& payload) : payload_(payload) {} + explicit DispatchResponse(const lf::a2a::v1::AgentCard& payload) : payload_(payload) {} + explicit DispatchResponse(lf::a2a::v1::AgentCard&& payload) : payload_(std::move(payload)) {} explicit DispatchResponse(lf::a2a::v1::ListTaskPushNotificationConfigsResponse&& payload) : payload_(std::move(payload)) {} DispatchResponse() : payload_(std::monostate{}) {} diff --git a/include/a2a/server/dispatcher.h b/include/a2a/server/dispatcher.h index 3b578b5..2662fef 100644 --- a/include/a2a/server/dispatcher.h +++ b/include/a2a/server/dispatcher.h @@ -7,6 +7,7 @@ #include #include +#include "a2a/core/agent_card/agent_card_provider.h" #include "a2a/core/result.h" #include "a2a/server/agent_executor.h" #include "a2a/server/dispatch_types.h" @@ -18,7 +19,10 @@ namespace a2a::server { class Dispatcher final { public: explicit Dispatcher(AgentExecutor* executor); + explicit Dispatcher(AgentExecutor* executor, std::shared_ptr agent_card_provider); explicit Dispatcher(AgentExecutor* executor, std::vector> interceptors); + Dispatcher(AgentExecutor* executor, std::vector> interceptors, + std::shared_ptr agent_card_provider); [[nodiscard]] core::Result Dispatch(const DispatchRequest& request, RequestContext& context) const; void AddInterceptor(std::shared_ptr interceptor); @@ -28,6 +32,7 @@ class Dispatcher final { const core::Result& result) const; AgentExecutor* executor_ = nullptr; + std::shared_ptr agent_card_provider_; mutable std::shared_mutex interceptor_mutex_; std::vector> interceptors_; }; diff --git a/include/a2a/server/grpc_server_transport.h b/include/a2a/server/grpc_server_transport.h index 54c8f76..38f8cc9 100644 --- a/include/a2a/server/grpc_server_transport.h +++ b/include/a2a/server/grpc_server_transport.h @@ -18,7 +18,7 @@ namespace a2a::server { struct GrpcServerTransportOptions final { - std::vector required_extensions; + std::vector required_extensions = {}; }; class GrpcServerTransport final : public lf::a2a::v1::A2AService::Service { diff --git a/include/a2a/server/json_rpc_server_transport.h b/include/a2a/server/json_rpc_server_transport.h index 3ffbf11..cd47098 100644 --- a/include/a2a/server/json_rpc_server_transport.h +++ b/include/a2a/server/json_rpc_server_transport.h @@ -24,7 +24,7 @@ struct JsonRpcServerTransportOptions final { bool require_version_header = true; std::size_t default_list_tasks_page_size = 50; std::size_t max_list_tasks_page_size = 100; - std::vector required_extensions; + std::vector required_extensions = {}; }; class JsonRpcServerTransport final { diff --git a/include/a2a/server/rest_server_transport.h b/include/a2a/server/rest_server_transport.h index 2f3983b..06f88ce 100644 --- a/include/a2a/server/rest_server_transport.h +++ b/include/a2a/server/rest_server_transport.h @@ -11,6 +11,7 @@ #include #include +#include "a2a/core/protocol_paths.h" #include "a2a/core/result.h" #include "a2a/server/required_extensions_validator.h" #include "a2a/server/rest_transport.h" @@ -46,13 +47,14 @@ struct RestServerTransportOptions final { bool require_version_header = true; bool include_legacy_transport_fields = true; std::optional agent_card_cache_settings; - std::vector required_extensions; + std::vector required_extensions = {}; }; class RestServerTransport final { public: - static constexpr std::string_view kAgentCardPath = "/.well-known/agent-card.json"; - static constexpr std::string_view kLegacyAgentCardPath = "/.well-known/agent.json"; + static constexpr std::string_view kAgentCardPath = core::protocol_paths::kAgentCard; + static constexpr std::string_view kLegacyAgentCardPath = core::protocol_paths::kLegacyAgentCard; + static constexpr std::string_view kExtendedAgentCardPath = core::protocol_paths::kExtendedAgentCard; RestServerTransport(Dispatcher* dispatcher, lf::a2a::v1::AgentCard agent_card, RestServerTransportOptions options = {}); @@ -63,11 +65,14 @@ class RestServerTransport final { [[nodiscard]] core::Result BuildRestRequest(const HttpServerRequest& request) const; [[nodiscard]] core::Result ValidateVersionHeader(const HttpServerRequest& request) const; [[nodiscard]] core::Result HandleAgentCard(const HttpServerRequest& request) const; + [[nodiscard]] core::Result HandleExtendedAgentCard(const HttpServerRequest& request, + std::string_view tenant) const; [[nodiscard]] static HttpServerResponse ToHttpResponse(const RestResponse& response, const std::vector& activated_extensions); static std::string NormalizeBasePath(std::string_view path); + Dispatcher* dispatcher_ = nullptr; RestTransport transport_; lf::a2a::v1::AgentCard agent_card_; RestServerTransportOptions options_; diff --git a/include/a2a/server/stores/postgres_common.h b/include/a2a/server/stores/postgres_common.h index e743769..dead31e 100644 --- a/include/a2a/server/stores/postgres_common.h +++ b/include/a2a/server/stores/postgres_common.h @@ -12,6 +12,7 @@ #include #include "a2a/core/error.h" +#include "a2a/core/non_copyable.h" #include "a2a/server/stores/store_factory.h" typedef struct pg_conn PGconn; @@ -55,11 +56,9 @@ class PostgresConnectionPool final { public: explicit PostgresConnectionPool(std::string connection_string, std::size_t size = kDefaultPostgresConnectionPoolSize); - class Lease final { + class Lease final : private core::NonCopyable { public: Lease(PostgresConnectionPool* pool, PgConnection connection); - Lease(const Lease&) = delete; - Lease& operator=(const Lease&) = delete; Lease(Lease&& other) noexcept; Lease& operator=(Lease&& other) noexcept = delete; ~Lease(); diff --git a/include/a2a/server/stores/store_factory.h b/include/a2a/server/stores/store_factory.h index 4d0f24b..98898e4 100644 --- a/include/a2a/server/stores/store_factory.h +++ b/include/a2a/server/stores/store_factory.h @@ -6,6 +6,7 @@ #include #include +#include "a2a/core/non_copyable.h" #include "a2a/core/result.h" #include "a2a/server/push_notification_store.h" #include "a2a/server/tasks/in_memory_task_store.h" @@ -28,13 +29,9 @@ struct StoreBundle final { std::unique_ptr push_store; }; -class StoreFactory { +class StoreFactory : private core::NonCopyableOrMovable { public: StoreFactory() = default; - StoreFactory(const StoreFactory&) = delete; - StoreFactory& operator=(const StoreFactory&) = delete; - StoreFactory(StoreFactory&&) = delete; - StoreFactory& operator=(StoreFactory&&) = delete; virtual ~StoreFactory() = default; [[nodiscard]] virtual StoreBackendKind backend_kind() const noexcept = 0; diff --git a/include/a2a/server/stream_response_coroutine.h b/include/a2a/server/stream_response_coroutine.h index b22373e..964f84b 100644 --- a/include/a2a/server/stream_response_coroutine.h +++ b/include/a2a/server/stream_response_coroutine.h @@ -8,11 +8,12 @@ #include #include +#include "a2a/core/non_copyable.h" #include "a2a/v1/a2a.pb.h" namespace a2a::server { -class StreamResponseCoroutine final { +class StreamResponseCoroutine final : private core::NonCopyable { public: struct promise_type final { [[nodiscard]] StreamResponseCoroutine get_return_object() { @@ -32,9 +33,6 @@ class StreamResponseCoroutine final { }; StreamResponseCoroutine() = default; - StreamResponseCoroutine(const StreamResponseCoroutine&) = delete; - StreamResponseCoroutine& operator=(const StreamResponseCoroutine&) = delete; - StreamResponseCoroutine(StreamResponseCoroutine&& other) noexcept : handle_(std::exchange(other.handle_, {})) {} StreamResponseCoroutine& operator=(StreamResponseCoroutine&& other) noexcept { if (this != &other) { diff --git a/include/a2a/server/task_subscription_service.h b/include/a2a/server/task_subscription_service.h index f768627..3d92b97 100644 --- a/include/a2a/server/task_subscription_service.h +++ b/include/a2a/server/task_subscription_service.h @@ -14,6 +14,7 @@ #include #include +#include "a2a/core/non_copyable.h" #include "a2a/core/protocol_errors.h" #include "a2a/core/result.h" #include "a2a/core/task_states.h" @@ -23,13 +24,9 @@ namespace a2a::server { -class TaskSubscriptionService final { +class TaskSubscriptionService final : private core::NonCopyableOrMovable { public: TaskSubscriptionService() = default; - TaskSubscriptionService(const TaskSubscriptionService&) = delete; - TaskSubscriptionService& operator=(const TaskSubscriptionService&) = delete; - TaskSubscriptionService(TaskSubscriptionService&&) = delete; - TaskSubscriptionService& operator=(TaskSubscriptionService&&) = delete; ~TaskSubscriptionService(); [[nodiscard]] core::Result> Subscribe(const lf::a2a::v1::Task& task); diff --git a/include/a2a/server/transport_mux.h b/include/a2a/server/transport_mux.h index d2c8880..386060a 100644 --- a/include/a2a/server/transport_mux.h +++ b/include/a2a/server/transport_mux.h @@ -9,6 +9,7 @@ #include #include "a2a/core/http_constants.h" +#include "a2a/core/protocol_paths.h" #include "a2a/core/result.h" #include "a2a/server/json_rpc_server_transport.h" #include "a2a/server/rest_server_transport.h" @@ -57,14 +58,14 @@ class TransportMux final { struct JsonRpcRouteOptions final { std::string route_name = "jsonrpc"; std::string rpc_path = "/rpc"; - std::string method = "POST"; + std::string method = std::string(core::http::kMethodPost); int priority = 100; }; struct RestRouteOptions final { std::string route_name = "rest"; std::string rest_api_prefix = "/a2a"; - std::string well_known_prefix = "/.well-known/"; + std::string well_known_prefix = std::string(core::protocol_paths::kWellKnownPrefix); int priority = 10; }; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c0c1d01..e96c066 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(a2a_core STATIC - core/agent_card_builder.cpp + core/agent_card/agent_card_builder.cpp + core/agent_card/agent_card_provider.cpp core/core.cpp core/error.cpp core/extensions.cpp @@ -79,6 +80,7 @@ target_link_libraries(a2a_client ) add_library(a2a_server STATIC + server/agent_card/agent_card_serializer.cpp server/grpc_server_transport.cpp server/http_adapter.cpp server/rest_server_transport.cpp diff --git a/src/client/discovery.cpp b/src/client/discovery.cpp index de37ade..ea28e67 100644 --- a/src/client/discovery.cpp +++ b/src/client/discovery.cpp @@ -12,7 +12,9 @@ #include #include "a2a/core/error.h" +#include "a2a/core/http_constants.h" #include "a2a/core/protocol_bindings.h" +#include "a2a/core/protocol_paths.h" #include "a2a/core/protojson.h" #include "a2a/core/version.h" #include "a2a/http/http_client.h" @@ -24,7 +26,6 @@ namespace { constexpr int kHttpStatusOkMin = 200; constexpr int kHttpStatusOkMax = 299; constexpr int kHttpStatusNotFound = 404; -constexpr std::string_view kDiscoveryGetMethod = "GET"; std::string Trim(std::string_view input) { std::string value(input); @@ -97,8 +98,9 @@ std::optional ToWireTransport(PreferredTransport transport) { HttpFetcher MakeDefaultHttpFetcher() { return [client = a2a::http::Client{}](std::string_view url) -> core::Result { a2a::http::Request request; - request.method = std::string(kDiscoveryGetMethod); + request.method = std::string(core::http::kMethodGet); request.url = std::string(url); + request.headers.push_back({std::string(core::Version::kHeaderName), core::Version::HeaderValue()}); auto response = client.SendRequest(request); if (!response.ok()) { return response.error(); @@ -203,15 +205,23 @@ core::Result DiscoveryClient::BuildDiscoveryUrl(std::string_view ba while (!normalized.empty() && normalized.back() == '/') { normalized.pop_back(); } - return normalized + "/.well-known/agent-card.json"; + return normalized + std::string(core::protocol_paths::kAgentCard); } core::Result DiscoveryClient::BuildExtendedDiscoveryUrl(std::string_view base_url) { - const auto standard = BuildDiscoveryUrl(base_url); - if (!standard.ok()) { - return standard.error(); + std::string normalized = Trim(base_url); + if (normalized.empty()) { + return core::Error::Validation("Base URL is required for extended Agent Card discovery"); + } + if (!HasHttpScheme(normalized)) { + return core::Error::Validation("Base URL must start with http:// or https://"); + } + + while (!normalized.empty() && normalized.back() == '/') { + normalized.pop_back(); } - return standard.value() + "?view=extended"; + normalized.append(core::protocol_paths::kExtendedAgentCard); + return normalized; } core::Result DiscoveryClient::ValidateAgentCard(const lf::a2a::v1::AgentCard& card) { diff --git a/src/client/http_json_transport.cpp b/src/client/http_json_transport.cpp index 45a4888..26050a1 100644 --- a/src/client/http_json_transport.cpp +++ b/src/client/http_json_transport.cpp @@ -17,6 +17,7 @@ #include "a2a/client/sse_parser.h" #include "a2a/core/error.h" #include "a2a/core/extensions.h" +#include "a2a/core/http_constants.h" #include "a2a/core/protocol_methods.h" #include "a2a/core/protojson.h" #include "a2a/core/version.h" @@ -219,7 +220,7 @@ core::Result DispatchSseEvent(const SseEvent& event, StreamObserver& obser core::Result ParseListTasksResponsePayload(const HttpClientResponse& response, std::string_view endpoint) { if (response.status_code < kHttpOkMin || response.status_code > kHttpOkMax) { - return BuildHttpError("GET", endpoint, response); + return BuildHttpError(core::http::kMethodGet, endpoint, response); } google::protobuf::Struct payload; @@ -396,12 +397,13 @@ core::Result HttpJsonTransport::SendMessage( } const std::string endpoint(EndpointMap::kSendMessage); - const auto response = SendRequest({.method = "POST", .endpoint = endpoint}, body.value(), options); + const auto response = + SendRequest({.method = std::string(core::http::kMethodPost), .endpoint = endpoint}, body.value(), options); if (!response.ok()) { return response.error(); } - return ParseBodyOrMapError("POST", endpoint, response.value()); + return ParseBodyOrMapError(core::http::kMethodPost, endpoint, response.value()); } core::Result HttpJsonTransport::GetTask(const lf::a2a::v1::GetTaskRequest& request, @@ -415,11 +417,11 @@ core::Result HttpJsonTransport::GetTask(const lf::a2a::v1::Ge endpoint += "?historyLength=" + std::to_string(request.history_length()); } - const auto response = SendRequest({.method = "GET", .endpoint = endpoint}, {}, options); + const auto response = SendRequest({.method = std::string(core::http::kMethodGet), .endpoint = endpoint}, {}, options); if (!response.ok()) { return response.error(); } - return ParseBodyOrMapError("GET", endpoint, response.value()); + return ParseBodyOrMapError(core::http::kMethodGet, endpoint, response.value()); } core::Result HttpJsonTransport::ListTasks(const ListTasksRequest& request, @@ -442,7 +444,8 @@ core::Result HttpJsonTransport::ListTasks(const ListTasksRequ } const std::string endpoint_path = endpoint.str(); - const auto response = SendRequest({.method = "GET", .endpoint = endpoint_path}, {}, options); + const auto response = + SendRequest({.method = std::string(core::http::kMethodGet), .endpoint = endpoint_path}, {}, options); if (!response.ok()) { return response.error(); } @@ -456,11 +459,12 @@ core::Result HttpJsonTransport::CancelTask(const lf::a2a::v1: } const std::string endpoint = BuildTaskPath(request.id()) + ":cancel"; - const auto response = SendRequest({.method = "POST", .endpoint = endpoint}, "{}", options); + const auto response = + SendRequest({.method = std::string(core::http::kMethodPost), .endpoint = endpoint}, "{}", options); if (!response.ok()) { return response.error(); } - return ParseBodyOrMapError("POST", endpoint, response.value()); + return ParseBodyOrMapError(core::http::kMethodPost, endpoint, response.value()); } core::Result HttpJsonTransport::CreateTaskPushNotificationConfig( @@ -475,11 +479,13 @@ core::Result HttpJsonTransport::CreateT } const std::string endpoint = BuildTaskPushConfigCollectionPath(request.task_id()); - const auto response = SendRequest({.method = "POST", .endpoint = endpoint}, body.value(), options); + const auto response = + SendRequest({.method = std::string(core::http::kMethodPost), .endpoint = endpoint}, body.value(), options); if (!response.ok()) { return response.error(); } - return ParseBodyOrMapError("POST", endpoint, response.value()); + return ParseBodyOrMapError(core::http::kMethodPost, endpoint, + response.value()); } core::Result HttpJsonTransport::GetTaskPushNotificationConfig( @@ -492,11 +498,12 @@ core::Result HttpJsonTransport::GetTask } const std::string endpoint = BuildTaskPushConfigPath({.task_id = request.task_id(), .id = request.id()}); - const auto response = SendRequest({.method = "GET", .endpoint = endpoint}, {}, options); + const auto response = SendRequest({.method = std::string(core::http::kMethodGet), .endpoint = endpoint}, {}, options); if (!response.ok()) { return response.error(); } - return ParseBodyOrMapError("GET", endpoint, response.value()); + return ParseBodyOrMapError(core::http::kMethodGet, endpoint, + response.value()); } core::Result HttpJsonTransport::ListTaskPushNotificationConfigs( @@ -523,11 +530,12 @@ core::Result HttpJsonTrans } const std::string path = endpoint.str(); - const auto response = SendRequest({.method = "GET", .endpoint = path}, {}, options); + const auto response = SendRequest({.method = std::string(core::http::kMethodGet), .endpoint = path}, {}, options); if (!response.ok()) { return response.error(); } - return ParseBodyOrMapError("GET", path, response.value()); + return ParseBodyOrMapError(core::http::kMethodGet, path, + response.value()); } core::Result HttpJsonTransport::DeleteTaskPushNotificationConfig( @@ -540,13 +548,14 @@ core::Result HttpJsonTransport::DeleteTaskPushNotificationConfig( } const std::string endpoint = BuildTaskPushConfigPath({.task_id = request.task_id(), .id = request.id()}); - const auto response = SendRequest({.method = "DELETE", .endpoint = endpoint}, {}, options); + const auto response = + SendRequest({.method = std::string(core::http::kMethodDelete), .endpoint = endpoint}, {}, options); if (!response.ok()) { return response.error(); } if (response.value().status_code < kHttpOkMin || response.value().status_code > kHttpOkMax) { - return BuildHttpError("DELETE", endpoint, response.value()); + return BuildHttpError(core::http::kMethodDelete, endpoint, response.value()); } if (response.value().status_code != kHttpNoContent && !response.value().body.empty() && @@ -568,8 +577,9 @@ core::Result> HttpJsonTransport::SendStreamingMess return body.error(); } - return StartSseStream({.method = "POST", .endpoint = EndpointMap::kSendStreamingMessage}, body.value(), observer, - options); + return StartSseStream( + {.method = std::string(core::http::kMethodPost), .endpoint = EndpointMap::kSendStreamingMessage}, body.value(), + observer, options); } core::Result> HttpJsonTransport::SubscribeTask(const lf::a2a::v1::GetTaskRequest& request, @@ -584,7 +594,7 @@ core::Result> HttpJsonTransport::SubscribeTask(con endpoint += "?historyLength=" + std::to_string(request.history_length()); } - return StartSseStream({.method = "GET", .endpoint = endpoint}, {}, observer, options); + return StartSseStream({.method = std::string(core::http::kMethodGet), .endpoint = endpoint}, {}, observer, options); } core::Result> HttpJsonTransport::StartSseStream(HttpOperation operation, std::string body, diff --git a/src/client/json_rpc_transport.cpp b/src/client/json_rpc_transport.cpp index b08990b..89f8e8b 100644 --- a/src/client/json_rpc_transport.cpp +++ b/src/client/json_rpc_transport.cpp @@ -18,6 +18,7 @@ #include "a2a/core/error.h" #include "a2a/core/extensions.h" +#include "a2a/core/http_constants.h" #include "a2a/core/json_rpc.h" #include "a2a/core/protojson.h" #include "a2a/core/version.h" @@ -200,7 +201,7 @@ core::Result JsonRpcTransport::SendJsonRpcRequest(std::strin } HttpRequest http_request; - http_request.method = "POST"; + http_request.method = std::string(core::http::kMethodPost); http_request.url = JoinUrl(resolved_interface_.url); http_request.body = std::move(request_body); http_request.timeout = options.timeout.value_or(default_timeout_); diff --git a/src/core/agent_card_builder.cpp b/src/core/agent_card/agent_card_builder.cpp similarity index 97% rename from src/core/agent_card_builder.cpp rename to src/core/agent_card/agent_card_builder.cpp index d2eeb2e..9ece681 100644 --- a/src/core/agent_card_builder.cpp +++ b/src/core/agent_card/agent_card_builder.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright 2026 Vladimir Pavlov (https://github.com/MisterVVP) -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include #include @@ -76,6 +76,11 @@ AgentCardBuilder& AgentCardBuilder::WithPushNotifications(bool enabled) { return *this; } +AgentCardBuilder& AgentCardBuilder::WithExtendedAgentCard(bool enabled) { + card_.mutable_capabilities()->set_extended_agent_card(enabled); + return *this; +} + AgentCardBuilder& AgentCardBuilder::AddExtension(std::string_view uri, bool required, std::string_view description) { auto* extension = card_.mutable_capabilities()->add_extensions(); extension->set_uri(std::string(uri)); diff --git a/src/core/agent_card/agent_card_provider.cpp b/src/core/agent_card/agent_card_provider.cpp new file mode 100644 index 0000000..6331efa --- /dev/null +++ b/src/core/agent_card/agent_card_provider.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2026 Vladimir Pavlov (https://github.com/MisterVVP) + +#include "a2a/core/agent_card/agent_card_provider.h" + +#include +#include + +#include "a2a/core/protocol_errors.h" + +namespace a2a::core { + +StaticAgentCardProvider::StaticAgentCardProvider(std::optional extended_agent_card) + : extended_agent_card_(std::move(extended_agent_card)) {} + +Result StaticAgentCardProvider::GetExtendedAgentCard( + const AgentCardRequestContext& context) const { + (void)context; + if (!extended_agent_card_.has_value()) { + return protocol_errors::ExtendedAgentCardNotConfigured(); + } + return *extended_agent_card_; +} + +} // namespace a2a::core diff --git a/src/server/agent_card/agent_card_serializer.cpp b/src/server/agent_card/agent_card_serializer.cpp new file mode 100644 index 0000000..ef280fa --- /dev/null +++ b/src/server/agent_card/agent_card_serializer.cpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2026 Vladimir Pavlov (https://github.com/MisterVVP) + +#include "a2a/server/agent_card/agent_card_serializer.h" + +#include +#include +#include + +#include "a2a/core/legacy_transport_names.h" +#include "a2a/core/protocol_bindings.h" +#include "a2a/core/protojson.h" + +namespace a2a::server { +namespace { + +constexpr std::string_view kDefaultAgentCardVersion = "0.1.0"; +constexpr std::string_view kDefaultTextMode = "text/plain"; +constexpr std::string_view kVersionField = "version"; +constexpr std::string_view kDescriptionField = "description"; +constexpr std::string_view kCapabilitiesField = "capabilities"; +constexpr std::string_view kStreamingField = "streaming"; +constexpr std::string_view kPushNotificationsField = "pushNotifications"; +constexpr std::string_view kDefaultInputModesField = "defaultInputModes"; +constexpr std::string_view kDefaultOutputModesField = "defaultOutputModes"; +constexpr std::string_view kSkillsField = "skills"; +constexpr std::string_view kTagsField = "tags"; +constexpr std::string_view kSupportedInterfacesField = "supportedInterfaces"; +constexpr std::string_view kProtocolBindingField = "protocolBinding"; + +bool HasField(const google::protobuf::Struct& object, std::string_view key) { + return object.fields().find(std::string(key)) != object.fields().end(); +} + +google::protobuf::Value* EnsureStructField(google::protobuf::Struct* object, std::string key) { + auto& value = (*object->mutable_fields())[std::move(key)]; + if (!value.has_struct_value()) { + value.mutable_struct_value(); + } + return &value; +} + +google::protobuf::Value* EnsureListField(google::protobuf::Struct* object, std::string key) { + auto& value = (*object->mutable_fields())[std::move(key)]; + if (!value.has_list_value()) { + value.mutable_list_value(); + } + return &value; +} + +void EnsureStringField(google::protobuf::Struct* object, std::string_view key, std::string_view fallback) { + if (!HasField(*object, key)) { + (*object->mutable_fields())[std::string(key)].set_string_value(std::string(fallback)); + } +} + +void EnsureBoolField(google::protobuf::Struct* object, std::string_view key, bool fallback) { + if (!HasField(*object, key)) { + (*object->mutable_fields())[std::string(key)].set_bool_value(fallback); + } +} + +void EnsureDefaultModeField(google::protobuf::Struct* card, std::string_view key) { + if (HasField(*card, key)) { + return; + } + auto* modes = EnsureListField(card, std::string(key))->mutable_list_value(); + modes->add_values()->set_string_value(std::string(kDefaultTextMode)); +} + +void EnsureSkillTags(google::protobuf::Struct* card) { + auto* fields = card->mutable_fields(); + if (fields->find(std::string(kSkillsField)) == fields->end()) { + EnsureListField(card, std::string(kSkillsField)); + } + + auto skills_it = fields->find(std::string(kSkillsField)); + if (skills_it == fields->end() || !skills_it->second.has_list_value()) { + return; + } + + for (auto& skill : *skills_it->second.mutable_list_value()->mutable_values()) { + if (!skill.has_struct_value()) { + continue; + } + EnsureListField(skill.mutable_struct_value(), std::string(kTagsField)); + } +} + +void NormalizeAgentCardFields(google::protobuf::Struct* card) { + EnsureStringField(card, kVersionField, kDefaultAgentCardVersion); + EnsureStringField(card, kDescriptionField, ""); + + auto* capabilities = EnsureStructField(card, std::string(kCapabilitiesField))->mutable_struct_value(); + EnsureBoolField(capabilities, kStreamingField, false); + EnsureBoolField(capabilities, kPushNotificationsField, false); + + EnsureDefaultModeField(card, kDefaultInputModesField); + EnsureDefaultModeField(card, kDefaultOutputModesField); + EnsureSkillTags(card); +} + +void AddLegacyTransportFields(google::protobuf::Struct* card, const lf::a2a::v1::AgentCard& agent_card) { + if (card == nullptr) { + return; + } + + auto* fields = card->mutable_fields(); + auto interfaces_it = fields->find(std::string(kSupportedInterfacesField)); + if (interfaces_it != fields->end() && interfaces_it->second.has_list_value()) { + for (auto& interface_value : *interfaces_it->second.mutable_list_value()->mutable_values()) { + if (!interface_value.has_struct_value()) { + continue; + } + auto* interface_fields = interface_value.mutable_struct_value()->mutable_fields(); + const auto binding_it = interface_fields->find(std::string(kProtocolBindingField)); + if (binding_it == interface_fields->end() || + binding_it->second.kind_case() != google::protobuf::Value::kStringValue) { + continue; + } + (*interface_fields)[std::string(core::legacy_transport_names::kTransportField)].set_string_value( + binding_it->second.string_value()); + } + } + + if (fields->find(std::string(core::legacy_transport_names::kEndpointField)) != fields->end()) { + return; + } + for (const auto& iface : agent_card.supported_interfaces()) { + if (iface.protocol_binding() == core::protocol_bindings::kJsonRpc || + iface.protocol_binding() == core::protocol_bindings::kHttpJson) { + (*fields)[std::string(core::legacy_transport_names::kEndpointField)].set_string_value(iface.url()); + (*fields)[std::string(core::legacy_transport_names::kPreferredTransportField)].set_string_value( + iface.protocol_binding()); + break; + } + } +} + +} // namespace + +core::Result BuildNormalizedAgentCard(const lf::a2a::v1::AgentCard& agent_card, + bool include_legacy_transport_fields) { + const auto body = core::MessageToJson(agent_card); + if (!body.ok()) { + return body.error(); + } + + google::protobuf::Struct card; + const auto parsed = core::JsonToMessage(body.value(), &card, {.ignore_unknown_fields = false}); + if (!parsed.ok()) { + return parsed.error(); + } + + NormalizeAgentCardFields(&card); + if (include_legacy_transport_fields) { + AddLegacyTransportFields(&card, agent_card); + } + return card; +} + +core::Result BuildAgentCardJsonValue(const lf::a2a::v1::AgentCard& agent_card, + bool include_legacy_transport_fields) { + const auto card = BuildNormalizedAgentCard(agent_card, include_legacy_transport_fields); + if (!card.ok()) { + return card.error(); + } + + google::protobuf::Value value; + *value.mutable_struct_value() = card.value(); + return value; +} + +} // namespace a2a::server diff --git a/src/server/dispatcher.cpp b/src/server/dispatcher.cpp index 9660b15..f0a870f 100644 --- a/src/server/dispatcher.cpp +++ b/src/server/dispatcher.cpp @@ -12,6 +12,7 @@ #include "a2a/core/error.h" #include "a2a/core/protocol_error_messages.h" +#include "a2a/core/protocol_errors.h" #include "a2a/server/tasks/task_history.h" namespace a2a::server { @@ -29,6 +30,13 @@ bool IsPushNotificationOperation(DispatcherOperation operation) { operation == DispatcherOperation::kDeleteTaskPushNotificationConfig; } +core::AgentCardRequestContext ToAgentCardRequestContext(const RequestContext& context, std::string_view tenant) { + return {.tenant = tenant.empty() ? std::optional{} : std::optional(tenant), + .remote_address = context.remote_address, + .client_headers = context.client_headers, + .auth_metadata = context.auth_metadata}; +} + core::Result DispatchPushToExecutor(AgentExecutor& executor, const DispatchRequest& request, RequestContext& context) { switch (request.operation) { @@ -86,6 +94,7 @@ core::Result DispatchPushToExecutor(AgentExecutor& executor, c case DispatcherOperation::kSubscribeTask: case DispatcherOperation::kListTasks: case DispatcherOperation::kCancelTask: + case DispatcherOperation::kGetExtendedAgentCard: return core::Error::Validation("Dispatch operation is not a push notification operation"); } return core::Error::Validation("Unsupported push notification dispatcher operation"); @@ -104,84 +113,122 @@ core::Result DispatchSubscribeToExecutor(AgentExecutor& execut return DispatchResponse(std::move(response.value())); } +core::Result DispatchSendMessageToExecutor(AgentExecutor& executor, const DispatchRequest& request, + RequestContext& context) { + const auto* payload = std::get_if(&request.payload); + if (payload == nullptr) { + return DispatchPayloadTypeMismatchError(core::protocol_error_messages::kDispatchPayloadTypeMismatchForSendMessage); + } + const auto response = executor.SendMessage(*payload, context); + if (!response.ok()) { + return response.error(); + } + return DispatchResponse(response.value()); +} + +core::Result DispatchSendStreamingMessageToExecutor(AgentExecutor& executor, + const DispatchRequest& request, + RequestContext& context) { + const auto* payload = std::get_if(&request.payload); + if (payload == nullptr) { + return DispatchPayloadTypeMismatchError( + core::protocol_error_messages::kDispatchPayloadTypeMismatchForSendStreamingMessage); + } + auto response = executor.SendStreamingMessage(*payload, context); + if (!response.ok()) { + return response.error(); + } + return DispatchResponse(std::move(response.value())); +} + +core::Result DispatchGetTaskToExecutor(AgentExecutor& executor, const DispatchRequest& request, + RequestContext& context) { + const auto* payload = std::get_if(&request.payload); + if (payload == nullptr) { + return DispatchPayloadTypeMismatchError(core::protocol_error_messages::kDispatchPayloadTypeMismatchForGetTask); + } + auto response = executor.GetTask(*payload, context); + if (!response.ok()) { + return response.error(); + } + lf::a2a::v1::Task task = std::move(response.value()); + if (payload->has_history_length()) { + ApplyHistoryRetention(&task, static_cast(payload->history_length())); + } + return DispatchResponse(std::move(task)); +} + +core::Result DispatchListTasksToExecutor(AgentExecutor& executor, const DispatchRequest& request, + RequestContext& context) { + const auto* payload = std::get_if(&request.payload); + if (payload == nullptr) { + return DispatchPayloadTypeMismatchError(core::protocol_error_messages::kDispatchPayloadTypeMismatchForListTasks); + } + const auto response = executor.ListTasks(*payload, context); + if (!response.ok()) { + return response.error(); + } + return DispatchResponse(response.value()); +} + +core::Result DispatchCancelTaskToExecutor(AgentExecutor& executor, const DispatchRequest& request, + RequestContext& context) { + const auto* payload = std::get_if(&request.payload); + if (payload == nullptr) { + return DispatchPayloadTypeMismatchError(core::protocol_error_messages::kDispatchPayloadTypeMismatchForCancelTask); + } + const auto response = executor.CancelTask(*payload, context); + if (!response.ok()) { + return response.error(); + } + return DispatchResponse(response.value()); +} + +core::Result DispatchExtendedAgentCard( + const DispatchRequest& request, RequestContext& context, + const std::shared_ptr& agent_card_provider) { + const auto* payload = std::get_if(&request.payload); + if (payload == nullptr) { + return core::Error::Validation("Dispatch payload type mismatch for GetExtendedAgentCard"); + } + if (agent_card_provider == nullptr) { + return core::protocol_errors::ExtendedAgentCardNotConfigured(); + } + auto response = agent_card_provider->GetExtendedAgentCard(ToAgentCardRequestContext(context, payload->tenant())); + if (!response.ok()) { + return response.error(); + } + return DispatchResponse(std::move(response.value())); +} + core::Result DispatchToExecutor(AgentExecutor& executor, const DispatchRequest& request, - RequestContext& context) { + RequestContext& context, + const std::shared_ptr& agent_card_provider) { if (IsPushNotificationOperation(request.operation)) { return DispatchPushToExecutor(executor, request, context); } switch (request.operation) { - case DispatcherOperation::kSendMessage: { - const auto* payload = std::get_if(&request.payload); - if (payload == nullptr) { - return DispatchPayloadTypeMismatchError( - core::protocol_error_messages::kDispatchPayloadTypeMismatchForSendMessage); - } - const auto response = executor.SendMessage(*payload, context); - if (!response.ok()) { - return response.error(); - } - return DispatchResponse(response.value()); - } - case DispatcherOperation::kSendStreamingMessage: { - const auto* payload = std::get_if(&request.payload); - if (payload == nullptr) { - return DispatchPayloadTypeMismatchError( - core::protocol_error_messages::kDispatchPayloadTypeMismatchForSendStreamingMessage); - } - auto response = executor.SendStreamingMessage(*payload, context); - if (!response.ok()) { - return response.error(); - } - return DispatchResponse(std::move(response.value())); - } - case DispatcherOperation::kGetTask: { - const auto* payload = std::get_if(&request.payload); - if (payload == nullptr) { - return DispatchPayloadTypeMismatchError(core::protocol_error_messages::kDispatchPayloadTypeMismatchForGetTask); - } - auto response = executor.GetTask(*payload, context); - if (!response.ok()) { - return response.error(); - } - lf::a2a::v1::Task task = std::move(response.value()); - if (payload->has_history_length()) { - ApplyHistoryRetention(&task, static_cast(payload->history_length())); - } - return DispatchResponse(std::move(task)); - } + case DispatcherOperation::kSendMessage: + return DispatchSendMessageToExecutor(executor, request, context); + case DispatcherOperation::kSendStreamingMessage: + return DispatchSendStreamingMessageToExecutor(executor, request, context); + case DispatcherOperation::kGetTask: + return DispatchGetTaskToExecutor(executor, request, context); case DispatcherOperation::kSubscribeTask: { return DispatchSubscribeToExecutor(executor, request, context); } - case DispatcherOperation::kListTasks: { - const auto* payload = std::get_if(&request.payload); - if (payload == nullptr) { - return DispatchPayloadTypeMismatchError( - core::protocol_error_messages::kDispatchPayloadTypeMismatchForListTasks); - } - const auto response = executor.ListTasks(*payload, context); - if (!response.ok()) { - return response.error(); - } - return DispatchResponse(response.value()); - } - case DispatcherOperation::kCancelTask: { - const auto* payload = std::get_if(&request.payload); - if (payload == nullptr) { - return DispatchPayloadTypeMismatchError( - core::protocol_error_messages::kDispatchPayloadTypeMismatchForCancelTask); - } - const auto response = executor.CancelTask(*payload, context); - if (!response.ok()) { - return response.error(); - } - return DispatchResponse(response.value()); - } + case DispatcherOperation::kListTasks: + return DispatchListTasksToExecutor(executor, request, context); + case DispatcherOperation::kCancelTask: + return DispatchCancelTaskToExecutor(executor, request, context); case DispatcherOperation::kCreateTaskPushNotificationConfig: case DispatcherOperation::kGetTaskPushNotificationConfig: case DispatcherOperation::kListTaskPushNotificationConfigs: case DispatcherOperation::kDeleteTaskPushNotificationConfig: return core::Error::Validation("Push notification dispatch was not handled by push dispatcher"); + case DispatcherOperation::kGetExtendedAgentCard: + return DispatchExtendedAgentCard(request, context, agent_card_provider); } return core::Error::Validation("Unsupported dispatcher operation"); @@ -191,8 +238,17 @@ core::Result DispatchToExecutor(AgentExecutor& executor, const Dispatcher::Dispatcher(AgentExecutor* executor) : executor_(executor) {} +Dispatcher::Dispatcher(AgentExecutor* executor, std::shared_ptr agent_card_provider) + : executor_(executor), agent_card_provider_(std::move(agent_card_provider)) {} + Dispatcher::Dispatcher(AgentExecutor* executor, std::vector> interceptors) - : executor_(executor), interceptors_(std::move(interceptors)) {} + : Dispatcher(executor, std::move(interceptors), nullptr) {} + +Dispatcher::Dispatcher(AgentExecutor* executor, std::vector> interceptors, + std::shared_ptr agent_card_provider) + : executor_(executor), + agent_card_provider_(std::move(agent_card_provider)), + interceptors_(std::move(interceptors)) {} core::Result Dispatcher::Dispatch(const DispatchRequest& request, RequestContext& context) const { if (executor_ == nullptr) { @@ -214,7 +270,7 @@ core::Result Dispatcher::Dispatch(const DispatchRequest& reque } read_lock.unlock(); - auto dispatch_result = DispatchToExecutor(*executor_, request, context); + auto dispatch_result = DispatchToExecutor(*executor_, request, context, agent_card_provider_); RunAfterInterceptors(request, context, dispatch_result); return dispatch_result; } diff --git a/src/server/grpc_server_transport.cpp b/src/server/grpc_server_transport.cpp index 6c880ce..491547d 100644 --- a/src/server/grpc_server_transport.cpp +++ b/src/server/grpc_server_transport.cpp @@ -14,6 +14,7 @@ #include "a2a/core/error.h" #include "a2a/core/extensions.h" +#include "a2a/core/non_copyable.h" #include "a2a/core/protocol_codes.h" #include "a2a/core/protocol_error_messages.h" #include "a2a/core/protocol_errors.h" @@ -179,7 +180,7 @@ constexpr int32_t kMaxListTasksPageSize = 100; constexpr std::chrono::milliseconds kStreamCancellationPollInterval{50}; constexpr std::string_view kExtensionsMetadataKey = "a2a-extensions"; -class StreamCancellationWatcher final { +class StreamCancellationWatcher final : private core::NonCopyable { public: StreamCancellationWatcher(::grpc::ServerContext* context, ServerStreamSession* stream) : context_(context), stream_(stream) { @@ -188,9 +189,6 @@ class StreamCancellationWatcher final { } } - StreamCancellationWatcher(const StreamCancellationWatcher&) = delete; - StreamCancellationWatcher& operator=(const StreamCancellationWatcher&) = delete; - ~StreamCancellationWatcher() { stopped_.store(true, std::memory_order_release); if (worker_.joinable()) { @@ -656,23 +654,23 @@ ::grpc::Status GrpcServerTransport::GetExtendedAgentCard(::grpc::ServerContext* if (request == nullptr || response == nullptr) { return {::grpc::StatusCode::INVALID_ARGUMENT, "Request and response are required"}; } - (void)request; - (void)context; - - response->set_name("A2A C++ SDK Agent"); - response->set_description("Default agent card for compatibility checks"); - response->set_version(std::string(core::Version::kAgentCardVersion)); - response->add_default_input_modes("text/plain"); - response->add_default_output_modes("text/plain"); - auto* capabilities = response->mutable_capabilities(); - capabilities->set_push_notifications(false); - capabilities->set_streaming(true); - for (const auto& required_extension : required_extensions_validator_.required_extensions()) { - auto* extension = capabilities->add_extensions(); - extension->set_uri(required_extension); - extension->set_required(true); + auto request_context = BuildRequestContext(*context); + if (!request_context.ok()) { + return ToGrpcStatus(request_context.error(), context); + } + const auto dispatch = + dispatcher_->Dispatch({.operation = DispatcherOperation::kGetExtendedAgentCard, .payload = *request}, + request_context.value().request_context); + if (!dispatch.ok()) { + return ToGrpcStatus(dispatch.error(), context, request_context.value().activated_extensions); + } + const auto* payload = std::get_if(&dispatch.value().payload()); + if (payload == nullptr) { + return {::grpc::StatusCode::INTERNAL, "GetExtendedAgentCard dispatch returned an unexpected payload"}; } + *response = *payload; + AddActivatedExtensionsTrailingMetadata(context, request_context.value().activated_extensions); return ::grpc::Status::OK; } diff --git a/src/server/json_rpc_server_transport.cpp b/src/server/json_rpc_server_transport.cpp index 0e462d6..404cce7 100644 --- a/src/server/json_rpc_server_transport.cpp +++ b/src/server/json_rpc_server_transport.cpp @@ -28,6 +28,7 @@ #include "a2a/core/protojson.h" #include "a2a/core/task_states.h" #include "a2a/core/version.h" +#include "a2a/server/agent_card/agent_card_serializer.h" #include "a2a/server/http_adapter.h" namespace a2a::server { @@ -121,6 +122,10 @@ bool IsDeletePushConfigMethod(std::string_view method) { core::json_rpc::MethodNames::kDeleteTaskPushNotificationConfig); } +bool IsGetExtendedAgentCardMethod(std::string_view method) { + return method == core::protocol_methods::kGetExtendedAgentCard; +} + std::optional MethodToOperation(std::string_view method) { if (IsSendMessageMethod(method)) { return DispatcherOperation::kSendMessage; @@ -411,6 +416,14 @@ core::Result ParseListTasksPayload(const google::protobuf::Str core::Result BuildDispatchRequestFromMethod(std::string_view method_name, const google::protobuf::Struct& params, const JsonRpcServerTransportOptions& options) { + if (IsGetExtendedAgentCardMethod(method_name)) { + auto payload = ParseProtoPayload(params); + if (!payload.ok()) { + return payload.error(); + } + return DispatchRequest{.operation = DispatcherOperation::kGetExtendedAgentCard, + .payload = std::move(payload.value())}; + } if (IsCreatePushConfigMethod(method_name)) { auto payload = ParseCreatePushConfigPayload(params); if (!payload.ok()) { @@ -494,6 +507,8 @@ core::Result BuildDispatchRequestFromMethod(std::string_view me case DispatcherOperation::kListTaskPushNotificationConfigs: case DispatcherOperation::kDeleteTaskPushNotificationConfig: return core::Error::Internal("Push notification operations are handled before the generic JSON-RPC switch"); + case DispatcherOperation::kGetExtendedAgentCard: + return core::Error::Internal("GetExtendedAgentCard is handled before the generic JSON-RPC switch"); } return core::Error::Internal("Unsupported JSON-RPC dispatcher operation"); @@ -767,7 +782,7 @@ core::Result JsonRpcServerTransport::Handle(const HttpServer } const std::string normalized_target = NormalizePath(request.target); - if (request.method != "POST" || normalized_target != options_.rpc_path) { + if (request.method != core::http::kMethodPost || normalized_target != options_.rpc_path) { return BuildErrorResponse(kJsonRpcInvalidRequest, "No matching JSON-RPC route", ResponseId{}, std::nullopt, core::http::kStatusOk); } @@ -982,6 +997,14 @@ core::Result JsonRpcServerTransport::SerializeDispatchR value.mutable_struct_value(); return value; } + case DispatcherOperation::kGetExtendedAgentCard: { + const auto* payload = std::get_if(&response.payload()); + if (payload == nullptr) { + return core::protocol_errors::InvalidAgentResponse( + "JSON-RPC GetExtendedAgentCard response payload type mismatch"); + } + return BuildAgentCardJsonValue(*payload, false); + } case DispatcherOperation::kSendStreamingMessage: case DispatcherOperation::kSubscribeTask: return core::protocol_errors::InvalidAgentResponse("Streaming JSON-RPC responses must be serialized as SSE"); diff --git a/src/server/push_notifications/push_notification_delivery.cpp b/src/server/push_notifications/push_notification_delivery.cpp index 6170ec4..6cae835 100644 --- a/src/server/push_notifications/push_notification_delivery.cpp +++ b/src/server/push_notifications/push_notification_delivery.cpp @@ -17,7 +17,6 @@ namespace a2a::server { namespace { -constexpr std::string_view kPostMethod = "POST"; constexpr std::string_view kUnsupportedSchemeMessage = "push notification URL must use http or https"; constexpr std::string_view kUnsupportedHttpVersionMessage = "push notification delivery supports only HTTP/1.1, HTTP/2.0, or HTTP/3.0"; @@ -75,7 +74,7 @@ std::vector BuildHeaders(const lf::a2a::v1::TaskPushNotificationCo http::Request BuildHttpRequest(const PushDeliveryRequest& request, std::string body, std::string_view http_version, std::chrono::milliseconds timeout) { http::Request http_request; - http_request.method = std::string(kPostMethod); + http_request.method = std::string(core::http::kMethodPost); http_request.url = request.config.url(); http_request.headers = BuildHeaders(request.config); http_request.body = std::move(body); diff --git a/src/server/rest_server_transport.cpp b/src/server/rest_server_transport.cpp index f1e8d33..5cd115c 100644 --- a/src/server/rest_server_transport.cpp +++ b/src/server/rest_server_transport.cpp @@ -12,18 +12,20 @@ #include #include #include +#include #include +#include #include #include "a2a/core/error.h" #include "a2a/core/extensions.h" #include "a2a/core/http_constants.h" #include "a2a/core/http_utils.h" -#include "a2a/core/legacy_transport_names.h" -#include "a2a/core/protocol_bindings.h" +#include "a2a/core/protocol_codes.h" #include "a2a/core/protocol_errors.h" #include "a2a/core/protojson.h" #include "a2a/core/version.h" +#include "a2a/server/agent_card/agent_card_serializer.h" namespace a2a::server { namespace { @@ -31,6 +33,8 @@ namespace { constexpr int kHexAlphabetOffset = 10; constexpr std::uint64_t kFnvOffsetBasis = 14695981039346656037ULL; constexpr std::uint64_t kFnvPrime = 1099511628211ULL; +constexpr std::string_view kAgentCardViewQueryKey = "view"; +constexpr std::string_view kExtendedAgentCardViewQueryValue = "extended"; struct ErrorBodySpec final { int status_code = core::http::kStatusBadRequest; @@ -38,14 +42,24 @@ struct ErrorBodySpec final { std::string_view reason; }; -void AddLegacyTransportFields(google::protobuf::Struct* card, const lf::a2a::v1::AgentCard& agent_card); - std::string HttpStatusName(int status_code) { switch (status_code) { case core::http::kStatusBadRequest: return "INVALID_ARGUMENT"; case core::http::kStatusNotFound: return "NOT_FOUND"; + case core::http::kStatusUnauthorized: + return "UNAUTHENTICATED"; + case core::http::kStatusForbidden: + return "PERMISSION_DENIED"; + case core::http::kStatusConflict: + return "CONFLICT"; + case core::http::kStatusUnsupportedMediaType: + return "UNSUPPORTED_MEDIA_TYPE"; + case core::http::kStatusInternalServerError: + return "INTERNAL"; + case core::http::kStatusBadGateway: + return "BAD_GATEWAY"; default: return "UNKNOWN"; } @@ -107,6 +121,56 @@ HttpServerResponse BuildValidatedErrorResponse(int status_code, std::string_view return response; } +std::string_view ProtocolCodeToRestReason(std::string_view protocol_code) { + if (protocol_code == core::protocol_codes::kExtendedAgentCardNotConfigured) { + return "EXTENDED_AGENT_CARD_NOT_CONFIGURED"; + } + if (protocol_code == core::protocol_codes::kUnsupportedOperation) { + return "UNSUPPORTED_OPERATION"; + } + if (protocol_code == core::protocol_codes::kContentTypeNotSupported) { + return "CONTENT_TYPE_NOT_SUPPORTED"; + } + if (protocol_code == core::protocol_codes::kInvalidAgentResponse) { + return "INVALID_AGENT_RESPONSE"; + } + if (protocol_code == core::protocol_codes::kExtensionSupportRequired) { + return "EXTENSION_SUPPORT_REQUIRED"; + } + if (protocol_code == core::protocol_codes::kVersionNotSupported) { + return "VERSION_NOT_SUPPORTED"; + } + return "REMOTE_PROTOCOL_ERROR"; +} + +std::string_view RestReasonFromError(const core::Error& error) { + const auto& protocol_code = error.protocol_code(); + if (protocol_code.has_value()) { + return ProtocolCodeToRestReason(*protocol_code); + } + if (error.code() == core::ErrorCode::kUnsupportedVersion) { + return "VERSION_NOT_SUPPORTED"; + } + if (error.code() == core::ErrorCode::kInternal || error.code() == core::ErrorCode::kSerialization) { + return "INTERNAL"; + } + return "INVALID_ARGUMENT"; +} + +int HttpStatusFromError(const core::Error& error) { + const auto& http_status = error.http_status(); + if (http_status.has_value()) { + return *http_status; + } + if (error.code() == core::ErrorCode::kInternal || error.code() == core::ErrorCode::kSerialization) { + return core::http::kStatusInternalServerError; + } + if (error.code() == core::ErrorCode::kNetwork) { + return core::http::kStatusBadGateway; + } + return core::http::kStatusBadRequest; +} + std::uint64_t ComputeEtagHash(std::string_view data) { std::uint64_t hash = kFnvOffsetBasis; for (const char ch : data) { @@ -137,86 +201,12 @@ std::string BuildQuotedEtag(std::uint64_t hash_value) { return etag; } -google::protobuf::Value* EnsureStructField(google::protobuf::Struct* object, std::string key) { - auto& value = (*object->mutable_fields())[std::move(key)]; - if (!value.has_struct_value()) { - value.mutable_struct_value(); - } - return &value; -} - -google::protobuf::Value* EnsureListField(google::protobuf::Struct* object, std::string key) { - auto& value = (*object->mutable_fields())[std::move(key)]; - if (!value.has_list_value()) { - value.mutable_list_value(); - } - return &value; -} - -void EnsureStringField(google::protobuf::Struct* object, std::string_view key, std::string_view fallback) { - auto* fields = object->mutable_fields(); - if (fields->find(std::string(key)) == fields->end()) { - (*fields)[std::string(key)].set_string_value(std::string(fallback)); - } -} - -void EnsureBoolField(google::protobuf::Struct* object, std::string_view key, bool fallback) { - auto* fields = object->mutable_fields(); - if (fields->find(std::string(key)) == fields->end()) { - (*fields)[std::string(key)].set_bool_value(fallback); - } -} - -void EnsureDefaultModeField(google::protobuf::Struct* card, std::string_view key) { - auto* fields = card->mutable_fields(); - if (fields->find(std::string(key)) != fields->end()) { - return; - } - auto* modes = EnsureListField(card, std::string(key))->mutable_list_value(); - modes->add_values()->set_string_value("text/plain"); -} - -void EnsureSkillTags(google::protobuf::Struct* card) { - auto* fields = card->mutable_fields(); - if (fields->find("skills") == fields->end()) { - EnsureListField(card, "skills"); - } - - auto skills_it = fields->find("skills"); - if (skills_it == fields->end() || !skills_it->second.has_list_value()) { - return; - } - - for (auto& skill : *skills_it->second.mutable_list_value()->mutable_values()) { - if (!skill.has_struct_value()) { - continue; - } - EnsureListField(skill.mutable_struct_value(), "tags"); - } -} - -void NormalizeAgentCardFields(google::protobuf::Struct* card) { - EnsureStringField(card, "version", "0.1.0"); - EnsureStringField(card, "description", ""); - - auto* capabilities = EnsureStructField(card, "capabilities")->mutable_struct_value(); - EnsureBoolField(capabilities, "streaming", false); - EnsureBoolField(capabilities, "pushNotifications", false); - - EnsureDefaultModeField(card, "defaultInputModes"); - EnsureDefaultModeField(card, "defaultOutputModes"); - EnsureSkillTags(card); -} - void ApplyAgentCardCacheHeaders(const std::optional& settings, HttpServerResponse* response) { - if (!settings.has_value()) { - return; - } - if (settings->cache_control.has_value()) { + if (settings.has_value() && settings->cache_control.has_value()) { response->headers["Cache-Control"] = *settings->cache_control; } - if (!settings->last_modified.has_value()) { + if (!settings.has_value() || !settings->last_modified.has_value()) { return; } @@ -226,26 +216,6 @@ void ApplyAgentCardCacheHeaders(const std::optional BuildNormalizedAgentCard(const lf::a2a::v1::AgentCard& agent_card, - bool include_legacy_transport_fields) { - const auto body = core::MessageToJson(agent_card); - if (!body.ok()) { - return body.error(); - } - - google::protobuf::Struct card; - const auto parsed = core::JsonToMessage(body.value(), &card, {.ignore_unknown_fields = false}); - if (!parsed.ok()) { - return parsed.error(); - } - - NormalizeAgentCardFields(&card); - if (include_legacy_transport_fields) { - AddLegacyTransportFields(&card, agent_card); - } - return card; -} - core::Result DecodeUrlComponent(std::string_view raw) { std::string decoded; decoded.reserve(raw.size()); @@ -328,47 +298,57 @@ core::Result ParseQueryString(std::string_view raw, std::unordered_map HasExtendedAgentCardView(std::string_view query) { + std::unordered_map query_params; + const auto parsed = ParseQueryString(query, &query_params); + if (!parsed.ok()) { + return parsed.error(); } - auto* fields = card->mutable_fields(); - auto interfaces_it = fields->find("supportedInterfaces"); - if (interfaces_it != fields->end() && interfaces_it->second.has_list_value()) { - for (auto& interface_value : *interfaces_it->second.mutable_list_value()->mutable_values()) { - if (!interface_value.has_struct_value()) { - continue; - } - auto* interface_fields = interface_value.mutable_struct_value()->mutable_fields(); - const auto binding_it = interface_fields->find("protocolBinding"); - if (binding_it == interface_fields->end() || - binding_it->second.kind_case() != google::protobuf::Value::kStringValue) { - continue; - } - (*interface_fields)[std::string(a2a::core::legacy_transport_names::kTransportField)].set_string_value( - binding_it->second.string_value()); - } + const auto view = query_params.find(std::string(kAgentCardViewQueryKey)); + return view != query_params.end() && view->second == kExtendedAgentCardViewQueryValue; +} + +bool PathStartsWithBasePath(std::string_view path, std::string_view base_path) { + return !base_path.empty() && base_path != "/" && path.starts_with(base_path) && + (path.size() == base_path.size() || path[base_path.size()] == '/'); +} + +std::optional ExtractTenantFromRelativeExtendedAgentCardPath(std::string_view path) { + if (!path.starts_with('/') || !path.ends_with(RestServerTransport::kExtendedAgentCardPath)) { + return std::nullopt; } - if (fields->find(std::string(a2a::core::legacy_transport_names::kEndpointField)) == fields->end()) { - for (const auto& iface : agent_card.supported_interfaces()) { - if (iface.protocol_binding() == a2a::core::protocol_bindings::kJsonRpc || - iface.protocol_binding() == a2a::core::protocol_bindings::kHttpJson) { - (*fields)[std::string(a2a::core::legacy_transport_names::kEndpointField)].set_string_value(iface.url()); - (*fields)[std::string(a2a::core::legacy_transport_names::kPreferredTransportField)].set_string_value( - iface.protocol_binding()); - break; - } - } + const std::size_t tenant_end = path.size() - RestServerTransport::kExtendedAgentCardPath.size(); + if (tenant_end <= 1 || path[tenant_end] != '/') { + return std::nullopt; + } + + const std::string_view tenant = path.substr(1, tenant_end - 1); + if (tenant.empty() || tenant.find('/') != std::string_view::npos) { + return std::nullopt; + } + return std::string(tenant); +} + +std::optional ExtractTenantFromExtendedAgentCardPath(std::string_view path, std::string_view base_path) { + if (auto tenant = ExtractTenantFromRelativeExtendedAgentCardPath(path); tenant.has_value()) { + return tenant; + } + if (!PathStartsWithBasePath(path, base_path)) { + return std::nullopt; } + + const std::string_view relative_path = path.substr(base_path.size()); + return ExtractTenantFromRelativeExtendedAgentCardPath(relative_path); } } // namespace RestServerTransport::RestServerTransport(Dispatcher* dispatcher, lf::a2a::v1::AgentCard agent_card, RestServerTransportOptions options) - : transport_(dispatcher), + : dispatcher_(dispatcher), + transport_(dispatcher), agent_card_(std::move(agent_card)), options_(std::move(options)), required_extensions_validator_(options_.required_extensions) { @@ -389,10 +369,32 @@ core::Result RestServerTransport::Handle(const HttpServerReq const std::string_view path = query_start == std::string::npos ? std::string_view(request.target) : std::string_view(request.target).substr(0, query_start); - - if (path == kAgentCardPath || path == kLegacyAgentCardPath) { + const std::string_view query = + query_start == std::string::npos ? std::string_view{} : std::string_view(request.target).substr(query_start + 1); + const bool is_base_path_extended_agent_card = + options_.rest_api_base_path != "/" && + path == options_.rest_api_base_path + std::string(core::protocol_paths::kExtendedAgentCard); + + if (path == core::protocol_paths::kAgentCard) { + const auto extended_view = HasExtendedAgentCardView(query); + if (!extended_view.ok()) { + return extended_view.error(); + } + if (extended_view.value()) { + return HandleExtendedAgentCard(request, {}); + } + return HandleAgentCard(request); + } + if (path == core::protocol_paths::kLegacyAgentCard) { return HandleAgentCard(request); } + if (path == core::protocol_paths::kExtendedAgentCard || is_base_path_extended_agent_card) { + return HandleExtendedAgentCard(request, {}); + } + if (const auto tenant = ExtractTenantFromExtendedAgentCardPath(path, options_.rest_api_base_path); + tenant.has_value()) { + return HandleExtendedAgentCard(request, *tenant); + } const auto version = ValidateVersionHeader(request); if (!version.ok()) { @@ -470,7 +472,7 @@ core::Result RestServerTransport::ValidateVersionHeader(const HttpServerRe } core::Result RestServerTransport::HandleAgentCard(const HttpServerRequest& request) const { - if (request.method != "GET") { + if (request.method != core::http::kMethodGet) { return BuildJsonErrorResponse(core::http::kStatusNotFound, "No matching route or request was malformed", "UNSUPPORTED_OPERATION"); } @@ -496,6 +498,65 @@ core::Result RestServerTransport::HandleAgentCard(const Http return response; } +core::Result RestServerTransport::HandleExtendedAgentCard(const HttpServerRequest& request, + std::string_view tenant) const { + if (request.method != core::http::kMethodGet) { + return BuildJsonErrorResponse(core::http::kStatusNotFound, "No matching route or request was malformed", + "UNSUPPORTED_OPERATION"); + } + const auto version = ValidateVersionHeader(request); + if (!version.ok()) { + return BuildJsonErrorResponse(core::http::kStatusBadRequest, version.error().message(), "VERSION_NOT_SUPPORTED"); + } + const auto extensions = required_extensions_validator_.Validate(request.headers); + if (!extensions.ok()) { + return BuildJsonErrorResponse(core::http::kStatusBadRequest, extensions.error().message(), + "EXTENSION_SUPPORT_REQUIRED"); + } + if (dispatcher_ == nullptr) { + return BuildValidatedErrorResponse(core::http::kStatusInternalServerError, + "REST transport dispatcher is not configured", "INTERNAL", extensions.value()); + } + RequestContext context; + context.remote_address = request.remote_address.empty() ? std::optional{} + : std::optional(request.remote_address); + context.client_headers = request.headers; + context.auth_metadata = ExtractAuthMetadata(request.headers); + lf::a2a::v1::GetExtendedAgentCardRequest card_request; + if (!tenant.empty()) { + card_request.set_tenant(std::string(tenant)); + } + const auto dispatch = dispatcher_->Dispatch( + {.operation = DispatcherOperation::kGetExtendedAgentCard, .payload = card_request}, context); + if (!dispatch.ok()) { + return BuildValidatedErrorResponse(HttpStatusFromError(dispatch.error()), dispatch.error().message(), + RestReasonFromError(dispatch.error()), extensions.value()); + } + const auto* payload = std::get_if(&dispatch.value().payload()); + if (payload == nullptr) { + return core::Error::Internal("GetExtendedAgentCard dispatch returned an unexpected payload"); + } + + const auto card = BuildNormalizedAgentCard(*payload, options_.include_legacy_transport_fields); + if (!card.ok()) { + return card.error(); + } + + const auto normalized = core::MessageToJson(card.value()); + if (!normalized.ok()) { + return normalized.error(); + } + + HttpServerResponse response; + response.status_code = core::http::kStatusOk; + response.headers[std::string(core::http::kContentTypeHeaderName)] = + std::string(core::http::kContentTypeApplicationJson); + response.headers[std::string(core::Version::kHeaderName)] = core::Version::HeaderValue(); + AddActivatedExtensionsHeader(extensions.value(), &response); + response.body = normalized.value(); + return response; +} + HttpServerResponse RestServerTransport::ToHttpResponse(const RestResponse& response, const std::vector& activated_extensions) { HttpServerResponse http_response; diff --git a/src/server/rest_transport.cpp b/src/server/rest_transport.cpp index ca28da7..b3970fb 100644 --- a/src/server/rest_transport.cpp +++ b/src/server/rest_transport.cpp @@ -35,33 +35,36 @@ template } const std::array kRoutes = { - RestRoute{.method = "POST", + RestRoute{.method = core::http::kMethodPost, .path_pattern = RestEndpointPaths::kSendMessage, .operation = DispatcherOperation::kSendMessage}, - RestRoute{.method = "POST", + RestRoute{.method = core::http::kMethodPost, .path_pattern = RestEndpointPaths::kSendStreamingMessage, .operation = DispatcherOperation::kSendStreamingMessage}, - RestRoute{.method = "GET", .path_pattern = "/tasks/{id}", .operation = DispatcherOperation::kGetTask}, - RestRoute{.method = "GET", + RestRoute{ + .method = core::http::kMethodGet, .path_pattern = "/tasks/{id}", .operation = DispatcherOperation::kGetTask}, + RestRoute{.method = core::http::kMethodGet, .path_pattern = RestEndpointPaths::kTaskCollection, .operation = DispatcherOperation::kListTasks}, - RestRoute{.method = "POST", .path_pattern = "/tasks/{id}:cancel", .operation = DispatcherOperation::kCancelTask}, + RestRoute{.method = core::http::kMethodPost, + .path_pattern = "/tasks/{id}:cancel", + .operation = DispatcherOperation::kCancelTask}, RestRoute{.method = core::http::kMethodGet, .path_pattern = RestEndpointPaths::kTaskSubscribePath, .operation = DispatcherOperation::kSubscribeTask}, RestRoute{.method = core::http::kMethodPost, .path_pattern = RestEndpointPaths::kTaskSubscribePath, .operation = DispatcherOperation::kSubscribeTask}, - RestRoute{.method = "POST", + RestRoute{.method = core::http::kMethodPost, .path_pattern = "/tasks/{task_id}/pushNotificationConfigs", .operation = DispatcherOperation::kCreateTaskPushNotificationConfig}, - RestRoute{.method = "GET", + RestRoute{.method = core::http::kMethodGet, .path_pattern = "/tasks/{task_id}/pushNotificationConfigs/{id}", .operation = DispatcherOperation::kGetTaskPushNotificationConfig}, - RestRoute{.method = "GET", + RestRoute{.method = core::http::kMethodGet, .path_pattern = "/tasks/{task_id}/pushNotificationConfigs", .operation = DispatcherOperation::kListTaskPushNotificationConfigs}, - RestRoute{.method = "DELETE", + RestRoute{.method = core::http::kMethodDelete, .path_pattern = "/tasks/{task_id}/pushNotificationConfigs/{id}", .operation = DispatcherOperation::kDeleteTaskPushNotificationConfig}, }; @@ -418,7 +421,7 @@ core::Result BuildSubscribeResponse(std::unique_ptr BuildMessageDispatchRequest(const RestRequest& request) { - if (request.method != "POST" || + if (request.method != core::http::kMethodPost || (request.path != RestEndpointPaths::kSendMessage && request.path != RestEndpointPaths::kSendStreamingMessage)) { return std::nullopt; } @@ -436,7 +439,7 @@ std::optional BuildMessageDispatchRequest(const RestRequest& re } std::optional BuildListTasksDispatchRequest(const RestRequest& request) { - if (request.method != "GET" || request.path != RestEndpointPaths::kTaskCollection) { + if (request.method != core::http::kMethodGet || request.path != RestEndpointPaths::kTaskCollection) { return std::nullopt; } @@ -469,7 +472,7 @@ std::optional BuildListTasksDispatchRequest(const RestRequest& } std::optional BuildGetTaskDispatchRequest(const RestRequest& request) { - if (request.method != "GET") { + if (request.method != core::http::kMethodGet) { return std::nullopt; } @@ -487,7 +490,7 @@ std::optional BuildGetTaskDispatchRequest(const RestRequest& re } std::optional BuildCancelTaskDispatchRequest(const RestRequest& request) { - if (request.method != "POST") { + if (request.method != core::http::kMethodPost) { return std::nullopt; } @@ -506,7 +509,7 @@ std::optional BuildPushConfigDispatchRequest(const RestRequest& if (!path.has_value()) { return std::nullopt; } - if (request.method == "POST" && path->collection) { + if (request.method == core::http::kMethodPost && path->collection) { lf::a2a::v1::TaskPushNotificationConfig payload; const auto parse = core::JsonToMessage(request.body, &payload, {.ignore_unknown_fields = true}); if (!parse.ok()) { @@ -515,7 +518,7 @@ std::optional BuildPushConfigDispatchRequest(const RestRequest& payload.set_task_id(path->task_id); return DispatchRequest{.operation = DispatcherOperation::kCreateTaskPushNotificationConfig, .payload = payload}; } - if (request.method == "GET" && path->collection) { + if (request.method == core::http::kMethodGet && path->collection) { lf::a2a::v1::ListTaskPushNotificationConfigsRequest payload; payload.set_task_id(path->task_id); if (const auto page_size = LookupQuery(request, "pageSize"); page_size.has_value()) { @@ -526,13 +529,13 @@ std::optional BuildPushConfigDispatchRequest(const RestRequest& } return DispatchRequest{.operation = DispatcherOperation::kListTaskPushNotificationConfigs, .payload = payload}; } - if (request.method == "GET" && !path->collection) { + if (request.method == core::http::kMethodGet && !path->collection) { lf::a2a::v1::GetTaskPushNotificationConfigRequest payload; payload.set_task_id(path->task_id); payload.set_id(path->config_id); return DispatchRequest{.operation = DispatcherOperation::kGetTaskPushNotificationConfig, .payload = payload}; } - if (request.method == "DELETE" && !path->collection) { + if (request.method == core::http::kMethodDelete && !path->collection) { lf::a2a::v1::DeleteTaskPushNotificationConfigRequest payload; payload.set_task_id(path->task_id); payload.set_id(path->config_id); @@ -620,6 +623,8 @@ core::Result RestTransport::SerializeDispatchResponse(DispatcherOp google::protobuf::Struct empty; return BuildJsonResponse(empty); } + case DispatcherOperation::kGetExtendedAgentCard: + return core::Error::Validation("Extended agent card is handled by the server transport"); case DispatcherOperation::kListTasks: { const auto* payload = std::get_if(&response.payload()); if (payload == nullptr) { diff --git a/src/server/transport_mux.cpp b/src/server/transport_mux.cpp index 5bd486c..acd7b29 100644 --- a/src/server/transport_mux.cpp +++ b/src/server/transport_mux.cpp @@ -6,7 +6,26 @@ #include #include +#include "a2a/core/protocol_paths.h" + namespace a2a::server { +namespace { + +bool IsTenantExtendedAgentCardPath(std::string_view path) { + if (!path.starts_with('/') || !path.ends_with(core::protocol_paths::kExtendedAgentCard)) { + return false; + } + + const std::size_t tenant_end = path.size() - core::protocol_paths::kExtendedAgentCard.size(); + if (tenant_end <= 1 || path[tenant_end] != '/') { + return false; + } + + const std::string_view tenant = path.substr(1, tenant_end - 1); + return !tenant.empty() && tenant.find('/') == std::string_view::npos; +} + +} // namespace TransportMux::TransportMux() : TransportMux(Options{}) {} @@ -47,7 +66,8 @@ void TransportMux::RegisterRestRoute(RestServerTransport& transport, RestRouteOp [rest_api_prefix = std::move(options.rest_api_prefix), well_known_prefix = std::move(options.well_known_prefix)](std::string_view method, std::string_view path) { (void)method; - return path.starts_with(rest_api_prefix) || path.starts_with(well_known_prefix); + return path == core::protocol_paths::kExtendedAgentCard || IsTenantExtendedAgentCardPath(path) || + path.starts_with(rest_api_prefix) || path.starts_with(well_known_prefix); }, .handler = [&transport](const HttpServerRequest& routed_request) { return transport.Handle(routed_request); }, .priority = options.priority, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 857cfb8..6091595 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -50,6 +50,18 @@ target_link_libraries(core_helpers_test gtest_discover_tests(core_helpers_test) +add_executable(non_copyable_test + unit/non_copyable_test.cpp +) + +target_link_libraries(non_copyable_test + PRIVATE + a2a::core + GTest::gtest_main +) + +gtest_discover_tests(non_copyable_test) + add_executable(response_builders_test unit/response_builders_test.cpp ) diff --git a/tests/integration/grpc_transport_integration_test.cpp b/tests/integration/grpc_transport_integration_test.cpp index e684f56..2d06383 100644 --- a/tests/integration/grpc_transport_integration_test.cpp +++ b/tests/integration/grpc_transport_integration_test.cpp @@ -21,6 +21,8 @@ #include "a2a/client/client.h" #include "a2a/client/grpc_transport.h" +#include "a2a/core/agent_card/agent_card_provider.h" +#include "a2a/core/version.h" #include "a2a/server/agent_executor.h" #include "a2a/server/dispatcher.h" #include "a2a/server/grpc_server_transport.h" @@ -169,7 +171,20 @@ class RecordingObserver final : public a2a::client::StreamObserver { struct GrpcServerHarness final { a2a::server::InMemoryTaskStore store; StreamingStoreExecutor executor{&store}; - a2a::server::Dispatcher dispatcher{&executor}; + lf::a2a::v1::AgentCard extended_card = [] { + lf::a2a::v1::AgentCard card; + card.set_name("A2A C++ SDK Agent"); + card.set_description("Default agent card for compatibility checks"); + card.set_version(std::string(a2a::core::Version::kAgentCardVersion)); + card.add_default_input_modes("text/plain"); + card.add_default_output_modes("text/plain"); + card.mutable_capabilities()->set_push_notifications(false); + card.mutable_capabilities()->set_streaming(true); + return card; + }(); + std::shared_ptr agent_card_provider = + std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher{&executor, agent_card_provider}; a2a::server::GrpcServerTransport transport{&dispatcher}; std::unique_ptr server; int port = 0; diff --git a/tests/interop/tck_http_sut.cpp b/tests/interop/tck_http_sut.cpp index b260fe9..047c4d2 100644 --- a/tests/interop/tck_http_sut.cpp +++ b/tests/interop/tck_http_sut.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -33,7 +34,8 @@ #include #include -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_provider.h" #include "a2a/server/dispatcher.h" #include "a2a/server/grpc_server_transport.h" #include "a2a/server/http_adapter.h" @@ -57,6 +59,10 @@ constexpr std::string_view kInMemoryBackend = "inmemory"; constexpr const char* kStoreBackendEnv = "A2A_TCK_STORE_BACKEND"; constexpr const char* kPostgresDsnEnv = "A2A_TCK_POSTGRES_DSN"; constexpr const char* kPostgresSchemaEnv = "A2A_TCK_POSTGRES_SCHEMA"; +constexpr const char* kExtendedCardModeEnv = "A2A_TCK_EXTENDED_AGENT_CARD_MODE"; +constexpr std::string_view kExtendedCardModeConfigured = "configured"; +constexpr std::string_view kExtendedCardModeDeclaredOnly = "declared_only"; +constexpr std::string_view kExtendedCardModeDisabled = "disabled"; constexpr std::string_view kDefaultPostgresSchema = "public"; constexpr std::string_view kMissingPostgresDsnMessage = "A2A_TCK_POSTGRES_DSN must be set when A2A_TCK_STORE_BACKEND=postgres"; @@ -186,14 +192,32 @@ int main(int argc, char** argv) { std::signal(SIGINT, SignalHandler); std::signal(SIGTERM, SignalHandler); + const char* extended_card_mode_value = std::getenv(kExtendedCardModeEnv); + const std::string_view extended_card_mode = + extended_card_mode_value == nullptr ? kExtendedCardModeConfigured : std::string_view(extended_card_mode_value); + if (extended_card_mode != kExtendedCardModeConfigured && extended_card_mode != kExtendedCardModeDeclaredOnly && + extended_card_mode != kExtendedCardModeDisabled) { + std::cerr << "Unsupported A2A_TCK_EXTENDED_AGENT_CARD_MODE: " << extended_card_mode << '\n'; + return 1; + } + const bool declares_extended_card = extended_card_mode != kExtendedCardModeDisabled; + const bool configures_extended_card = extended_card_mode == kExtendedCardModeConfigured; + auto agent_card = a2a::core::AgentCardBuilder::ConformancePreset( {.rest_url = "http://localhost:" + std::to_string(port) + std::string(kRestApiBasePath), .json_rpc_url = "http://localhost:" + std::to_string(port) + "/rpc", .grpc_url = "localhost:" + std::to_string(grpc_port)}, "TCK HTTP SUT", "0.1.0", "Conformance-focused local SUT for A2A") .WithPushNotifications(true) + .WithExtendedAgentCard(declares_extended_card) .Build(); + std::optional extended_agent_card; + if (configures_extended_card) { + extended_agent_card = agent_card; + extended_agent_card->set_description("Extended conformance-focused local SUT card for A2A"); + } + auto store_bundle = CreateStoreBundleFromEnvironment(); if (!store_bundle.ok()) { std::cerr << store_bundle.error().message() << '\n'; @@ -204,7 +228,8 @@ int main(int argc, char** argv) { executor_options.task_store = store_bundle.value().task_store.get(); executor_options.push_store = store_bundle.value().push_store.get(); a2a::examples::ExampleExecutor executor(std::move(executor_options)); - a2a::server::Dispatcher dispatcher(&executor); + auto agent_card_provider = std::make_shared(extended_agent_card); + a2a::server::Dispatcher dispatcher(&executor, agent_card_provider); a2a::server::GrpcServerTransportOptions grpc_options; grpc_options.required_extensions = {std::string(kTckRequiredExtensionUri)}; a2a::server::GrpcServerTransport grpc(&dispatcher, std::move(grpc_options)); diff --git a/tests/support/example_support/example_support.h b/tests/support/example_support/example_support.h index d30031c..77ee6d5 100644 --- a/tests/support/example_support/example_support.h +++ b/tests/support/example_support/example_support.h @@ -15,7 +15,7 @@ #include #include -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include "a2a/core/error.h" #include "a2a/core/protocol_codes.h" #include "a2a/core/protocol_errors.h" diff --git a/tests/support/rest_server_test_utils.h b/tests/support/rest_server_test_utils.h index 71faf16..953b739 100644 --- a/tests/support/rest_server_test_utils.h +++ b/tests/support/rest_server_test_utils.h @@ -9,7 +9,7 @@ #include #include -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include "a2a/core/error.h" #include "a2a/server/agent_executor.h" #include "a2a/server/request_context.h" diff --git a/tests/unit/agent_card_builder_test.cpp b/tests/unit/agent_card_builder_test.cpp index cd9ccbb..2fc8a09 100644 --- a/tests/unit/agent_card_builder_test.cpp +++ b/tests/unit/agent_card_builder_test.cpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -#include "a2a/core/agent_card_builder.h" +#include "a2a/core/agent_card/agent_card_builder.h" #include @@ -126,3 +126,11 @@ TEST(AgentCardBuilderTest, WithPushNotificationsPreservesExistingCapabilities) { } } // namespace + +TEST(AgentCardBuilderTest, WithExtendedAgentCardSetsCapability) { + const auto card = a2a::core::AgentCardBuilder::RestPreset("REST Agent", "http://agent.local/a2a") + .WithExtendedAgentCard(true) + .Build(); + + EXPECT_TRUE(card.capabilities().extended_agent_card()); +} diff --git a/tests/unit/discovery_test.cpp b/tests/unit/discovery_test.cpp index b287e91..ddc9efb 100644 --- a/tests/unit/discovery_test.cpp +++ b/tests/unit/discovery_test.cpp @@ -110,7 +110,7 @@ TEST(DiscoveryClientTest, UsesInMemoryCacheWithinTtl) { EXPECT_EQ(calls, 1U); } -TEST(DiscoveryClientTest, FetchExtendedAgentCardUsesExtendedQueryPath) { +TEST(DiscoveryClientTest, FetchExtendedAgentCardUsesSpecEndpoint) { std::string called_url; DiscoveryClient client([&called_url](std::string_view url) -> a2a::core::Result { called_url = std::string(url); @@ -122,8 +122,9 @@ TEST(DiscoveryClientTest, FetchExtendedAgentCardUsesExtendedQueryPath) { const auto fetched = client.FetchExtendedAgentCard("https://agent.example.com/"); ASSERT_TRUE(fetched.ok()) << fetched.error().message(); - EXPECT_EQ(called_url, "https://agent.example.com/.well-known/agent-card.json?view=extended"); + EXPECT_EQ(called_url, "https://agent.example.com/extendedAgentCard"); } + TEST(AgentCardResolverTest, SelectsPreferredThenFallsBack) { lf::a2a::v1::AgentCard card; auto* json_rpc = card.add_supported_interfaces(); diff --git a/tests/unit/grpc_server_transport_test.cpp b/tests/unit/grpc_server_transport_test.cpp index ce7c120..363421b 100644 --- a/tests/unit/grpc_server_transport_test.cpp +++ b/tests/unit/grpc_server_transport_test.cpp @@ -3,6 +3,8 @@ #include "a2a/server/grpc_server_transport.h" +#include "a2a/core/agent_card/agent_card_provider.h" + #if __has_include() #include #define A2A_HAS_SERVER_CONTEXT_TEST_SPOUSE 1 @@ -449,9 +451,17 @@ TEST(GrpcServerTransportTest, PushNotificationRpcsReturnUnimplemented) { grpc::StatusCode::UNIMPLEMENTED); } -TEST(GrpcServerTransportTest, GetExtendedAgentCardProvidesCompatibilityDefaults) { +TEST(GrpcServerTransportTest, GetExtendedAgentCardRequiresVersionWhenConfigured) { FakeExecutor executor; - a2a::server::Dispatcher dispatcher(&executor); + lf::a2a::v1::AgentCard extended_card; + extended_card.set_name("Extended Unit Agent"); + extended_card.set_description("Configured extended card"); + extended_card.set_version(std::string(a2a::core::Version::kAgentCardVersion)); + extended_card.add_default_input_modes("text/plain"); + extended_card.add_default_output_modes("text/plain"); + extended_card.mutable_capabilities()->set_streaming(true); + auto provider = std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher(&executor, provider); a2a::server::GrpcServerTransport transport(&dispatcher); grpc::ServerContext context; @@ -460,21 +470,36 @@ TEST(GrpcServerTransportTest, GetExtendedAgentCardProvidesCompatibilityDefaults) auto* service = static_cast(&transport); const auto status = service->GetExtendedAgentCard(&context, &request, &response); - EXPECT_TRUE(status.ok()); - EXPECT_EQ(response.name(), "A2A C++ SDK Agent"); - EXPECT_EQ(response.description(), "Default agent card for compatibility checks"); - EXPECT_EQ(response.version(), a2a::core::Version::kAgentCardVersion); - ASSERT_EQ(response.default_input_modes_size(), 1); - ASSERT_EQ(response.default_output_modes_size(), 1); - EXPECT_EQ(response.default_input_modes(0), "text/plain"); - EXPECT_EQ(response.default_output_modes(0), "text/plain"); - EXPECT_FALSE(response.capabilities().push_notifications()); - EXPECT_TRUE(response.capabilities().streaming()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::UNIMPLEMENTED); } -TEST(GrpcServerTransportTest, GetExtendedAgentCardAdvertisesRequiredExtensions) { +#if A2A_HAS_SERVER_CONTEXT_TEST_SPOUSE +TEST(GrpcServerTransportTest, GetExtendedAgentCardReturnsNotConfiguredWhenMissingProvider) { FakeExecutor executor; a2a::server::Dispatcher dispatcher(&executor); + a2a::server::GrpcServerTransport transport(&dispatcher); + + grpc::ServerContext context; + grpc::testing::ServerContextTestSpouse spouse(&context); + AddValidVersionHeader(spouse); + lf::a2a::v1::GetExtendedAgentCardRequest request; + lf::a2a::v1::AgentCard response; + + auto* service = static_cast(&transport); + const auto status = service->GetExtendedAgentCard(&context, &request, &response); + + EXPECT_EQ(status.error_code(), grpc::StatusCode::FAILED_PRECONDITION); +} +#endif // A2A_HAS_SERVER_CONTEXT_TEST_SPOUSE + +TEST(GrpcServerTransportTest, GetExtendedAgentCardValidatesRequiredExtensions) { + FakeExecutor executor; + lf::a2a::v1::AgentCard extended_card; + extended_card.set_name("Extended Unit Agent"); + extended_card.set_description("Configured extended card"); + extended_card.set_version(std::string(a2a::core::Version::kAgentCardVersion)); + auto provider = std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher(&executor, provider); a2a::server::GrpcServerTransport transport(&dispatcher, {.required_extensions = {std::string(kRequiredExtension)}}); grpc::ServerContext context; @@ -484,10 +509,7 @@ TEST(GrpcServerTransportTest, GetExtendedAgentCardAdvertisesRequiredExtensions) auto* service = static_cast(&transport); const auto status = service->GetExtendedAgentCard(&context, &request, &response); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(response.capabilities().extensions_size(), 1); - EXPECT_EQ(response.capabilities().extensions(0).uri(), std::string(kRequiredExtension)); - EXPECT_TRUE(response.capabilities().extensions(0).required()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::UNIMPLEMENTED); } TEST(GrpcServerTransportTest, ReturnsInternalWhenDispatcherMissing) { diff --git a/tests/unit/http_client_test.cpp b/tests/unit/http_client_test.cpp index ed380ee..660ce1a 100644 --- a/tests/unit/http_client_test.cpp +++ b/tests/unit/http_client_test.cpp @@ -26,6 +26,7 @@ #include "a2a/client/http_json_transport.h" #include "a2a/client/json_rpc_transport.h" #include "a2a/core/http_constants.h" +#include "a2a/core/non_copyable.h" namespace { @@ -80,7 +81,7 @@ std::string BuildLoopbackUrl(int port, std::string_view scheme = a2a::core::http return url; } -class LoopbackHttpServer final { +class LoopbackHttpServer final : private a2a::core::NonCopyable { public: explicit LoopbackHttpServer(std::string response) : response_(std::move(response)) { fd_ = ::socket(AF_INET, SOCK_STREAM, 0); @@ -103,9 +104,6 @@ class LoopbackHttpServer final { worker_ = std::thread([this] { AcceptOnce(); }); } - LoopbackHttpServer(const LoopbackHttpServer&) = delete; - LoopbackHttpServer& operator=(const LoopbackHttpServer&) = delete; - ~LoopbackHttpServer() { if (worker_.joinable()) { worker_.join(); diff --git a/tests/unit/json_rpc_server_transport_test.cpp b/tests/unit/json_rpc_server_transport_test.cpp index 88a2e93..f15593c 100644 --- a/tests/unit/json_rpc_server_transport_test.cpp +++ b/tests/unit/json_rpc_server_transport_test.cpp @@ -13,6 +13,7 @@ #include #include +#include "a2a/core/agent_card/agent_card_provider.h" #include "a2a/core/protocol_errors.h" #include "a2a/core/protojson.h" #include "a2a/core/task_states.h" @@ -28,6 +29,7 @@ constexpr std::string_view kA2aVersionValue = "1.0"; constexpr std::string_view kSseHeartbeat = ": keep-alive\n\n"; constexpr std::string_view kHeartbeatSubscribeRequestBody = R"({"jsonrpc":"2.0","id":"req-sub-heartbeat","method":"a2a.subscribeToTask","params":{"id":"task-sub"}})"; +constexpr std::string_view kTenantId = "tenant-1"; class RecordingHttpTransport final : public a2a::server::HttpByteTransport { public: @@ -230,6 +232,23 @@ a2a::server::HttpServerRequest BuildJsonRpcRequest(std::string body) { .remote_address = {}}; } +class RecordingAgentCardProvider final : public a2a::core::AgentCardProvider { + public: + explicit RecordingAgentCardProvider(lf::a2a::v1::AgentCard extended_agent_card) + : extended_agent_card_(std::move(extended_agent_card)) {} + + a2a::core::Result GetExtendedAgentCard( + const a2a::core::AgentCardRequestContext& context) const override { + observed_tenant = context.tenant; + return extended_agent_card_; + } + + mutable std::optional observed_tenant; + + private: + lf::a2a::v1::AgentCard extended_agent_card_; +}; + TEST(JsonRpcServerTransportTest, HandlesSendMessageEnvelope) { JsonRpcEchoExecutor executor; a2a::server::Dispatcher dispatcher(&executor); @@ -748,4 +767,60 @@ TEST(JsonRpcServerTransportTest, HandlesPushNotificationConfigMethods) { EXPECT_NE(delete_response.value().body.find("\"result\""), std::string::npos); } +TEST(JsonRpcServerTransportTest, GetExtendedAgentCardReturnsConfiguredCard) { + constexpr std::string_view kRequestBody = + R"({"jsonrpc":"2.0","id":"req-card","method":"GetExtendedAgentCard","params":{}})"; + constexpr std::string_view kExpectedNameJson = R"("name":"Extended JSON-RPC Agent")"; + JsonRpcEchoExecutor executor; + lf::a2a::v1::AgentCard extended_card; + extended_card.set_name("Extended JSON-RPC Agent"); + extended_card.set_description("Configured extended card"); + extended_card.set_version("1.0.0"); + auto provider = std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher(&executor, provider); + a2a::server::JsonRpcServerTransport server(&dispatcher, {.rpc_path = "/rpc"}); + + const auto response = server.Handle(BuildJsonRpcRequest(std::string(kRequestBody))); + + ASSERT_TRUE(response.ok()); + EXPECT_EQ(response.value().status_code, kHttpOk); + EXPECT_NE(response.value().body.find(kExpectedNameJson), std::string::npos); +} + +TEST(JsonRpcServerTransportTest, GetExtendedAgentCardPropagatesTenantParam) { + constexpr std::string_view kRequestBody = + R"({"jsonrpc":"2.0","id":"req-card-tenant","method":"GetExtendedAgentCard","params":{"tenant":"tenant-1"}})"; + constexpr std::string_view kExpectedNameJson = R"("name":"Tenant JSON-RPC Agent")"; + JsonRpcEchoExecutor executor; + lf::a2a::v1::AgentCard extended_card; + extended_card.set_name("Tenant JSON-RPC Agent"); + extended_card.set_description("Configured tenant extended card"); + extended_card.set_version("1.0.0"); + auto provider = std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher(&executor, provider); + a2a::server::JsonRpcServerTransport server(&dispatcher, {.rpc_path = "/rpc"}); + + const auto response = server.Handle(BuildJsonRpcRequest(std::string(kRequestBody))); + + ASSERT_TRUE(response.ok()); + EXPECT_EQ(response.value().status_code, kHttpOk); + EXPECT_NE(response.value().body.find(kExpectedNameJson), std::string::npos); + EXPECT_EQ(provider->observed_tenant, std::optional(std::string(kTenantId))); +} + +TEST(JsonRpcServerTransportTest, GetExtendedAgentCardReturnsNotConfiguredErrorWhenMissing) { + constexpr std::string_view kRequestBody = + R"({"jsonrpc":"2.0","id":"req-card","method":"GetExtendedAgentCard","params":{}})"; + constexpr std::string_view kExpectedCodeJson = R"("code":-32007)"; + JsonRpcEchoExecutor executor; + a2a::server::Dispatcher dispatcher(&executor); + a2a::server::JsonRpcServerTransport server(&dispatcher, {.rpc_path = "/rpc"}); + + const auto response = server.Handle(BuildJsonRpcRequest(std::string(kRequestBody))); + + ASSERT_TRUE(response.ok()); + EXPECT_EQ(response.value().status_code, kHttpOk); + EXPECT_NE(response.value().body.find(kExpectedCodeJson), std::string::npos); +} + } // namespace diff --git a/tests/unit/non_copyable_test.cpp b/tests/unit/non_copyable_test.cpp new file mode 100644 index 0000000..7ed2a04 --- /dev/null +++ b/tests/unit/non_copyable_test.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2026 Vladimir Pavlov (https://github.com/MisterVVP) + +#include "a2a/core/non_copyable.h" + +#include + +#include + +namespace { + +class MoveOnlyDerived final : private a2a::core::NonCopyable { + public: + MoveOnlyDerived() = default; + MoveOnlyDerived(MoveOnlyDerived&&) noexcept = default; + MoveOnlyDerived& operator=(MoveOnlyDerived&&) noexcept = default; +}; + +class ImmovableDerived final : private a2a::core::NonCopyableOrMovable { + public: + ImmovableDerived() = default; +}; + +static_assert(!std::is_copy_constructible_v); +static_assert(!std::is_copy_assignable_v); +static_assert(std::is_move_constructible_v); +static_assert(std::is_move_assignable_v); + +static_assert(!std::is_copy_constructible_v); +static_assert(!std::is_copy_assignable_v); +static_assert(!std::is_move_constructible_v); +static_assert(!std::is_move_assignable_v); + +} // namespace + +TEST(NonCopyableTest, TraitsAreEnforcedAtCompileTime) { SUCCEED(); } diff --git a/tests/unit/push_notification_delivery_test.cpp b/tests/unit/push_notification_delivery_test.cpp index bf83243..50ef549 100644 --- a/tests/unit/push_notification_delivery_test.cpp +++ b/tests/unit/push_notification_delivery_test.cpp @@ -21,6 +21,7 @@ #include #include "a2a/core/http_constants.h" +#include "a2a/core/non_copyable.h" namespace { @@ -74,7 +75,7 @@ std::string BuildLoopbackUrl(int port, std::string_view scheme = a2a::core::http return url; } -class LoopbackHttpServer final { +class LoopbackHttpServer final : private a2a::core::NonCopyable { public: explicit LoopbackHttpServer(std::string response) : response_(std::move(response)) { fd_ = ::socket(AF_INET, SOCK_STREAM, 0); @@ -97,9 +98,6 @@ class LoopbackHttpServer final { worker_ = std::thread([this] { AcceptOnce(); }); } - LoopbackHttpServer(const LoopbackHttpServer&) = delete; - LoopbackHttpServer& operator=(const LoopbackHttpServer&) = delete; - ~LoopbackHttpServer() { if (worker_.joinable()) { worker_.join(); diff --git a/tests/unit/rest_server_transport_test.cpp b/tests/unit/rest_server_transport_test.cpp index 77d83d9..38f2a90 100644 --- a/tests/unit/rest_server_transport_test.cpp +++ b/tests/unit/rest_server_transport_test.cpp @@ -6,11 +6,14 @@ #include #include +#include #include #include #include +#include "a2a/core/agent_card/agent_card_provider.h" #include "a2a/core/protocol_bindings.h" +#include "a2a/core/protocol_errors.h" #include "a2a/core/protojson.h" #include "a2a/core/version.h" @@ -18,6 +21,7 @@ namespace { constexpr std::time_t kAgentCardLastModifiedUnix = 1704067200; constexpr std::string_view kRequiredExtension = "urn:a2a:tck:required-extension"; +constexpr std::string_view kTenantId = "tenant-1"; class EchoExecutor final : public a2a::server::AgentExecutor { public: @@ -70,6 +74,32 @@ class EchoExecutor final : public a2a::server::AgentExecutor { std::string observed_api_key; }; +class RecordingAgentCardProvider final : public a2a::core::AgentCardProvider { + public: + explicit RecordingAgentCardProvider(lf::a2a::v1::AgentCard extended_agent_card) + : extended_agent_card_(std::move(extended_agent_card)) {} + + [[nodiscard]] a2a::core::Result GetExtendedAgentCard( + const a2a::core::AgentCardRequestContext& context) const override { + observed_tenant = context.tenant; + return extended_agent_card_; + } + + mutable std::optional observed_tenant; + + private: + lf::a2a::v1::AgentCard extended_agent_card_; +}; + +class FailingAgentCardProvider final : public a2a::core::AgentCardProvider { + public: + [[nodiscard]] a2a::core::Result GetExtendedAgentCard( + const a2a::core::AgentCardRequestContext& context) const override { + (void)context; + return a2a::core::protocol_errors::InvalidAgentResponse("extended card provider failed"); + } +}; + lf::a2a::v1::AgentCard BuildCard() { lf::a2a::v1::AgentCard card; card.set_name("Unit Agent"); @@ -271,4 +301,119 @@ TEST(RestServerTransportTest, DoesNotEchoActivatedExtensionsWhenRequiredExtensio EXPECT_FALSE(response.value().headers.contains("A2A-Extensions")); } +TEST(RestServerTransportTest, ServesConfiguredExtendedAgentCard) { + constexpr std::string_view kExtendedName = "Extended REST Agent"; + EchoExecutor executor; + auto extended_card = BuildCard(); + extended_card.set_name(std::string(kExtendedName)); + auto provider = std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher(&executor, provider); + a2a::server::RestServerTransport server(&dispatcher, BuildCard(), RestOptions("/a2a")); + + const auto response = server.Handle({.method = "GET", + .target = "/extendedAgentCard", + .headers = {{"A2A-Version", "1.0"}}, + .body = {}, + .remote_address = {}}); + + ASSERT_TRUE(response.ok()); + EXPECT_EQ(response.value().status_code, 200); + EXPECT_NE(response.value().body.find(kExtendedName), std::string::npos); +} + +TEST(RestServerTransportTest, ServesConfiguredExtendedAgentCardFromDiscoveryView) { + constexpr std::string_view kExtendedName = "Extended Discovery Agent"; + EchoExecutor executor; + auto extended_card = BuildCard(); + extended_card.set_name(std::string(kExtendedName)); + auto provider = std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher(&executor, provider); + a2a::server::RestServerTransport server(&dispatcher, BuildCard(), RestOptions("/a2a")); + + const auto response = server.Handle({.method = "GET", + .target = "/.well-known/agent-card.json?view=extended", + .headers = {{"A2A-Version", "1.0"}}, + .body = {}, + .remote_address = {}}); + + ASSERT_TRUE(response.ok()); + EXPECT_EQ(response.value().status_code, 200); + EXPECT_NE(response.value().body.find(kExtendedName), std::string::npos); +} + +TEST(RestServerTransportTest, ServesConfiguredExtendedAgentCardUnderRestBasePath) { + constexpr std::string_view kExtendedName = "Extended REST Base Agent"; + EchoExecutor executor; + auto extended_card = BuildCard(); + extended_card.set_name(std::string(kExtendedName)); + auto provider = std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher(&executor, provider); + a2a::server::RestServerTransport server(&dispatcher, BuildCard(), RestOptions("/a2a")); + + const auto response = server.Handle({.method = "GET", + .target = "/a2a/extendedAgentCard", + .headers = {{"A2A-Version", "1.0"}}, + .body = {}, + .remote_address = {}}); + + ASSERT_TRUE(response.ok()); + EXPECT_EQ(response.value().status_code, 200); + EXPECT_NE(response.value().body.find(kExtendedName), std::string::npos); +} + +TEST(RestServerTransportTest, PropagatesTenantFromExtendedAgentCardPath) { + EchoExecutor executor; + auto extended_card = BuildCard(); + auto provider = std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher(&executor, provider); + a2a::server::RestServerTransport server(&dispatcher, BuildCard(), RestOptions("/a2a")); + + const auto response = server.Handle({.method = "GET", + .target = "/a2a/tenant-1/extendedAgentCard", + .headers = {{"A2A-Version", "1.0"}}, + .body = {}, + .remote_address = {}}); + + ASSERT_TRUE(response.ok()); + EXPECT_EQ(response.value().status_code, 200); + EXPECT_EQ(provider->observed_tenant, std::optional(std::string(kTenantId))); +} + +TEST(RestServerTransportTest, ExtendedAgentCardReturnsNotConfiguredWhenMissing) { + constexpr std::string_view kErrorReason = "EXTENDED_AGENT_CARD_NOT_CONFIGURED"; + EchoExecutor executor; + a2a::server::Dispatcher dispatcher(&executor); + a2a::server::RestServerTransport server(&dispatcher, BuildCard(), RestOptions("/a2a")); + + const auto response = server.Handle({.method = "GET", + .target = "/extendedAgentCard", + .headers = {{"A2A-Version", "1.0"}}, + .body = {}, + .remote_address = {}}); + + ASSERT_TRUE(response.ok()); + EXPECT_EQ(response.value().status_code, 400); + EXPECT_NE(response.value().body.find(kErrorReason), std::string::npos); +} + +TEST(RestServerTransportTest, ExtendedAgentCardPreservesProviderErrorReason) { + constexpr std::string_view kExpectedReason = "INVALID_AGENT_RESPONSE"; + constexpr std::string_view kUnexpectedReason = "EXTENDED_AGENT_CARD_NOT_CONFIGURED"; + EchoExecutor executor; + auto provider = std::make_shared(); + a2a::server::Dispatcher dispatcher(&executor, provider); + a2a::server::RestServerTransport server(&dispatcher, BuildCard(), RestOptions("/a2a")); + + const auto response = server.Handle({.method = "GET", + .target = "/extendedAgentCard", + .headers = {{"A2A-Version", "1.0"}}, + .body = {}, + .remote_address = {}}); + + ASSERT_TRUE(response.ok()); + EXPECT_EQ(response.value().status_code, 502); + EXPECT_NE(response.value().body.find(kExpectedReason), std::string::npos); + EXPECT_EQ(response.value().body.find(kUnexpectedReason), std::string::npos); +} + } // namespace diff --git a/tests/unit/server_dispatcher_test.cpp b/tests/unit/server_dispatcher_test.cpp index cf53ba9..6cf0604 100644 --- a/tests/unit/server_dispatcher_test.cpp +++ b/tests/unit/server_dispatcher_test.cpp @@ -10,7 +10,9 @@ #include #include +#include "a2a/core/agent_card/agent_card_provider.h" #include "a2a/core/error.h" +#include "a2a/core/protocol_codes.h" #include "a2a/server/agent_executor.h" #include "a2a/server/dispatch_types.h" #include "a2a/server/dispatcher.h" @@ -174,6 +176,40 @@ TEST(ServerDispatcherTest, DispatchesAllSupportedOperations) { ASSERT_TRUE(std::holds_alternative(cancel_result.value().payload())); } +TEST(ServerDispatcherTest, DispatchesGetExtendedAgentCardThroughProvider) { + constexpr std::string_view kExtendedName = "Extended Dispatcher Agent"; + FakeExecutor executor; + lf::a2a::v1::AgentCard extended_card; + extended_card.set_name(std::string(kExtendedName)); + auto provider = std::make_shared(extended_card); + a2a::server::Dispatcher dispatcher(&executor, provider); + a2a::server::RequestContext context; + context.auth_metadata.emplace("bearer_token", "token"); + + const a2a::server::DispatchRequest dispatch{.operation = a2a::server::DispatcherOperation::kGetExtendedAgentCard, + .payload = lf::a2a::v1::GetExtendedAgentCardRequest{}}; + const auto result = dispatcher.Dispatch(dispatch, context); + + ASSERT_TRUE(result.ok()); + const auto* payload = std::get_if(&result.value().payload()); + ASSERT_NE(payload, nullptr); + EXPECT_EQ(payload->name(), kExtendedName); +} + +TEST(ServerDispatcherTest, GetExtendedAgentCardReturnsNotConfiguredWithoutProvider) { + FakeExecutor executor; + a2a::server::Dispatcher dispatcher(&executor); + a2a::server::RequestContext context; + + const a2a::server::DispatchRequest dispatch{.operation = a2a::server::DispatcherOperation::kGetExtendedAgentCard, + .payload = lf::a2a::v1::GetExtendedAgentCardRequest{}}; + const auto result = dispatcher.Dispatch(dispatch, context); + + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.error().protocol_code(), + std::optional(std::string(a2a::core::protocol_codes::kExtendedAgentCardNotConfigured))); +} + TEST(ServerDispatcherTest, ReturnsValidationErrorForPayloadMismatch) { FakeExecutor executor; a2a::server::Dispatcher dispatcher(&executor);