Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 6 additions & 28 deletions apisix/plugins/ai-providers/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ local transport_http = require("apisix.plugins.ai-transport.http")
local transport_auth = require("apisix.plugins.ai-transport.auth")
local log_sanitize = require("apisix.utils.log-sanitize")
local protocols = require("apisix.plugins.ai-protocols")
local deep_merge = require("apisix.plugins.ai-proxy.merge").deep_merge
local ai_proxy_base = require("apisix.plugins.ai-proxy.base")
local ngx = ngx
local ngx_now = ngx.now
local tonumber = tonumber
Expand Down Expand Up @@ -198,33 +198,11 @@ function _M.build_request(self, ctx, conf, request_body, opts)
or opts.target_host or self.host,
}

-- Inject model options (flat overwrite)
if opts.model_options then
for opt, val in pairs(opts.model_options) do
if request_body[opt] ~= nil then
core.log.info("model_options overwriting request field '", opt, "'")
end
request_body[opt] = val
end
end

-- Apply llm_options via provider capability hook (always force-overwrites)
if opts.override_llm_options then
local cap = self.capabilities and self.capabilities[ctx.ai_target_protocol]
if cap and cap.rewrite_request_body then
cap.rewrite_request_body(request_body, opts.override_llm_options, true)
end
end

-- Apply per-target-protocol request body override (deep merge)
if opts.request_body_override_map then
local patch = opts.request_body_override_map[ctx.ai_target_protocol]
if patch then
core.log.info("applying request_body override for target protocol '",
ctx.ai_target_protocol, "'")
request_body = deep_merge(request_body, patch, opts.request_body_force_override)
end
end
-- Apply instance-level overrides (options + override.{llm_options, request_body}).
-- Runs after the converter so request_body is in target-protocol shape, and the
-- request_body[target_protocol] patch applies to the post-conversion body.
request_body = ai_proxy_base.apply_instance_overrides(
request_body, opts.ai_instance, self, ctx.ai_target_protocol)
params.body = request_body

if self.remove_model then
Expand Down
110 changes: 103 additions & 7 deletions apisix/plugins/ai-proxy/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ local pcall = pcall
local pairs = pairs
local type = type
local table = table
local tostring = tostring
local exporter = require("apisix.plugins.prometheus.exporter")
local protocols = require("apisix.plugins.ai-protocols")
local transport_http = require("apisix.plugins.ai-transport.http")
local log_sanitize = require("apisix.utils.log-sanitize")
local apisix_upstream = require("resty.apisix.upstream")
local deep_merge = require("apisix.plugins.ai-proxy.merge").deep_merge

local _M = {}

Expand Down Expand Up @@ -99,6 +101,106 @@ function _M.detect_request_type(ctx)
end


-- Apply ai_instance overrides to request_body and return the effective body
-- that would be sent upstream. Precedence: options (flat overwrite) ->
-- override.llm_options (provider capability rewrite) ->
-- override.request_body[target_protocol] (deep merge). Mutates request_body
-- in place.
function _M.apply_instance_overrides(request_body, ai_instance, ai_provider, target_protocol)
local model_options = ai_instance and ai_instance.options
if model_options then
for opt, val in pairs(model_options) do
if request_body[opt] ~= nil then
core.log.info("model_options overwriting request field '", opt, "'")
end
request_body[opt] = val
end
end

local override_llm_options =
core.table.try_read_attr(ai_instance, "override", "llm_options")
if override_llm_options then
local caps = ai_provider and ai_provider.capabilities
local cap = caps and caps[target_protocol]
if cap and cap.rewrite_request_body then
cap.rewrite_request_body(request_body, override_llm_options, true)
end
end

local request_body_override_map =
core.table.try_read_attr(ai_instance, "override", "request_body")
if request_body_override_map then
local patch = request_body_override_map[target_protocol]
if patch then
core.log.info("applying request_body override for target protocol '",
target_protocol, "'")
local force = core.table.try_read_attr(ai_instance, "override",
"request_body_force_override")
request_body = deep_merge(request_body, patch, force)
end
end

return request_body
end


