Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions src/JsonRpc/MessageFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,24 @@ final class MessageFactory
Schema\Request\SetLogLevelRequest::class,
];

/**
* Upper bound on the number of messages accepted in a single batch, guarding
* against amplification where one small request expands into many operations.
*/
public const DEFAULT_MAX_BATCH_SIZE = 100;

/**
* @param list<class-string<Request>|class-string<Notification>> $registeredMessages
* @param int $maxBatchSize Maximum number of messages accepted in a single JSON-RPC batch
*/
public function __construct(
private readonly array $registeredMessages,
private readonly int $maxBatchSize = self::DEFAULT_MAX_BATCH_SIZE,
) {
if ($this->maxBatchSize < 1) {
throw new InvalidArgumentException('maxBatchSize must be at least 1.');
}

foreach ($this->registeredMessages as $messageClass) {
if (!is_subclass_of($messageClass, Request::class) && !is_subclass_of($messageClass, Notification::class)) {
throw new InvalidArgumentException(\sprintf('Message classes must extend %s or %s.', Request::class, Notification::class));
Expand All @@ -83,9 +95,9 @@ public function __construct(
/**
* Creates a new Factory instance with all the protocol's default messages.
*/
public static function make(): self
public static function make(int $maxBatchSize = self::DEFAULT_MAX_BATCH_SIZE): self
{
return new self(self::REGISTERED_MESSAGES);
return new self(self::REGISTERED_MESSAGES, $maxBatchSize);
}

/**
Expand All @@ -102,13 +114,37 @@ public function create(string $input): array
{
$data = json_decode($input, true, flags: \JSON_THROW_ON_ERROR);

if ('{' === $input[0]) {
$data = [$data];
// A JSON-RPC payload is a single message (JSON object) or a batch (JSON
// array). Anything else (scalar, null) is invalid input rather than a
// parse error, and must not reach the per-message loop below.
if (!\is_array($data)) {
return [new InvalidInputMessageException('A JSON-RPC message must be a JSON object or a batch array.')];
}

// json_decode(assoc: true) maps both objects and arrays to PHP arrays. A
// list is a batch; a non-list (string keys) is a single message. An empty
// array is ambiguous ({} vs []) and invalid as either, so reject it.
if ([] === $data) {
return [new InvalidInputMessageException('A JSON-RPC message must not be empty.')];
}

if (array_is_list($data)) {
if (\count($data) > $this->maxBatchSize) {
return [new InvalidInputMessageException(\sprintf('JSON-RPC batch size %d exceeds the maximum allowed batch size of %d.', \count($data), $this->maxBatchSize))];
}

$batch = $data;
} else {
$batch = [$data];
}

$messages = [];
foreach ($data as $message) {
foreach ($batch as $message) {
try {
if (!\is_array($message)) {
throw new InvalidInputMessageException('A JSON-RPC message must be a JSON object.');
}

$messages[] = $this->createMessage($message);
} catch (InvalidInputMessageException $e) {
$messages[] = $e;
Expand Down
51 changes: 50 additions & 1 deletion src/Server/Transport/StreamableHttpTransport.php
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
use Psr\Http\Message\ResponseInterface;
use Psr\Http\Message\ServerRequestInterface;
use Psr\Http\Message\StreamFactoryInterface;
use Psr\Http\Message\StreamInterface;
use Psr\Http\Server\MiddlewareInterface;
use Psr\Log\LoggerInterface;
use Symfony\Component\Uid\Uuid;
Expand All @@ -36,6 +37,12 @@ class StreamableHttpTransport extends BaseTransport
public const SESSION_HEADER = 'Mcp-Session-Id';
public const PROTOCOL_VERSION_HEADER = 'Mcp-Protocol-Version';

/**
* Upper bound on the request body read for a POST, guarding against memory
* exhaustion from an oversized (or unbounded chunked) payload.
*/
public const DEFAULT_MAX_BODY_BYTES = 4 * 1024 * 1024;

private ResponseFactoryInterface $responseFactory;
private StreamFactoryInterface $streamFactory;

Expand All @@ -54,9 +61,14 @@ public function __construct(
?StreamFactoryInterface $streamFactory = null,
?LoggerInterface $logger = null,
?iterable $middleware = null,
private readonly int $maxBodyBytes = self::DEFAULT_MAX_BODY_BYTES,
) {
parent::__construct($logger);

if ($this->maxBodyBytes < 1) {
throw new InvalidArgumentException('maxBodyBytes must be at least 1.');
}

$this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory();
$this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory();

Expand Down Expand Up @@ -107,7 +119,13 @@ protected function handleOptionsRequest(): ResponseInterface

protected function handlePostRequest(): ResponseInterface
{
$body = $this->request->getBody()->getContents();
$body = $this->readBody($this->request->getBody());
if (null === $body) {
$this->logger->warning('Rejected POST body exceeding the maximum allowed size.', ['limit' => $this->maxBodyBytes]);

return $this->createErrorResponse(Error::forInvalidRequest(\sprintf('Request body exceeds the maximum allowed size of %d bytes.', $this->maxBodyBytes)), 413);
}

$this->handleMessage($body, $this->sessionId);

if (null !== $this->immediateResponse) {
Expand Down Expand Up @@ -273,6 +291,37 @@ protected function createErrorResponse(Error $jsonRpcError, int $statusCode): Re
return $response;
}

/**
* Reads the request body, bounded by {@see self::$maxBodyBytes}.
*
* Returns the body contents, or `null` when the payload exceeds the cap. When
* the stream advertises a size we reject up-front; otherwise (e.g. chunked
* transfer with unknown size) we read incrementally and stop at the cap so an
* unbounded stream cannot exhaust memory.
*/
private function readBody(StreamInterface $body): ?string
{
$size = $body->getSize();
if (null !== $size && $size > $this->maxBodyBytes) {
return null;
}

$contents = '';
while (!$body->eof()) {
$chunk = $body->read(8192);
if ('' === $chunk) {
break;
}

$contents .= $chunk;
if (\strlen($contents) > $this->maxBodyBytes) {
return null;
}
}

return $contents;
}

/**
* @param iterable<MiddlewareInterface> $middleware
*
Expand Down
82 changes: 82 additions & 0 deletions tests/Unit/JsonRpc/MessageFactoryTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

namespace Mcp\Tests\Unit\JsonRpc;

use Mcp\Exception\InvalidArgumentException;
use Mcp\Exception\InvalidInputMessageException;
use Mcp\JsonRpc\MessageFactory;
use Mcp\Schema\JsonRpc\Error;
Expand Down Expand Up @@ -400,4 +401,85 @@ public function testErrorWithInvalidMessageType(): void
$this->assertInstanceOf(InvalidInputMessageException::class, $results[0]);
$this->assertStringContainsString('message', $results[0]->getMessage());
}

public function testScalarJsonIsRejected(): void
{
$results = $this->factory->create('5');

$this->assertCount(1, $results);
$this->assertInstanceOf(InvalidInputMessageException::class, $results[0]);
}

public function testStringJsonIsRejected(): void
{
$results = $this->factory->create('"hello"');

$this->assertCount(1, $results);
$this->assertInstanceOf(InvalidInputMessageException::class, $results[0]);
}

public function testEmptyBatchIsRejected(): void
{
$results = $this->factory->create('[]');

$this->assertCount(1, $results);
$this->assertInstanceOf(InvalidInputMessageException::class, $results[0]);
}

public function testBatchElementMustBeObject(): void
{
$results = $this->factory->create('[1, 2]');

$this->assertCount(2, $results);
$this->assertInstanceOf(InvalidInputMessageException::class, $results[0]);
$this->assertInstanceOf(InvalidInputMessageException::class, $results[1]);
}

public function testLeadingWhitespaceObjectIsParsedAsSingleMessage(): void
{
$json = " \n {\"jsonrpc\": \"2.0\", \"method\": \"ping\", \"id\": 1}";

$results = $this->factory->create($json);

$this->assertCount(1, $results);
$this->assertInstanceOf(PingRequest::class, $results[0]);
}

public function testBatchSizeExceedingMaxIsRejected(): void
{
$factory = new MessageFactory([PingRequest::class], maxBatchSize: 2);
$json = '[
{"jsonrpc": "2.0", "method": "ping", "id": 1},
{"jsonrpc": "2.0", "method": "ping", "id": 2},
{"jsonrpc": "2.0", "method": "ping", "id": 3}
]';

$results = $factory->create($json);

$this->assertCount(1, $results);
$this->assertInstanceOf(InvalidInputMessageException::class, $results[0]);
$this->assertStringContainsString('batch', $results[0]->getMessage());
}

public function testBatchSizeWithinMaxIsAccepted(): void
{
$factory = new MessageFactory([PingRequest::class], maxBatchSize: 2);
$json = '[
{"jsonrpc": "2.0", "method": "ping", "id": 1},
{"jsonrpc": "2.0", "method": "ping", "id": 2}
]';

$results = $factory->create($json);

$this->assertCount(2, $results);
$this->assertInstanceOf(PingRequest::class, $results[0]);
$this->assertInstanceOf(PingRequest::class, $results[1]);
}

public function testNonPositiveMaxBatchSizeThrows(): void
{
$this->expectException(InvalidArgumentException::class);

new MessageFactory([PingRequest::class], maxBatchSize: 0);
}
}
36 changes: 36 additions & 0 deletions tests/Unit/Server/Transport/StreamableHttpTransportTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,42 @@ public function testInvalidMiddlewareEntryThrows(): void
);
}

public function testPostBodyExceedingMaxBytesReturns413(): void
{
$request = $this->factory
->createServerRequest('POST', 'http://localhost/')
->withBody($this->factory->createStream(str_repeat('a', 64)));

// Empty middleware bypasses the default security stack to isolate body-size handling.
$transport = new StreamableHttpTransport($request, $this->factory, $this->factory, null, [], maxBodyBytes: 16);

$response = $transport->listen();

$this->assertSame(413, $response->getStatusCode());
}

public function testPostBodyWithinMaxBytesIsNotRejected(): void
{
$request = $this->factory
->createServerRequest('POST', 'http://localhost/')
->withBody($this->factory->createStream('{}'));

$transport = new StreamableHttpTransport($request, $this->factory, $this->factory, null, [], maxBodyBytes: 1024);

$response = $transport->listen();

$this->assertNotSame(413, $response->getStatusCode());
}

public function testNonPositiveMaxBodyBytesThrows(): void
{
$request = $this->factory->createServerRequest('POST', 'http://localhost/');

$this->expectException(InvalidArgumentException::class);

new StreamableHttpTransport($request, $this->factory, $this->factory, null, [], maxBodyBytes: 0);
}

private function stubAuth401(): MiddlewareInterface
{
return new class($this->factory) implements MiddlewareInterface {
Expand Down