Skip to content

Commit 25aaeb0

Browse files
authored
Merge pull request #560 from chaitin/feat/model-response-guard
fix: 收敛模型响应敏感信息
2 parents af6a51c + e0006b1 commit 25aaeb0

6 files changed

Lines changed: 198 additions & 23 deletions

File tree

backend/biz/setting/usecase/model.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ func (u *modelUsecase) List(ctx context.Context, uid uuid.UUID, cursor domain.Cu
7575
models := cvt.Iter(ms, func(_ int, m *db.Model) *domain.Model {
7676
j := cvt.From(m, &domain.Model{})
7777
j.IsDefault = j.GetIsDefault(user)
78+
j.HideSharedCredentials()
7879
return j
7980
})
8081

@@ -84,6 +85,9 @@ func (u *modelUsecase) List(ctx context.Context, uid uuid.UUID, cursor domain.Cu
8485
u.logger.ErrorContext(ctx, "failed to list additional models from hook", "error", err, "user_id", uid)
8586
return nil, fmt.Errorf("failed to list additional models: %w", err)
8687
}
88+
for _, model := range additionalModels {
89+
model.HideSharedCredentials()
90+
}
8791
models = append(models, additionalModels...)
8892
}
8993

backend/biz/task/usecase/task.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,7 @@ func (a *TaskUsecase) SwitchModel(ctx context.Context, user *domain.User, taskID
233233
}
234234
}
235235

