diff --git a/dataproxy/service/dataproxy_service.go b/dataproxy/service/dataproxy_service.go index d9063638cab..79a28d724d3 100644 --- a/dataproxy/service/dataproxy_service.go +++ b/dataproxy/service/dataproxy_service.go @@ -19,11 +19,11 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/dataproxy/config" + "github.com/flyteorg/flyte/v2/dataproxy/logs" "github.com/flyteorg/flyte/v2/flytestdlib/logger" "github.com/flyteorg/flyte/v2/flytestdlib/storage" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" flyteIdlCore "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" - "github.com/flyteorg/flyte/v2/dataproxy/logs" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project" @@ -470,6 +470,10 @@ func (s *Service) GetActionData( ctx context.Context, req *connect.Request[dataproxy.GetActionDataRequest], ) (*connect.Response[dataproxy.GetActionDataResponse], error) { + if err := req.Msg.Validate(); err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + actionId := req.Msg.GetActionId() urisResp, err := s.runClient.GetActionDataURIs(ctx, connect.NewRequest(&workflow.GetActionDataURIsRequest{ @@ -495,11 +499,17 @@ func (s *Service) GetActionData( } logger.Infof(groupCtx, "GetActionData: reading inputs from %s", inputRef) if err := s.dataStore.ReadProtobuf(groupCtx, inputRef, resp.Inputs); err != nil { - logger.Errorf(groupCtx, "GetActionData: failed to read inputs from %s: %v", inputRef, err) - return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read inputs from %s: %w", inputRef, err)) + if !storage.IsNotFound(err) { + logger.Errorf(groupCtx, "GetActionData: failed to read inputs from %s: %v", inputRef, err) + return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read inputs from %s: %w", inputRef, err)) + } + } else { + logger.Debugf(groupCtx, "Read %d input literals and %d action contexts", len(resp.Inputs.Literals), len(resp.Inputs.Context)) } return nil }) + } else { + logger.Warnf(ctx, "Action %s has empty InputURI", req.Msg.ActionId.Name) } if urisResp.Msg.GetOutputsUri() != "" { @@ -508,11 +518,16 @@ func (s *Service) GetActionData( logger.Infof(groupCtx, "GetActionData: reading outputs from %s", outputRef) var inputsOrOutputs task.Inputs if err := s.dataStore.ReadProtobuf(groupCtx, outputRef, &inputsOrOutputs); err != nil { - logger.Errorf(groupCtx, "GetActionData: failed to read outputs from %s: %v", outputRef, err) - return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read outputs from %s: %w", outputRef, err)) - } - resp.Outputs = &task.Outputs{ - Literals: inputsOrOutputs.GetLiterals(), + if !storage.IsNotFound(err) { + logger.Errorf(groupCtx, "GetActionData: failed to read outputs from %s: %v", outputRef, err) + return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read outputs from %s: %w", outputRef, err)) + } + logger.Debugf(groupCtx, "Outputs not found at %s (action may not have finished)", urisResp.Msg.GetOutputsUri()) + } else { + resp.Outputs = &task.Outputs{ + Literals: inputsOrOutputs.GetLiterals(), + } + logger.Debugf(groupCtx, "Read %d output literals", len(resp.Outputs.Literals)) } return nil }) diff --git a/dataproxy/service/dataproxy_service_test.go b/dataproxy/service/dataproxy_service_test.go index 98fda0d4667..7b302cb10ec 100644 --- a/dataproxy/service/dataproxy_service_test.go +++ b/dataproxy/service/dataproxy_service_test.go @@ -26,9 +26,9 @@ import ( "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project" + projectMocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect/mocks" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" - projectMocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect/mocks" workflowMocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect/mocks" ) diff --git a/runs/service/run_service.go b/runs/service/run_service.go index efb5e48cddb..c029412525c 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -11,9 +11,7 @@ import ( "time" "connectrpc.com/connect" - "github.com/flyteorg/flyte/v2/flytestdlib/app" "golang.org/x/sync/errgroup" - "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" @@ -25,6 +23,8 @@ import ( "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/actions/actionsconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" @@ -39,12 +39,20 @@ import ( type RunService struct { repo interfaces.Repository actionsClient actionsconnect.ActionsServiceClient + dataProxyClient actionDataClient projectClient projectconnect.ProjectServiceClient storagePrefix string dataStore *storage.DataStore abortReconciler *AbortReconciler } +type actionDataClient interface { + GetActionData( + ctx context.Context, + req *connect.Request[dataproxy.GetActionDataRequest], + ) (*connect.Response[dataproxy.GetActionDataResponse], error) +} + const ( runIDLength = 20 runStringFormat = "r%s" @@ -105,10 +113,19 @@ func (s *RunService) WatchGroups(ctx context.Context, req *connect.Request[workf } // NewRunService creates a new RunService instance -func NewRunService(repo interfaces.Repository, actionsClient actionsconnect.ActionsServiceClient, projectClient projectconnect.ProjectServiceClient, storagePrefix string, dataStore *storage.DataStore, reconciler *AbortReconciler) *RunService { +func NewRunService( + repo interfaces.Repository, + actionsClient actionsconnect.ActionsServiceClient, + dataProxyClient dataproxyconnect.DataProxyServiceClient, + projectClient projectconnect.ProjectServiceClient, + storagePrefix string, + dataStore *storage.DataStore, + reconciler *AbortReconciler, +) *RunService { return &RunService{ repo: repo, actionsClient: actionsClient, + dataProxyClient: dataProxyClient, projectClient: projectClient, storagePrefix: storagePrefix, dataStore: dataStore, @@ -746,7 +763,7 @@ func lastAttemptIsTerminal(attempts []*workflow.ActionAttempt) bool { return IsTerminalPhase(last.GetPhase()) } -// GetActionData gets input and output data for an action by reading from storage. +// GetActionData keeps backward compatibility by delegating data reads to DataProxy. func (s *RunService) GetActionData( ctx context.Context, req *connect.Request[workflow.GetActionDataRequest], @@ -758,107 +775,20 @@ func (s *RunService) GetActionData( return nil, connect.NewError(connect.CodeInvalidArgument, err) } - // Get action from DB for storage URIs - action, err := s.repo.ActionRepo().GetAction(ctx, req.Msg.ActionId) - if err != nil { - logger.Errorf(ctx, "Failed to get action: %v", err) - return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action not found: %w", err)) + if s.dataProxyClient == nil { + return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("dataproxy client is not configured")) } - inputURI, _ := extractStorageURIs(action.ActionSpec) - - info := &workflow.RunInfo{} - if err := proto.Unmarshal(action.DetailedInfo, info); err != nil { + dpResp, err := s.dataProxyClient.GetActionData(ctx, connect.NewRequest(&dataproxy.GetActionDataRequest{ + ActionId: req.Msg.GetActionId(), + })) + if err != nil { return nil, err } resp := &workflow.GetActionDataResponse{ - Inputs: &task.Inputs{}, - Outputs: &task.Outputs{}, - } - - // Read inputs from storage - group, groupCtx := errgroup.WithContext(ctx) - if inputURI != "" { - group.Go(func() error { - inputRef := storage.DataReference(inputURI) - logger.Debugf(groupCtx, "Reading inputs from: %s", inputRef) - if err := s.dataStore.ReadProtobuf(groupCtx, inputRef, resp.Inputs); err != nil { - if !storage.IsNotFound(err) { - logger.Errorf(groupCtx, "Failed to read inputs from %s: %v", inputRef, err) - return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read inputs: %w", err)) - } - logger.Debugf(groupCtx, "Inputs not found at %s", inputRef) - } else { - logger.Debugf(groupCtx, "Read %d input literals and %d action contexts", len(resp.Inputs.Literals), len(resp.Inputs.Context)) - } - return nil - }) - } else { - logger.Warnf(ctx, "Action %s has empty InputURI", req.Msg.ActionId.Name) - } - - // Read outputs from storage (only present if action succeeded) - if action.Phase == int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED) { - group.Go(func() error { - // There are no attempts for trace actions, so we can skip the attempt validation - var attempts []*workflow.ActionAttempt - var err error - if workflow.ActionType(action.ActionType) == workflow.ActionType_ACTION_TYPE_TRACE { - if info.GetOutputsUri() == "" { - return nil - } - logger.Debugf(groupCtx, "Reading outputs from: %s", info.GetOutputsUri()) - - outputMap := &core.LiteralMap{} - if err := s.dataStore.ReadProtobuf(groupCtx, storage.DataReference(info.GetOutputsUri()), outputMap); err != nil { - if !storage.IsNotFound(err) { - logger.Errorf(groupCtx, "Failed to read outputs from %s: %v", info.GetOutputsUri(), err) - return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read outputs: %w", err)) - } - logger.Debugf(groupCtx, "Outputs not found at %s (action may not have finished)", info.GetOutputsUri()) - } else { - resp.Outputs = literalMapToOutputs(outputMap) - logger.Debugf(groupCtx, "Read %d output literals", len(resp.Outputs.Literals)) - } - - return nil - } - - // Default to "task" action types - attempts, err = s.getAttempts(groupCtx, req.Msg.GetActionId()) - if err != nil { - return err - } - - if len(attempts) == 0 { - return app.NewServerError(codes.NotFound, "outputs not available, no attempts for action") - } - - outputUri := attempts[len(attempts)-1].GetOutputs().GetOutputUri() - if outputUri == "" { - return app.NewServerError(codes.NotFound, "outputs not available") - } - - logger.Debugf(groupCtx, "Reading outputs from: %s", outputUri) - outputMap := &core.LiteralMap{} - if err := s.dataStore.ReadProtobuf(groupCtx, storage.DataReference(outputUri), outputMap); err != nil { - if !storage.IsNotFound(err) { - logger.Errorf(groupCtx, "Failed to read outputs from %s: %v", outputUri, err) - return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read outputs: %w", err)) - } - logger.Debugf(groupCtx, "Outputs not found at %s (action may not have finished)", outputUri) - } else { - resp.Outputs = literalMapToOutputs(outputMap) - logger.Debugf(groupCtx, "Read %d output literals", len(resp.Outputs.Literals)) - } - - return nil - }) - } - - if err := group.Wait(); err != nil { - return nil, err + Inputs: dpResp.Msg.GetInputs(), + Outputs: dpResp.Msg.GetOutputs(), } logger.Infof(ctx, "Retrieved action data for: %s (inputs=%d, outputs=%d)", diff --git a/runs/service/run_service_test.go b/runs/service/run_service_test.go index 6d82827b306..3329b2fc864 100644 --- a/runs/service/run_service_test.go +++ b/runs/service/run_service_test.go @@ -25,6 +25,7 @@ import ( actionsconnectmocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/actions/actionsconnect/mocks" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project" projectMocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect/mocks" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task" @@ -85,6 +86,19 @@ func newRunServiceTestClient(t *testing.T, svc *RunService) workflowconnect.RunS return workflowconnect.NewRunServiceClient(http.DefaultClient, server.URL) } +type mockDataProxyClient struct { + mock.Mock +} + +func (m *mockDataProxyClient) GetActionData( + ctx context.Context, + req *connect.Request[dataproxy.GetActionDataRequest], +) (*connect.Response[dataproxy.GetActionDataResponse], error) { + args := m.Called(ctx, req) + resp, _ := args.Get(0).(*connect.Response[dataproxy.GetActionDataResponse]) + return resp, args.Error(1) +} + func TestGetRunDetails_WithTaskSpec(t *testing.T) { actionRepo := &repoMocks.ActionRepo{} taskRepo := &repoMocks.TaskRepo{} @@ -1273,21 +1287,7 @@ func TestCreateRun_PreservesInputContextAndRawDataPath(t *testing.T) { require.NoError(t, err) } -func TestGetActionData_ReadsOutputFromAttempts(t *testing.T) { - actionRepo := &repoMocks.ActionRepo{} - actionsClient := actionsconnectmocks.NewActionsServiceClient(t) - repo := &repoMocks.Repository{} - store := &storageMocks.ComposedProtobufStore{} - dataStore := &storage.DataStore{ComposedProtobufStore: store} - - repo.On("ActionRepo").Return(actionRepo) - - svc := &RunService{ - repo: repo, - actionsClient: actionsClient, - dataStore: dataStore, - } - +func TestGetActionData_DelegatesToDataProxy(t *testing.T) { actionID := &common.ActionIdentifier{ Run: &common.RunIdentifier{ Org: "test-org", @@ -1298,80 +1298,35 @@ func TestGetActionData_ReadsOutputFromAttempts(t *testing.T) { Name: "action-1", } - // Build action spec with input URI - actionSpec := &workflow.ActionSpec{ - InputUri: "s3://bucket/inputs/inputs.pb", - } - actionSpecBytes, _ := proto.Marshal(actionSpec) - - runInfo := &workflow.RunInfo{} - runInfoBytes, _ := proto.Marshal(runInfo) - - actionModel := &models.Action{ - Project: "test-project", - Domain: "test-domain", - RunName: "rtest12345", - Name: "action-1", - Phase: int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED), - ActionType: int32(workflow.ActionType_ACTION_TYPE_TASK), - ActionSpec: actionSpecBytes, - DetailedInfo: runInfoBytes, - Attempts: 1, - } - - // Build event with output URI - successEvent := &workflow.ActionEvent{ - Id: actionID, - Attempt: 0, - Phase: common.ActionPhase_ACTION_PHASE_SUCCEEDED, - Version: 1, - UpdatedTime: timestamppb.Now(), - Outputs: &task.OutputReferences{ - OutputUri: "s3://bucket/outputs/action-1/outputs.pb", + dpClient := &mockDataProxyClient{} + svc := &RunService{dataProxyClient: dpClient} + dpClient.On("GetActionData", mock.Anything, mock.MatchedBy(func(req *connect.Request[dataproxy.GetActionDataRequest]) bool { + return proto.Equal(req.Msg.GetActionId(), actionID) + })).Return(connect.NewResponse(&dataproxy.GetActionDataResponse{ + Inputs: &task.Inputs{ + Literals: []*task.NamedLiteral{ + {Name: "x", Value: newStringLiteral("input")}, + }, }, - } - eventModel, _ := models.NewActionEventModel(successEvent) - - actionRepo.On("GetAction", mock.Anything, actionID).Return(actionModel, nil) - actionRepo.On("ListEvents", mock.Anything, matchActionID(actionID), 500).Return([]*models.ActionEvent{eventModel}, nil) - - // Mock reading inputs - store.On("ReadProtobuf", mock.Anything, storage.DataReference("s3://bucket/inputs/inputs.pb"), mock.AnythingOfType("*task.Inputs")). - Return(nil).Once() - - // Mock reading outputs — verify it reads from the attempt's output URI - store.On("ReadProtobuf", mock.Anything, storage.DataReference("s3://bucket/outputs/action-1/outputs.pb"), mock.AnythingOfType("*core.LiteralMap")). - Run(func(args mock.Arguments) { - lm := args.Get(2).(*core.LiteralMap) - lm.Literals = map[string]*core.Literal{ - "result": newStringLiteral("success"), - } - }). - Return(nil).Once() + Outputs: &task.Outputs{ + Literals: []*task.NamedLiteral{ + {Name: "result", Value: newStringLiteral("success")}, + }, + }, + }), nil).Once() resp, err := svc.GetActionData(context.Background(), connect.NewRequest(&workflow.GetActionDataRequest{ ActionId: actionID, })) require.NoError(t, err) + assert.Len(t, resp.Msg.Inputs.Literals, 1) assert.Len(t, resp.Msg.Outputs.Literals, 1) + assert.Equal(t, "x", resp.Msg.Inputs.Literals[0].Name) assert.Equal(t, "result", resp.Msg.Outputs.Literals[0].Name) - - store.AssertExpectations(t) + dpClient.AssertExpectations(t) } -func TestGetActionData_NonSucceededSkipsOutputs(t *testing.T) { - actionRepo := &repoMocks.ActionRepo{} - repo := &repoMocks.Repository{} - store := &storageMocks.ComposedProtobufStore{} - dataStore := &storage.DataStore{ComposedProtobufStore: store} - - repo.On("ActionRepo").Return(actionRepo) - - svc := &RunService{ - repo: repo, - dataStore: dataStore, - } - +func TestGetActionData_PropagatesDataProxyError(t *testing.T) { actionID := &common.ActionIdentifier{ Run: &common.RunIdentifier{ Org: "test-org", @@ -1382,40 +1337,20 @@ func TestGetActionData_NonSucceededSkipsOutputs(t *testing.T) { Name: "action-1", } - actionSpec := &workflow.ActionSpec{ - InputUri: "s3://bucket/inputs/inputs.pb", - } - actionSpecBytes, _ := proto.Marshal(actionSpec) + dpClient := &mockDataProxyClient{} + svc := &RunService{dataProxyClient: dpClient} + dpClient.On("GetActionData", mock.Anything, mock.Anything).Return( + nil, connect.NewError(connect.CodeNotFound, errors.New("action not found")), + ).Once() - runInfo := &workflow.RunInfo{} - runInfoBytes, _ := proto.Marshal(runInfo) - - actionModel := &models.Action{ - Project: "test-project", - Domain: "test-domain", - RunName: "rtest12345", - Name: "action-1", - Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), - ActionSpec: actionSpecBytes, - DetailedInfo: runInfoBytes, - } - - actionRepo.On("GetAction", mock.Anything, actionID).Return(actionModel, nil) - - // Mock reading inputs - store.On("ReadProtobuf", mock.Anything, storage.DataReference("s3://bucket/inputs/inputs.pb"), mock.AnythingOfType("*task.Inputs")). - Return(nil).Once() - - resp, err := svc.GetActionData(context.Background(), connect.NewRequest(&workflow.GetActionDataRequest{ + resp, err := svc.GetActionData(context.Background(), connect.NewRequest( + &workflow.GetActionDataRequest{ ActionId: actionID, })) - require.NoError(t, err) - // Outputs should be empty since action is still running - assert.Empty(t, resp.Msg.Outputs.Literals) - - store.AssertExpectations(t) - // Verify ReadProtobuf was only called once (for inputs, not outputs) - store.AssertNumberOfCalls(t, "ReadProtobuf", 1) + assert.Nil(t, resp) + require.Error(t, err) + assert.Equal(t, connect.CodeNotFound, connect.CodeOf(err)) + dpClient.AssertExpectations(t) } func TestCreateRun_WithOffloadedInputData(t *testing.T) { diff --git a/runs/setup.go b/runs/setup.go index f89c4cdc107..22483574edd 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -10,6 +10,7 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/app" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/actions/actionsconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect" projectpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect" @@ -59,6 +60,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { http.DefaultClient, projectsURL, ) + dataProxyClient := dataproxyconnect.NewDataProxyServiceClient(http.DefaultClient, projectsURL) abortReconciler := service.NewAbortReconciler(repo, actionsClient, service.AbortReconcilerConfig{ Workers: 5, @@ -71,7 +73,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { return abortReconciler.Run(ctx) }) - runsSvc := service.NewRunService(repo, actionsClient, projectClient, cfg.StoragePrefix, sc.DataStore, abortReconciler) + runsSvc := service.NewRunService(repo, actionsClient, dataProxyClient, projectClient, cfg.StoragePrefix, sc.DataStore, abortReconciler) taskSvc := service.NewTaskService(repo, projectClient) runsPath, runsHandler := workflowconnect.NewRunServiceHandler(runsSvc) diff --git a/runs/test/api/setup_test.go b/runs/test/api/setup_test.go index b62c6242a8e..d243ed1e07e 100644 --- a/runs/test/api/setup_test.go +++ b/runs/test/api/setup_test.go @@ -115,7 +115,7 @@ func TestMain(m *testing.M) { // Create RunService with a no-op actions client (points at test server; not used by watch tests) endpointURL := fmt.Sprintf("http://localhost:%d", testPort) actionsClient := actionsconnect.NewActionsServiceClient(http.DefaultClient, endpointURL) - runSvc := service.NewRunService(repo, actionsClient, nil, "", nil, nil) + runSvc := service.NewRunService(repo, actionsClient, nil, nil, "", nil, nil) // Setup HTTP server mux := http.NewServeMux()