Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions internal/cmd/import-test/with-import-common.zed
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
definition user {}

caveat mycaveat(day_of_week string) {
day_of_week == "friday"
}

10 changes: 10 additions & 0 deletions internal/cmd/import-test/with-import-root.zed
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use import

import "with-import-common.zed"

definition resource {
relation writer: user
relation reader: user with mycaveat
permission write = writer
permission view = reader + write
}
5 changes: 5 additions & 0 deletions internal/cmd/import-test/with-import-validation-file.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
schemaFile: "./with-import-root.zed"
relationships: |-
resource:1#reader@user:1[mycaveat]
resource:2#writer@user:1
9 changes: 5 additions & 4 deletions internal/cmd/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func importCmdFunc(cmd *cobra.Command, schemaClient v1.SchemaServiceClient, rela
}

if cobrautil.MustGetBool(cmd, "schema") {
if err := importSchema(cmd.Context(), schemaClient, p.Schema.Schema, prefix); err != nil {
if err := importSchema(cmd.Context(), schemaClient, p.Schema.Schema, prefix, p.RootSchemaDir); err != nil {
return fmt.Errorf("error importing schema: %w", err)
}
}
Expand All @@ -98,11 +98,12 @@ func importCmdFunc(cmd *cobra.Command, schemaClient v1.SchemaServiceClient, rela
return err
}

func importSchema(ctx context.Context, client v1.SchemaServiceClient, schema string, definitionPrefix string) error {
func importSchema(ctx context.Context, client v1.SchemaServiceClient, schema string, definitionPrefix string, rootSchemaDir string) error {
log.Info().Msg("importing schema")

// Recompile the schema with the specified prefix.
schemaText, err := rewriteSchema(ctx, schema, definitionPrefix)
// Compile with the schema's root directory so any `import` statements resolve, and
// (optionally) apply the definition prefix.
schemaText, err := rewriteSchema(ctx, schema, definitionPrefix, rootSchemaDir)
if err != nil {
return err
}
Expand Down
47 changes: 47 additions & 0 deletions internal/cmd/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,53 @@ func TestImportCmd(t *testing.T) {
}
}

func TestImportCmdSchemaWithImports(t *testing.T) {
require := require.New(t)
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "schema-definition-prefix"},
zedtesting.BoolFlag{FlagName: "schema", FlagValue: true},
zedtesting.BoolFlag{FlagName: "relationships", FlagValue: true},
zedtesting.IntFlag{FlagName: "batch-size", FlagValue: 100},
zedtesting.IntFlag{FlagName: "workers", FlagValue: 1},
)
f := filepath.Join("import-test", "with-import-validation-file.yaml")

ctx := t.Context()
srv := zedtesting.NewTestServer(ctx, t)
go func() {
assert.NoError(t, srv.Run(ctx))
}()
conn, err := srv.GRPCDialContext(ctx)
require.NoError(err)
t.Cleanup(func() {
conn.Close()
})

c, err := zedtesting.ClientFromConn(conn)(cmd)
require.NoError(err)

// The YAML points to a .zed file that uses `import "with-import-common.zed"`. WriteSchema
// rejects `import` statements, so this exercises that the client flattens the schema
// (via rewriteSchema + SourceFolder) before sending.
err = importCmdFunc(cmd, c, c, "", f)
require.NoError(err)

rel := tuple.MustParse(`resource:1#view@user:1[mycaveat]`)
resp, err := c.CheckPermission(ctx, &v1.CheckPermissionRequest{
Consistency: fullyConsistent,
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: rel.Subject.ObjectType, ObjectId: rel.Subject.ObjectID}},
Permission: "view",
Resource: &v1.ObjectReference{ObjectType: rel.Resource.ObjectType, ObjectId: rel.Resource.ObjectID},
Context: &structpb.Struct{
Fields: map[string]*structpb.Value{
"day_of_week": structpb.NewStringValue("friday"),
},
},
})
require.NoError(err)
require.Equal(v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, resp.Permissionship)
}

func TestImportCmdRelationsOnly(t *testing.T) {
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "schema-definition-prefix"},
Expand Down
11 changes: 7 additions & 4 deletions internal/cmd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func schemaCopyInner(ctx context.Context, srcClient, destClient v1.SchemaService
return nil, err
}

schemaText, err := rewriteSchema(ctx, readResp.SchemaText, prefix)
schemaText, err := rewriteSchema(ctx, readResp.SchemaText, prefix, "")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -263,12 +263,14 @@ func schemaWriteCmdImpl(cmd *cobra.Command, args []string, client v1.SchemaServi
}

var schemaBytes []byte
var sourceFolder string
switch len(args) {
case 1:
schemaBytes, err = os.ReadFile(args[0])
if err != nil {
return fmt.Errorf("failed to read schema file: %w", err)
}
sourceFolder = filepath.Dir(args[0])
log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file")
case 0:
schemaBytes, err = io.ReadAll(os.Stdin)
Expand All @@ -289,7 +291,7 @@ func schemaWriteCmdImpl(cmd *cobra.Command, args []string, client v1.SchemaServi
return err
}

schemaText, err := rewriteSchema(cmd.Context(), string(schemaBytes), prefix)
schemaText, err := rewriteSchema(cmd.Context(), string(schemaBytes), prefix, sourceFolder)
if err != nil {
return err
}
Expand All @@ -316,15 +318,16 @@ func schemaWriteCmdImpl(cmd *cobra.Command, args []string, client v1.SchemaServi
}

// rewriteSchema rewrites the given existing schema to include the specified prefix on all definitions and caveats.
func rewriteSchema(ctx context.Context, existingSchemaText string, definitionPrefix string) (string, error) {
if definitionPrefix == "" {
func rewriteSchema(ctx context.Context, existingSchemaText string, definitionPrefix string, sourceFolder string) (string, error) {
if definitionPrefix == "" && sourceFolder == "" {
return existingSchemaText, nil
}

compiled, err := compiler.Compile(
compiler.InputSchema{Source: input.Source("schema"), SchemaString: existingSchemaText},
compiler.ObjectTypePrefix(definitionPrefix),
compiler.SkipValidation(),
compiler.SourceFolder(sourceFolder),
)
if err != nil {
return "", err
Expand Down
7 changes: 4 additions & 3 deletions internal/cmd/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ caveat test/some_caveat(someCondition int) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
found, err := rewriteSchema(t.Context(), test.existingSchema, test.definitionPrefix)
found, err := rewriteSchema(t.Context(), test.existingSchema, test.definitionPrefix, "")
require.NoError(t, err)
require.Equal(t, test.expectedSchema, found)
})
Expand Down Expand Up @@ -375,9 +375,10 @@ func TestSchemaWrite(t *testing.T) {
}, nil
},
expectSchemaWritten: `definition user {}

definition resource {
relation view: user
permission viewer = view
relation view: user
permission viewer = view
}`,
terminalChecker: &mockTermChecker{returnVal: false},
},
Expand Down
Loading