-- Resolve (target_protocol, converter) from ctx.ai_client_protocol + provider
-- capabilities. Mirrors before_proxy's routing so peer plugins running in
-- access phase (before before_proxy sets ctx.ai_target_protocol /
-- ctx.ai_converter) can compute them themselves.
local function resolve_target_protocol(ctx, ai_provider)
if ctx.ai_target_protocol then
return ctx.ai_target_protocol, ctx.ai_converter
end
local client_protocol = ctx.ai_client_protocol
if not client_protocol then
return nil, nil
end
local caps = ai_provider and ai_provider.capabilities or {}
if caps[client_protocol] then
return client_protocol, nil
end
if client_protocol == "passthrough" then
return "passthrough", nil
end
local converter, target = protocols.find_converter(client_protocol, caps)
return target, converter
end


-- Return the request body as it would be sent upstream for the current ctx.
-- Reads the parsed body, applies the converter (if the client protocol differs
-- from the provider's target protocol), then applies apply_instance_overrides.
-- The result matches what build_request would send upstream. Pure: no HTTP,
-- no signing, no upstream call. Requires ctx.picked_ai_instance and
-- ctx.ai_client_protocol (both set by ai-proxy access phase).
function _M.effective_request_for_cache(ctx)
local request_body, err = core.request.get_json_request_body_table()
if not request_body then
return nil, err
end
local ai_instance = ctx and ctx.picked_ai_instance
if not ai_instance then
return nil, "no picked_ai_instance on ctx"
end
local ok, ai_provider = pcall(require,
"apisix.plugins.ai-providers." .. ai_instance.provider)
if not ok then
return nil, "failed to load provider: " .. tostring(ai_instance.provider)
end
local target_protocol, converter = resolve_target_protocol(ctx, ai_provider)
Comment on lines +185 to +191
if converter and converter.convert_request then
local converted, conv_err = converter.convert_request(request_body, ctx)
if not converted then
return nil, conv_err or "converter failed"
end
request_body = converted
end
return _M.apply_instance_overrides(
request_body, ai_instance, ai_provider, target_protocol)
Comment on lines +199 to +200
end


-- Execute the AI proxy pipeline:
-- 1. Validate request
-- 2. Route client protocol to driver capability (passthrough / convert / error)
Expand All @@ -124,15 +226,9 @@ function _M.before_proxy(conf, ctx, on_error)
local extra_opts = {
name = ai_instance.name,
endpoint = core.table.try_read_attr(ai_instance, "override", "endpoint"),
model_options = ai_instance.options,
conf = ai_instance.provider_conf or {},
auth = ai_instance.auth,
override_llm_options =
core.table.try_read_attr(ai_instance, "override", "llm_options"),
request_body_override_map =
core.table.try_read_attr(ai_instance, "override", "request_body"),
request_body_force_override =
core.table.try_read_attr(ai_instance, "override", "request_body_force_override"),
ai_instance = ai_instance,
}
-- Step 1: Route client protocol to driver capability
local client_protocol = ctx.ai_client_protocol
Expand Down
2 changes: 1 addition & 1 deletion apisix/plugins/ai-request-rewrite.lua
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ local function request_to_llm(conf, request_table, ctx, target_path)
local extra_opts = {
endpoint = core.table.try_read_attr(conf, "override", "endpoint"),
auth = conf.auth,
model_options = conf.options,
ai_instance = conf,
target_path = target_path,
}
ctx.llm_request_start_time = ngx.now()
Expand Down
139 changes: 139 additions & 0 deletions t/plugin/ai-proxy-request-body-override.t
Original file line number Diff line number Diff line change
Expand Up @@ -819,3 +819,142 @@ max_tokens=321
}
--- response_body
max_completion_tokens=200 temperature=0.5



