diff --git a/components/callbacks/request.go b/components/callbacks/request.go index 488835b9706..6f58992a905 100644 --- a/components/callbacks/request.go +++ b/components/callbacks/request.go @@ -6,6 +6,7 @@ import ( "github.com/nexus-rpc/sdk-go/nexus" "go.temporal.io/api/serviceerror" + persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/log" @@ -42,16 +43,40 @@ func routeSystemCallbackRequest( logger.Error("failed to decode completion from token", tag.Error(err)) return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") } - ns, err := namespaceRegistry.GetNamespaceByID(namespace.ID(completion.NamespaceId)) + + // Normalize to support two possible token shapes: + // - legacy HSM tokens carry namespace/workflow IDs directly + // - CHASM tokens carry an encoded component ref instead + namespaceID := completion.GetNamespaceId() + businessID := completion.GetWorkflowId() + if namespaceID == "" && len(completion.GetComponentRef()) > 0 { + ref := &persistencespb.ChasmComponentRef{} + if err := ref.Unmarshal(completion.GetComponentRef()); err != nil { + logger.Error("failed to decode CHASM component ref from callback token", tag.Error(err)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") + } + if ref.GetNamespaceId() == "" { + logger.Error("decoded CHASM component ref is missing namespace ID") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") + } + if ref.GetBusinessId() == "" { + logger.Error("decoded CHASM component ref is missing business ID") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") + } + namespaceID = ref.GetNamespaceId() + businessID = ref.GetBusinessId() + } + + ns, err := namespaceRegistry.GetNamespaceByID(namespace.ID(namespaceID)) if err != nil { - logger.Error("failed to get namespace for nexus completion request", tag.WorkflowNamespaceID(completion.NamespaceId), tag.Error(err)) + logger.Error("failed to get namespace for nexus completion request", tag.WorkflowNamespaceID(namespaceID), tag.Error(err)) var nfe *serviceerror.NamespaceNotFound if errors.As(err, &nfe) { - return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace %q not found", completion.NamespaceId) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace %q not found", namespaceID) } return nil, commonnexus.ConvertGRPCError(err, false) } - clusterName := ns.ActiveClusterName(namespace.RoutingKey{ID: completion.GetWorkflowId()}) + clusterName := ns.ActiveClusterName(namespace.RoutingKey{ID: businessID}) if clusterMetadata.GetCurrentClusterName() == clusterName { frontendClient = localClient } else { diff --git a/components/callbacks/request_test.go b/components/callbacks/request_test.go index dac01cef5ab..bcfea254f9b 100644 --- a/components/callbacks/request_test.go +++ b/components/callbacks/request_test.go @@ -238,58 +238,142 @@ func TestRouteSystemCallbackRequest_NamespaceNotFound(t *testing.T) { } func TestRouteSystemCallbackRequest_Success(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, commonnexus.PathCompletionCallbackNoIdentifier, r.URL.Path) - w.WriteHeader(http.StatusOK) - })) - defer ts.Close() - - ctrl := gomock.NewController(t) - clusterMeta := cluster.NewMockMetadata(ctrl) - nsRegistry := namespace.NewMockRegistry(ctrl) - - tokenGen := commonnexus.NewCallbackTokenGenerator() - tokenStr, err := tokenGen.Tokenize(&tokenspb.NexusOperationCompletion{ - NamespaceId: "ns-id-1", - WorkflowId: "wf-1", - RunId: "run-1", - Ref: &persistencespb.StateMachineRef{}, - }) - require.NoError(t, err) - - testNS := namespace.NewLocalNamespaceForTest( - &persistencespb.NamespaceInfo{Id: "ns-id-1", Name: "test-ns"}, - nil, - "cluster-A", - ) - nsRegistry.EXPECT().GetNamespaceByID(namespace.ID("ns-id-1")).Return(testNS, nil) - - // httpClientCache.Get will fail for "cluster-A", so it falls back to localClient. - clusterMeta.EXPECT().GetCurrentClusterName().Return("cluster-A").AnyTimes() - clusterMeta.EXPECT().GetAllClusterInfo().Return(map[string]cluster.ClusterInformation{}).AnyTimes() - clusterMeta.EXPECT().RegisterMetadataChangeCallback(gomock.Any(), gomock.Any()) - - localClient := newTestFrontendHTTPClient(ts) - - // Create a cache that will fail for the requested cluster since we don't set up metadata fully. - httpClientCache := cluster.NewFrontendHTTPClientCache(clusterMeta, nil) - - r, err := http.NewRequest(http.MethodPost, commonnexus.SystemCallbackURL, nil) - require.NoError(t, err) - r.Header.Set(commonnexus.CallbackTokenHeader, tokenStr) + for _, tc := range []struct { + name string + completionToken func(*commonnexus.CallbackTokenGenerator) (string, error) + }{ + { + name: "HSM", + completionToken: func(tokenGen *commonnexus.CallbackTokenGenerator) (string, error) { + return tokenGen.Tokenize(&tokenspb.NexusOperationCompletion{ + // HSM sets the deprecated execution fields and ref. + NamespaceId: "ns-id-1", + WorkflowId: "wf-1", + RunId: "run-1", + Ref: &persistencespb.StateMachineRef{}, + }) + }, + }, + { + name: "CHASM", + completionToken: func(tokenGen *commonnexus.CallbackTokenGenerator) (string, error) { + ref, err := (&persistencespb.ChasmComponentRef{ + NamespaceId: "ns-id-1", + BusinessId: "wf-1", + RunId: "run-1", + }).Marshal() + if err != nil { + return "", err + } + return tokenGen.Tokenize(&tokenspb.NexusOperationCompletion{ + // CHASM sets ComponentRef instead of HSM execution fields. + ComponentRef: ref, + }) + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var gotPath string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + ctrl := gomock.NewController(t) + clusterMeta := cluster.NewMockMetadata(ctrl) + nsRegistry := namespace.NewMockRegistry(ctrl) + + tokenGen := commonnexus.NewCallbackTokenGenerator() + tokenStr, err := tc.completionToken(tokenGen) + require.NoError(t, err) + + testNS := namespace.NewLocalNamespaceForTest( + &persistencespb.NamespaceInfo{Id: "ns-id-1", Name: "test-ns"}, + nil, + "cluster-A", + ) + nsRegistry.EXPECT().GetNamespaceByID(namespace.ID("ns-id-1")).Return(testNS, nil) + + // httpClientCache.Get will fail for "cluster-A", so it falls back to localClient. + clusterMeta.EXPECT().GetCurrentClusterName().Return("cluster-A").AnyTimes() + clusterMeta.EXPECT().GetAllClusterInfo().Return(map[string]cluster.ClusterInformation{}).AnyTimes() + clusterMeta.EXPECT().RegisterMetadataChangeCallback(gomock.Any(), gomock.Any()) + + localClient := newTestFrontendHTTPClient(ts) + // Create a cache that will fail for the requested cluster since we don't set up metadata fully. + httpClientCache := cluster.NewFrontendHTTPClientCache(clusterMeta, nil) + + r, err := http.NewRequest(http.MethodPost, commonnexus.SystemCallbackURL, nil) + require.NoError(t, err) + r.Header.Set(commonnexus.CallbackTokenHeader, tokenStr) + + resp, err := routeSystemCallbackRequest( + r, + clusterMeta, + nsRegistry, + httpClientCache, + tokenGen, + localClient, + log.NewNoopLogger(), + ) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, commonnexus.PathCompletionCallbackNoIdentifier, gotPath) + }) + } +} - resp, err := routeSystemCallbackRequest( - r, - clusterMeta, - nsRegistry, - httpClientCache, - tokenGen, - localClient, - log.NewNoopLogger(), - ) - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - require.Equal(t, http.StatusOK, resp.StatusCode) +func TestRouteSystemCallbackRequest_InvalidChasmComponentRef(t *testing.T) { + for _, tc := range []struct { + name string + ref *persistencespb.ChasmComponentRef + }{ + { + name: "missing namespace id", + ref: &persistencespb.ChasmComponentRef{ + BusinessId: "wf-1", + RunId: "run-1", + }, + }, + { + name: "missing business id", + ref: &persistencespb.ChasmComponentRef{ + NamespaceId: "ns-id-1", + RunId: "run-1", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + tokenGen := commonnexus.NewCallbackTokenGenerator() + + ref, err := tc.ref.Marshal() + require.NoError(t, err) + + tokenStr, err := tokenGen.Tokenize(&tokenspb.NexusOperationCompletion{ComponentRef: ref}) + require.NoError(t, err) + + r, err := http.NewRequest(http.MethodPost, commonnexus.SystemCallbackURL, nil) + require.NoError(t, err) + r.Header.Set(commonnexus.CallbackTokenHeader, tokenStr) + + _, err = routeSystemCallbackRequest( + r, + nil, + nil, + nil, + tokenGen, + nil, + log.NewNoopLogger(), + ) + require.Error(t, err) + var handlerErr *nexus.HandlerError + require.ErrorAs(t, err, &handlerErr) + require.Equal(t, nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + require.Contains(t, handlerErr.Error(), "invalid callback token") + }) + } } func TestRouteRequest_SystemCallback(t *testing.T) {