236-
respModel := cvt.From(model, &domain.Model{})
237-
respModel.APIKey = ""
236+
respModel := cvt.From(model, &domain.ModelBrief{})
238237
return &domain.SwitchTaskModelResp{
239238
ID: item.ID,
240239
RequestID: resp.RequestId,

backend/biz/task/usecase/task_switch_model_test.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ func TestSwitchModelRestartsWithExecutionConfigAndUpdatesModel(t *testing.T) {
116116
if resp.Model == nil || resp.Model.ID != toModelID {
117117
t.Fatalf("resp.Model = %+v, want target model", resp.Model)
118118
}
119-
if resp.Model.APIKey != "" {
120-
t.Fatalf("resp.Model.APIKey = %q, want empty response api key", resp.Model.APIKey)
121-
}
122119

123120
if repo.created == nil {
124121
t.Fatal("CreateModelSwitch was not called")

backend/domain/model.go

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,6 @@ func (m *Model) From(src *db.Model) *Model {
7777
m.UpdatedAt = src.UpdatedAt.Unix()
7878
m.LastCheckAt = src.LastCheckAt.Unix()
7979

80-
if src.Remark == "economy" {
81-
m.APIKey = ""
82-
m.BaseURL = ""
83-
}
84-
8580
if src.Edges.User == nil {
8681
return m
8782
}
@@ -99,7 +94,6 @@ func (m *Model) From(src *db.Model) *Model {
9994
Type: consts.OwnerTypeTeam,
10095
Name: team.Name,
10196
}
102-
m.APIKey = ""
10397
return m
10498
}
10599
if src.Edges.User.Role == consts.UserRoleAdmin {
@@ -108,13 +102,72 @@ func (m *Model) From(src *db.Model) *Model {
108102
Type: consts.OwnerTypePublic,
109103
Name: consts.MonkeyCodeAITeamName,
110104
}
111-
m.APIKey = ""
112-
m.BaseURL = ""
113105
return m
114106
}
115107
return m
116108
}
117109

110+
func (m *Model) HideCredentials() *Model {
111+
if m == nil {
112+
return m
113+
}
114+
m.APIKey = ""
115+
m.BaseURL = ""
116+
return m
117+
}
118+
119+
func (m *Model) HideSharedCredentials() *Model {
120+
if m == nil || m.Owner == nil || m.Owner.Type == consts.OwnerTypePrivate {
121+
return m
122+
}
123+
return m.HideCredentials()
124+
}
125+
126+
type ModelBrief struct {
127+
ID uuid.UUID `json:"id"`
128+
Provider string `json:"provider"`
129+
Model string `json:"model"`
130+
Temperature float64 `json:"temperature"`
131+
CreatedAt int64 `json:"created_at"`
132+
UpdatedAt int64 `json:"updated_at"`
133+
Weight int `json:"weight"`
134+
Owner *Owner `json:"owner,omitempty"`
135+
InterfaceType consts.InterfaceType `json:"interface_type"`
136+
IsFree bool `json:"is_free"`
137+
AccessLevel string `json:"access_level"`
138+
LastCheckAt int64 `json:"last_check_at"`
139+
LastCheckSuccess bool `json:"last_check_success"`
140+
LastCheckError string `json:"last_check_error"`
141+
ThinkingEnabled bool `json:"thinking_enabled"`
142+
ContextLimit int `json:"context_limit"`
143+
OutputLimit int `json:"output_limit"`
144+
}
145+
146+
func (m *ModelBrief) From(src *db.Model) *ModelBrief {
147+
if src == nil {
148+
return m
149+
}
150+
full := (&Model{}).From(src)
151+
m.ID = full.ID
152+
m.Provider = full.Provider
153+
m.Model = full.Model
154+
m.Temperature = full.Temperature
155+
m.CreatedAt = full.CreatedAt
156+
m.UpdatedAt = full.UpdatedAt
157+
m.Weight = full.Weight
158+
m.Owner = full.Owner
159+
m.InterfaceType = full.InterfaceType
160+
m.IsFree = full.IsFree
161+
m.AccessLevel = full.AccessLevel
162+
m.LastCheckAt = full.LastCheckAt
163+
m.LastCheckSuccess = full.LastCheckSuccess
164+
m.LastCheckError = full.LastCheckError
165+
m.ThinkingEnabled = full.ThinkingEnabled
166+
m.ContextLimit = full.ContextLimit
167+
m.OutputLimit = full.OutputLimit
168+
return m
169+
}
170+
118171
func (m *Model) GetIsDefault(user *db.User) bool {
119172
if defaultModelID, ok := user.DefaultConfigs[consts.DefaultConfigTypeModel]; ok {
120173
if defaultModelID.String() == m.ID.String() {
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package domain_test
2+
3+
import (
4+
"encoding/json"
5+
"strings"
6+
"testing"
7+
"time"
8+
9+
"github.com/google/uuid"
10+
11+
"github.com/chaitin/MonkeyCode/backend/consts"
12+
"github.com/chaitin/MonkeyCode/backend/db"
13+
"github.com/chaitin/MonkeyCode/backend/domain"
14+
)
15+
16+
func TestModelFromPreservesCredentialsForPureConversion(t *testing.T) {
17+
modelID := uuid.New()
18+
src := &db.Model{
19+
ID: modelID,
20+
Provider: "OpenAI",
21+
APIKey: "sk-admin-secret",
22+
BaseURL: "https://api.example.com/v1",
23+
Model: "gpt-4o",
24+
CreatedAt: time.Unix(100, 0),
25+
UpdatedAt: time.Unix(200, 0),
26+
Edges: db.ModelEdges{
27+
User: &db.User{
28+
ID: uuid.New(),
29+
Role: consts.UserRoleAdmin,
30+
Name: "admin",
31+
},
32+
},
33+
}
34+
35+
got := (&domain.Model{}).From(src)
36+
37+
if got.APIKey != src.APIKey {
38+
t.Fatalf("APIKey = %q, want %q", got.APIKey, src.APIKey)
39+
}
40+
if got.BaseURL != src.BaseURL {
41+
t.Fatalf("BaseURL = %q, want %q", got.BaseURL, src.BaseURL)
42+
}
43+
if got.Owner == nil || got.Owner.Type != consts.OwnerTypePublic {
44+
t.Fatalf("Owner = %#v, want public owner", got.Owner)
45+
}
46+
}
47+
48+
func TestProjectTaskFromDoesNotExposeModelCredentials(t *testing.T) {
49+
pt := (&domain.ProjectTask{}).From(&db.ProjectTask{
50+
Branch: "main",
51+
Edges: db.ProjectTaskEdges{
52+
Model: privateModelWithCredentials(),
53+
Task: &db.Task{
54+
ID: uuid.New(),
55+
UserID: uuid.New(),
56+
CreatedAt: time.Unix(100, 0),
57+
UpdatedAt: time.Unix(200, 0),
58+
},
59+
},
60+
})
61+
62+
payload, err := json.Marshal(pt)
63+
if err != nil {
64+
t.Fatalf("marshal project task: %v", err)
65+
}
66+
67+
assertNoModelCredentials(t, string(payload))
68+
}
69+
70+
func TestTaskFromDoesNotExposeModelCredentials(t *testing.T) {
71+
task := (&domain.Task{}).From(&db.Task{
72+
ID: uuid.New(),
73+
UserID: uuid.New(),
74+
CreatedAt: time.Unix(100, 0),
75+
UpdatedAt: time.Unix(200, 0),
76+
Edges: db.TaskEdges{
77+
ProjectTasks: []*db.ProjectTask{
78+
{
79+
Edges: db.ProjectTaskEdges{
80+
Model: privateModelWithCredentials(),
81+
},
82+
},
83+
},
84+
},
85+
})
86+
87+
payload, err := json.Marshal(task)
88+
if err != nil {
89+
t.Fatalf("marshal task: %v", err)
90+
}
91+
92+
assertNoModelCredentials(t, string(payload))
93+
}
94+
95+
func privateModelWithCredentials() *db.Model {
96+
return &db.Model{
97+
ID: uuid.New(),
98+
Provider: "OpenAI",
99+
APIKey: "sk-private-secret",
100+
BaseURL: "https://private.example.com/v1",
101+
Model: "gpt-4o",
102+
CreatedAt: time.Unix(100, 0),
103+
UpdatedAt: time.Unix(200, 0),
104+
Edges: db.ModelEdges{
105+
User: &db.User{
106+
ID: uuid.New(),
107+
Role: consts.UserRoleIndividual,
108+
Name: "user",
109+
},
110+
},
111+
}
112+
}
113+
114+
func assertNoModelCredentials(t *testing.T, payload string) {
115+
t.Helper()
116+
117+
for _, forbidden := range []string{"api_key", "sk-private-secret", "https://private.example.com/v1"} {
118+
if strings.Contains(payload, forbidden) {
119+
t.Fatalf("payload exposes %q: %s", forbidden, payload)
120+
}
121+
}
122+
}

backend/domain/task.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,12 @@ type SwitchTaskModelReq struct {
124124

125125
// SwitchTaskModelResp 切换任务运行模型响应
126126
type SwitchTaskModelResp struct {
127-
ID uuid.UUID `json:"id"`
128-
RequestID string `json:"request_id,omitempty"`
129-
Success bool `json:"success"`
130-
Message string `json:"message"`
131-
SessionID string `json:"session_id"`
132-
Model *Model `json:"model,omitempty"`
127+
ID uuid.UUID `json:"id"`
128+
RequestID string `json:"request_id,omitempty"`
129+
Success bool `json:"success"`
130+
Message string `json:"message"`
131+
SessionID string `json:"session_id"`
132+
Model *ModelBrief `json:"model,omitempty"`
133133
}
134134

135135
// TaskModelSwitch 任务模型切换记录
@@ -160,7 +160,7 @@ type TaskListReq struct {
160160
// ProjectTask 项目任务
161161
type ProjectTask struct {
162162
ID uuid.UUID `json:"id" validate:"required"`
163-
Model *Model `json:"model,omitempty"`
163+
Model *ModelBrief `json:"model,omitempty"`
164164
Image *Image `json:"image,omitempty"`
165165
Branch string `json:"branch,omitempty"`
166166
CliName consts.CliName `json:"cli_name,omitempty"`
@@ -180,7 +180,7 @@ func (pt *ProjectTask) From(src *db.ProjectTask) *ProjectTask {
180180
if src.Edges.Task != nil {
181181
pt.ID = src.Edges.Task.ID
182182
}
183-
pt.Model = cvt.From(src.Edges.Model, &Model{})
183+
pt.Model = cvt.From(src.Edges.Model, &ModelBrief{})
184184
pt.Task = cvt.From(src.Edges.Task, &Task{})
185185
pt.CliName = src.CliName
186186
pt.RepoURL = src.RepoURL
@@ -217,7 +217,7 @@ type Task struct {
217217
CreatedAt int64 `json:"created_at"`
218218
LastActiveAt int64 `json:"last_active_at"`
219219
CompletedAt int64 `json:"completed_at"`
220-
Model *Model `json:"model,omitempty"`
220+
Model *ModelBrief `json:"model,omitempty"`
221221
Image *Image `json:"image,omitempty"`
222222
Branch string `json:"branch,omitempty"`
223223
CliName consts.CliName `json:"cli_name,omitempty"`
@@ -264,7 +264,7 @@ func (t *Task) From(src *db.Task) *Task {
264264
}
265265
if pts := src.Edges.ProjectTasks; len(pts) > 0 {
266266
pt := pts[0]
267-
t.Model = cvt.From(pt.Edges.Model, &Model{})
267+
t.Model = cvt.From(pt.Edges.Model, &ModelBrief{})
268268
t.Image = cvt.From(pt.Edges.Image, &Image{})
269269
t.Branch = pt.Branch
270270
t.RepoURL = pt.RepoURL

0 commit comments

Comments
 (0)