=== TEST 17: effective_request_for_cache returns post-override body
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
-- ai-proxy applies overrides; serverless-post-function (priority -2000)
-- runs after ai-proxy access (priority 1040) in the access phase, invokes
-- the helper, and logs its output. The test asserts BOTH the
-- upstream-received body AND the helper output reflect the same
-- post-override view.
local code = t('/apisix/admin/routes/1',
ngx.HTTP_PUT,
[[{
"uri": "/chat",
"plugins": {
"ai-proxy": {
"provider": "openai",
"auth": { "header": { "Authorization": "Bearer t" } },
"options": { "model": "options-model" },
"override": {
"endpoint": "http://localhost:6732",
"request_body": {
"openai-chat": { "temperature": 0.42 }
}
},
"ssl_verify": false
},
"serverless-post-function": {
"functions": ["return function(_, ctx)
local b = require('apisix.plugins.ai-proxy.base')
local cjson = require('cjson.safe')
local body, err = b.effective_request_for_cache(ctx)
ngx.log(ngx.WARN, 'EFFECTIVE_BODY=',
body and cjson.encode(body)
or ('ERR:' .. tostring(err)))
end"]
}
}
}]]
)
if code >= 300 then ngx.status = code; return end

local http = require("resty.http").new()
local res = assert(http:request_uri("http://127.0.0.1:" .. ngx.var.server_port .. "/chat", {
method = "POST",
body = '{"messages":[{"role":"user","content":"hi"}],"model":"client-model"}',
headers = { ["Content-Type"] = "application/json" },
}))
local cjson = require("cjson.safe")
local body = cjson.decode(res.body)
local echoed = cjson.decode(body.choices[1].message.content)
ngx.say("upstream model=", echoed.model,
" upstream temperature=", echoed.temperature)
}
}
--- response_body
upstream model=options-model upstream temperature=0.42
--- error_log eval
[
qr/EFFECTIVE_BODY=.*"model":"options-model"/,
qr/EFFECTIVE_BODY=.*"temperature":0\.42/,
]



=== TEST 18: effective_request_for_cache applies the converter (anthropic-messages -> openai-chat)
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
-- Client sends anthropic-messages format to an openai provider, which
-- speaks openai-chat natively. The converter translates the body and
-- override.request_body.openai-chat then applies. The helper should
-- mirror this: convert first, then apply overrides. Distinctive
-- post-converter marker: max_tokens (anthropic) becomes
-- max_completion_tokens (openai-chat) and the original max_tokens
-- is stripped by the converter ("never forward max_tokens").
local code = t('/apisix/admin/routes/1',
ngx.HTTP_PUT,
[[{
"uri": "/v1/messages",
"plugins": {
"ai-proxy": {
"provider": "openai",
"auth": { "header": { "Authorization": "Bearer t" } },
"override": {
"endpoint": "http://localhost:6732",
"request_body": {
"openai-chat": { "temperature": 0.42 }
}
},
"ssl_verify": false
},
"serverless-post-function": {
"functions": ["return function(_, ctx)
local b = require('apisix.plugins.ai-proxy.base')
local cjson = require('cjson.safe')
local body, err = b.effective_request_for_cache(ctx)
ngx.log(ngx.WARN, 'EFFECTIVE_BODY=',
body and cjson.encode(body)
or ('ERR:' .. tostring(err)))
end"]
}
}
}]]
)
if code >= 300 then ngx.status = code; return end

local http = require("resty.http").new()
local res = assert(http:request_uri("http://127.0.0.1:" .. ngx.var.server_port .. "/v1/messages", {
method = "POST",
body = '{"model":"claude-3","max_tokens":10,"messages":[{"role":"user","content":"hi"}]}',
headers = { ["Content-Type"] = "application/json" },
}))
ngx.status = res.status
-- The /v1/messages stub echoes the raw upstream body as the message
-- text; ai-proxy converts the openai-chat response back to
-- anthropic-messages, so body.content[1].text is the post-converter
-- post-override body the upstream actually received.
local cjson = require("cjson.safe")
local body = cjson.decode(res.body)
local echoed = cjson.decode(body.content[1].text)
ngx.say("upstream max_completion_tokens=", echoed.max_completion_tokens,
" upstream temperature=", echoed.temperature,
" upstream max_tokens=", tostring(echoed.max_tokens))
}
}
--- response_body
upstream max_completion_tokens=10 upstream temperature=0.42 upstream max_tokens=nil
--- error_log eval
[
qr/EFFECTIVE_BODY=.*"max_completion_tokens":10/,
qr/EFFECTIVE_BODY=.*"temperature":0\.42/,
]
--- no_error_log eval
qr/EFFECTIVE_BODY=.*"max_tokens":10/
Loading