diff --git a/apisix/plugins/ai-providers/base.lua b/apisix/plugins/ai-providers/base.lua index 944be263296a..2607f710f684 100644 --- a/apisix/plugins/ai-providers/base.lua +++ b/apisix/plugins/ai-providers/base.lua @@ -34,7 +34,7 @@ local transport_http = require("apisix.plugins.ai-transport.http") local transport_auth = require("apisix.plugins.ai-transport.auth") local log_sanitize = require("apisix.utils.log-sanitize") local protocols = require("apisix.plugins.ai-protocols") -local deep_merge = require("apisix.plugins.ai-proxy.merge").deep_merge +local ai_proxy_base = require("apisix.plugins.ai-proxy.base") local ngx = ngx local ngx_now = ngx.now local tonumber = tonumber @@ -198,33 +198,11 @@ function _M.build_request(self, ctx, conf, request_body, opts) or opts.target_host or self.host, } - -- Inject model options (flat overwrite) - if opts.model_options then - for opt, val in pairs(opts.model_options) do - if request_body[opt] ~= nil then - core.log.info("model_options overwriting request field '", opt, "'") - end - request_body[opt] = val - end - end - - -- Apply llm_options via provider capability hook (always force-overwrites) - if opts.override_llm_options then - local cap = self.capabilities and self.capabilities[ctx.ai_target_protocol] - if cap and cap.rewrite_request_body then - cap.rewrite_request_body(request_body, opts.override_llm_options, true) - end - end - - -- Apply per-target-protocol request body override (deep merge) - if opts.request_body_override_map then - local patch = opts.request_body_override_map[ctx.ai_target_protocol] - if patch then - core.log.info("applying request_body override for target protocol '", - ctx.ai_target_protocol, "'") - request_body = deep_merge(request_body, patch, opts.request_body_force_override) - end - end + -- Apply instance-level overrides (options + override.{llm_options, request_body}). + -- Runs after the converter so request_body is in target-protocol shape, and the + -- request_body[target_protocol] patch applies to the post-conversion body. + request_body = ai_proxy_base.apply_instance_overrides( + request_body, opts.ai_instance, self, ctx.ai_target_protocol) params.body = request_body if self.remove_model then diff --git a/apisix/plugins/ai-proxy/base.lua b/apisix/plugins/ai-proxy/base.lua index 745f3b7a5c42..d13f89ff252a 100644 --- a/apisix/plugins/ai-proxy/base.lua +++ b/apisix/plugins/ai-proxy/base.lua @@ -23,11 +23,13 @@ local pcall = pcall local pairs = pairs local type = type local table = table +local tostring = tostring local exporter = require("apisix.plugins.prometheus.exporter") local protocols = require("apisix.plugins.ai-protocols") local transport_http = require("apisix.plugins.ai-transport.http") local log_sanitize = require("apisix.utils.log-sanitize") local apisix_upstream = require("resty.apisix.upstream") +local deep_merge = require("apisix.plugins.ai-proxy.merge").deep_merge local _M = {} @@ -99,6 +101,106 @@ function _M.detect_request_type(ctx) end +-- Apply ai_instance overrides to request_body and return the effective body +-- that would be sent upstream. Precedence: options (flat overwrite) -> +-- override.llm_options (provider capability rewrite) -> +-- override.request_body[target_protocol] (deep merge). Mutates request_body +-- in place. +function _M.apply_instance_overrides(request_body, ai_instance, ai_provider, target_protocol) + local model_options = ai_instance and ai_instance.options + if model_options then + for opt, val in pairs(model_options) do + if request_body[opt] ~= nil then + core.log.info("model_options overwriting request field '", opt, "'") + end + request_body[opt] = val + end + end + + local override_llm_options = + core.table.try_read_attr(ai_instance, "override", "llm_options") + if override_llm_options then + local caps = ai_provider and ai_provider.capabilities + local cap = caps and caps[target_protocol] + if cap and cap.rewrite_request_body then + cap.rewrite_request_body(request_body, override_llm_options, true) + end + end + + local request_body_override_map = + core.table.try_read_attr(ai_instance, "override", "request_body") + if request_body_override_map then + local patch = request_body_override_map[target_protocol] + if patch then + core.log.info("applying request_body override for target protocol '", + target_protocol, "'") + local force = core.table.try_read_attr(ai_instance, "override", + "request_body_force_override") + request_body = deep_merge(request_body, patch, force) + end + end + + return request_body +end + + +-- Resolve (target_protocol, converter) from ctx.ai_client_protocol + provider +-- capabilities. Mirrors before_proxy's routing so peer plugins running in +-- access phase (before before_proxy sets ctx.ai_target_protocol / +-- ctx.ai_converter) can compute them themselves. +local function resolve_target_protocol(ctx, ai_provider) + if ctx.ai_target_protocol then + return ctx.ai_target_protocol, ctx.ai_converter + end + local client_protocol = ctx.ai_client_protocol + if not client_protocol then + return nil, nil + end + local caps = ai_provider and ai_provider.capabilities or {} + if caps[client_protocol] then + return client_protocol, nil + end + if client_protocol == "passthrough" then + return "passthrough", nil + end + local converter, target = protocols.find_converter(client_protocol, caps) + return target, converter +end + + +-- Return the request body as it would be sent upstream for the current ctx. +-- Reads the parsed body, applies the converter (if the client protocol differs +-- from the provider's target protocol), then applies apply_instance_overrides. +-- The result matches what build_request would send upstream. Pure: no HTTP, +-- no signing, no upstream call. Requires ctx.picked_ai_instance and +-- ctx.ai_client_protocol (both set by ai-proxy access phase). +function _M.effective_request_for_cache(ctx) + local request_body, err = core.request.get_json_request_body_table() + if not request_body then + return nil, err + end + local ai_instance = ctx and ctx.picked_ai_instance + if not ai_instance then + return nil, "no picked_ai_instance on ctx" + end + local ok, ai_provider = pcall(require, + "apisix.plugins.ai-providers." .. ai_instance.provider) + if not ok then + return nil, "failed to load provider: " .. tostring(ai_instance.provider) + end + local target_protocol, converter = resolve_target_protocol(ctx, ai_provider) + if converter and converter.convert_request then + local converted, conv_err = converter.convert_request(request_body, ctx) + if not converted then + return nil, conv_err or "converter failed" + end + request_body = converted + end + return _M.apply_instance_overrides( + request_body, ai_instance, ai_provider, target_protocol) +end + + -- Execute the AI proxy pipeline: -- 1. Validate request -- 2. Route client protocol to driver capability (passthrough / convert / error) @@ -124,15 +226,9 @@ function _M.before_proxy(conf, ctx, on_error) local extra_opts = { name = ai_instance.name, endpoint = core.table.try_read_attr(ai_instance, "override", "endpoint"), - model_options = ai_instance.options, conf = ai_instance.provider_conf or {}, auth = ai_instance.auth, - override_llm_options = - core.table.try_read_attr(ai_instance, "override", "llm_options"), - request_body_override_map = - core.table.try_read_attr(ai_instance, "override", "request_body"), - request_body_force_override = - core.table.try_read_attr(ai_instance, "override", "request_body_force_override"), + ai_instance = ai_instance, } -- Step 1: Route client protocol to driver capability local client_protocol = ctx.ai_client_protocol diff --git a/apisix/plugins/ai-request-rewrite.lua b/apisix/plugins/ai-request-rewrite.lua index 900f700836e9..31712caa5771 100644 --- a/apisix/plugins/ai-request-rewrite.lua +++ b/apisix/plugins/ai-request-rewrite.lua @@ -122,7 +122,7 @@ local function request_to_llm(conf, request_table, ctx, target_path) local extra_opts = { endpoint = core.table.try_read_attr(conf, "override", "endpoint"), auth = conf.auth, - model_options = conf.options, + ai_instance = conf, target_path = target_path, } ctx.llm_request_start_time = ngx.now() diff --git a/t/plugin/ai-proxy-request-body-override.t b/t/plugin/ai-proxy-request-body-override.t index 088123bebac2..4ca3aae2b535 100644 --- a/t/plugin/ai-proxy-request-body-override.t +++ b/t/plugin/ai-proxy-request-body-override.t @@ -819,3 +819,142 @@ max_tokens=321 } --- response_body max_completion_tokens=200 temperature=0.5 + + + +=== TEST 17: effective_request_for_cache returns post-override body +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + -- ai-proxy applies overrides; serverless-post-function (priority -2000) + -- runs after ai-proxy access (priority 1040) in the access phase, invokes + -- the helper, and logs its output. The test asserts BOTH the + -- upstream-received body AND the helper output reflect the same + -- post-override view. + local code = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/chat", + "plugins": { + "ai-proxy": { + "provider": "openai", + "auth": { "header": { "Authorization": "Bearer t" } }, + "options": { "model": "options-model" }, + "override": { + "endpoint": "http://localhost:6732", + "request_body": { + "openai-chat": { "temperature": 0.42 } + } + }, + "ssl_verify": false + }, + "serverless-post-function": { + "functions": ["return function(_, ctx) + local b = require('apisix.plugins.ai-proxy.base') + local cjson = require('cjson.safe') + local body, err = b.effective_request_for_cache(ctx) + ngx.log(ngx.WARN, 'EFFECTIVE_BODY=', + body and cjson.encode(body) + or ('ERR:' .. tostring(err))) + end"] + } + } + }]] + ) + if code >= 300 then ngx.status = code; return end + + local http = require("resty.http").new() + local res = assert(http:request_uri("http://127.0.0.1:" .. ngx.var.server_port .. "/chat", { + method = "POST", + body = '{"messages":[{"role":"user","content":"hi"}],"model":"client-model"}', + headers = { ["Content-Type"] = "application/json" }, + })) + local cjson = require("cjson.safe") + local body = cjson.decode(res.body) + local echoed = cjson.decode(body.choices[1].message.content) + ngx.say("upstream model=", echoed.model, + " upstream temperature=", echoed.temperature) + } + } +--- response_body +upstream model=options-model upstream temperature=0.42 +--- error_log eval +[ + qr/EFFECTIVE_BODY=.*"model":"options-model"/, + qr/EFFECTIVE_BODY=.*"temperature":0\.42/, +] + + + +=== TEST 18: effective_request_for_cache applies the converter (anthropic-messages -> openai-chat) +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + -- Client sends anthropic-messages format to an openai provider, which + -- speaks openai-chat natively. The converter translates the body and + -- override.request_body.openai-chat then applies. The helper should + -- mirror this: convert first, then apply overrides. Distinctive + -- post-converter marker: max_tokens (anthropic) becomes + -- max_completion_tokens (openai-chat) and the original max_tokens + -- is stripped by the converter ("never forward max_tokens"). + local code = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/v1/messages", + "plugins": { + "ai-proxy": { + "provider": "openai", + "auth": { "header": { "Authorization": "Bearer t" } }, + "override": { + "endpoint": "http://localhost:6732", + "request_body": { + "openai-chat": { "temperature": 0.42 } + } + }, + "ssl_verify": false + }, + "serverless-post-function": { + "functions": ["return function(_, ctx) + local b = require('apisix.plugins.ai-proxy.base') + local cjson = require('cjson.safe') + local body, err = b.effective_request_for_cache(ctx) + ngx.log(ngx.WARN, 'EFFECTIVE_BODY=', + body and cjson.encode(body) + or ('ERR:' .. tostring(err))) + end"] + } + } + }]] + ) + if code >= 300 then ngx.status = code; return end + + local http = require("resty.http").new() + local res = assert(http:request_uri("http://127.0.0.1:" .. ngx.var.server_port .. "/v1/messages", { + method = "POST", + body = '{"model":"claude-3","max_tokens":10,"messages":[{"role":"user","content":"hi"}]}', + headers = { ["Content-Type"] = "application/json" }, + })) + ngx.status = res.status + -- The /v1/messages stub echoes the raw upstream body as the message + -- text; ai-proxy converts the openai-chat response back to + -- anthropic-messages, so body.content[1].text is the post-converter + -- post-override body the upstream actually received. + local cjson = require("cjson.safe") + local body = cjson.decode(res.body) + local echoed = cjson.decode(body.content[1].text) + ngx.say("upstream max_completion_tokens=", echoed.max_completion_tokens, + " upstream temperature=", echoed.temperature, + " upstream max_tokens=", tostring(echoed.max_tokens)) + } + } +--- response_body +upstream max_completion_tokens=10 upstream temperature=0.42 upstream max_tokens=nil +--- error_log eval +[ + qr/EFFECTIVE_BODY=.*"max_completion_tokens":10/, + qr/EFFECTIVE_BODY=.*"temperature":0\.42/, +] +--- no_error_log eval +qr/EFFECTIVE_BODY=.*"max_tokens":10/