diff --git a/pkg/util/testing/wrapper.go b/pkg/util/testing/wrapper.go index 665254c7ee..01d8902f25 100644 --- a/pkg/util/testing/wrapper.go +++ b/pkg/util/testing/wrapper.go @@ -1288,6 +1288,14 @@ func (m *MLPolicySourceWrapper) MPIPolicy(numProcPerNode *int32, MPImplementatio return m } +func (m *MLPolicySourceWrapper) FluxPolicy(numProcPerNode *int32) *MLPolicySourceWrapper { + if m.Flux == nil { + m.Flux = &trainer.FluxMLPolicySource{} + } + m.Flux.NumProcPerNode = numProcPerNode + return m +} + func (m *MLPolicySourceWrapper) Obj() *trainer.MLPolicySource { return &m.MLPolicySource } diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index e87b2688b3..1c540efb01 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -41,6 +41,7 @@ const ( deepSpeedRuntime = "deepspeed-distributed" jaxRuntime = "jax-distributed" xgboostRuntime = "xgboost-distributed" + fluxRuntime = "flux-distributed" ) //go:embed testdata/status_update.py @@ -320,6 +321,110 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { }) }) + ginkgo.When("Creating TrainJob to perform Flux workload", func() { + ginkgo.It("should create TrainJob with Flux runtime reference", func() { + ginkgo.By("Create Flux ClusterTrainingRuntime") + fluxRuntimeName := fluxRuntime + "-" + ns.Name + fluxClusterRuntimeWrapper := testingutil.MakeClusterTrainingRuntimeWrapper(fluxRuntimeName) + for _, rJob := range fluxClusterRuntimeWrapper.Spec.Template.Spec.ReplicatedJobs { + if rJob.Name == constants.Node { + rJob.Template.Spec.Template.Spec.Volumes = nil + rJob.Template.Spec.Template.Spec.Containers[0].VolumeMounts = nil + fluxClusterRuntimeWrapper.Spec.Template.Spec.ReplicatedJobs = []jobsetv1alpha2.ReplicatedJob{rJob} + break + } + } + fluxClusterRuntime := fluxClusterRuntimeWrapper. + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(fluxClusterRuntimeWrapper.Spec). + WithMLPolicy( + testingutil.MakeMLPolicyWrapper(). + WithNumNodes(2). + WithMLPolicySource(*testingutil.MakeMLPolicySourceWrapper(). + FluxPolicy(ptr.To[int32](1)). + Obj(), + ). + Obj(), + ). + Container( + constants.Node, + constants.Node, + "ubuntu:22.04", + []string{"bash"}, + []string{"-c", "true"}, + corev1.ResourceList{}, + ). + Obj(), + ). + Obj() + gomega.Expect(k8sClient.Create(ctx, fluxClusterRuntime)).Should(gomega.Succeed()) + defer func() { + gomega.Expect(k8sClient.Delete(ctx, fluxClusterRuntime)).Should(gomega.Succeed()) + }() + + trainJob := testingutil.MakeTrainJobWrapper(ns.Name, "e2e-test-flux"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), fluxRuntimeName). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(). + NumNodes(2). + Container( + "ubuntu:22.04", + []string{"bash"}, + []string{"-c", "true"}, + corev1.ResourceList{}, + ). + Obj(), + ). + Obj() + + ginkgo.By("Create a TrainJob with flux-distributed runtime reference", func() { + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + }) + + ginkgo.By("Wait for TrainJob jobs to become active", func() { + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.Node, + Ready: ptr.To(int32(0)), + Succeeded: ptr.To(int32(0)), + Failed: ptr.To(int32(0)), + Active: ptr.To(int32(1)), + Suspended: ptr.To(int32(0)), + }, + }, util.SortJobsStatus)) + }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Wait for TrainJob to be in Succeeded status with all jobs succeeded", func() { + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.Conditions).Should(gomega.BeComparableTo([]metav1.Condition{ + { + Type: trainer.TrainJobComplete, + Status: metav1.ConditionTrue, + Reason: jobsetconsts.AllJobsCompletedReason, + Message: jobsetconsts.AllJobsCompletedMessage, + }, + }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.Node, + Ready: ptr.To(int32(0)), + Succeeded: ptr.To(int32(1)), + Failed: ptr.To(int32(0)), + Active: ptr.To(int32(0)), + Suspended: ptr.To(int32(0)), + }, + }, util.SortJobsStatus)) + }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) + }) + }) + }) + ginkgo.When("Creating a TrainJob with RuntimePatches", func() { ginkgo.It("should preserve user-provided manager fields", func() { userTime := metav1.NewTime(time.Now().Add(-time.Hour).Truncate(time.Second)) diff --git a/test/integration/controller/trainjob_controller_test.go b/test/integration/controller/trainjob_controller_test.go index f3a28e3702..0270a25d31 100644 --- a/test/integration/controller/trainjob_controller_test.go +++ b/test/integration/controller/trainjob_controller_test.go @@ -2275,5 +2275,256 @@ alpha-node-0-1.alpha slots=8 }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) }) + ginkgo.Context("Integration Tests for the Flux Runtime", func() { + var ( + configMapKey client.ObjectKey + secKey client.ObjectKey + ) + + makeFluxObjects := func(suspend bool) { + trainJobWrapper := testingutil.MakeTrainJobWrapper(ns.Name, "alpha"). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), "alpha"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(). + NumNodes(2). + Container("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests). + Obj(), + ) + if suspend { + trainJobWrapper.Suspend(true) + } + trainJob = trainJobWrapper.Obj() + trainJobKey = client.ObjectKeyFromObject(trainJob) + + configMapKey = client.ObjectKey{ + Name: fmt.Sprintf("%s-flux-entrypoint", trainJobKey.Name), + Namespace: trainJobKey.Namespace, + } + secKey = client.ObjectKey{ + Name: fmt.Sprintf("%s-flux-curve", trainJobKey.Name), + Namespace: trainJobKey.Namespace, + } + + trainingRuntime = testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha"). + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha").Spec). + WithMLPolicy( + testingutil.MakeMLPolicyWrapper(). + WithNumNodes(2). + WithMLPolicySource(*testingutil.MakeMLPolicySourceWrapper(). + FluxPolicy(ptr.To[int32](1)). + Obj(), + ). + Obj(), + ). + Container(constants.Node, constants.Node, "test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests). + Obj(), + ). + Obj() + } + + ginkgo.It("Should succeed to create TrainJob with Flux TrainingRuntime", func() { + ginkgo.By("Creating Flux TrainingRuntime and TrainJob") + makeFluxObjects(false) + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed()) + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + + ginkgo.By("Checking if the appropriate Flux JobSet is created") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + g.Expect(jobSet.Spec.ReplicatedJobs).Should(gomega.HaveLen(3)) + + var nodeJob *jobsetv1alpha2.ReplicatedJob + for i := range jobSet.Spec.ReplicatedJobs { + if jobSet.Spec.ReplicatedJobs[i].Name == constants.Node { + nodeJob = &jobSet.Spec.ReplicatedJobs[i] + } + } + g.Expect(nodeJob).ShouldNot(gomega.BeNil()) + g.Expect(nodeJob.Replicas).Should(gomega.Equal(int32(1))) + g.Expect(nodeJob.Template.Spec.Parallelism).Should(gomega.Equal(ptr.To[int32](2))) + g.Expect(nodeJob.Template.Spec.Completions).Should(gomega.Equal(ptr.To[int32](2))) + + podSpec := nodeJob.Template.Spec.Template.Spec + var fluxInstaller *corev1.Container + for i := range podSpec.InitContainers { + if podSpec.InitContainers[i].Name == constants.FluxInstallerContainerName { + fluxInstaller = &podSpec.InitContainers[i] + } + } + g.Expect(fluxInstaller).ShouldNot(gomega.BeNil()) + g.Expect(fluxInstaller.Image).Should(gomega.Equal(constants.FluxInstallerImage)) + g.Expect(fluxInstaller.Command).Should(gomega.Equal([]string{"/bin/bash", "/etc/flux-config/init.sh"})) + g.Expect(fluxInstaller.VolumeMounts).Should(gomega.ConsistOf( + corev1.VolumeMount{Name: constants.FluxInstallVolumeName, MountPath: constants.FluxVolumePath}, + corev1.VolumeMount{Name: configMapKey.Name, MountPath: constants.FluxConfigVolumeName, ReadOnly: true}, + )) + g.Expect(podSpec.Volumes).Should(gomega.ContainElements( + corev1.Volume{ + Name: constants.FluxSpackViewVolumeName, + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + corev1.Volume{ + Name: constants.FluxInstallVolumeName, + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + corev1.Volume{ + Name: constants.FluxMemoryVolumeName, + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{ + Medium: corev1.StorageMediumMemory, + }, + }, + }, + )) + var configVolume, curveVolume *corev1.Volume + for i := range podSpec.Volumes { + switch podSpec.Volumes[i].Name { + case configMapKey.Name: + configVolume = &podSpec.Volumes[i] + case constants.FluxCurveVolumeName: + curveVolume = &podSpec.Volumes[i] + } + } + g.Expect(configVolume).ShouldNot(gomega.BeNil()) + g.Expect(configVolume.ConfigMap).ShouldNot(gomega.BeNil()) + g.Expect(configVolume.ConfigMap.Name).Should(gomega.Equal(configMapKey.Name)) + g.Expect(curveVolume).ShouldNot(gomega.BeNil()) + g.Expect(curveVolume.Secret).ShouldNot(gomega.BeNil()) + g.Expect(curveVolume.Secret.SecretName).Should(gomega.Equal(secKey.Name)) + + var nodeContainer *corev1.Container + for i := range podSpec.Containers { + if podSpec.Containers[i].Name == constants.Node { + nodeContainer = &podSpec.Containers[i] + } + } + g.Expect(nodeContainer).ShouldNot(gomega.BeNil()) + g.Expect(nodeContainer.Image).Should(gomega.Equal("test:trainjob")) + g.Expect(nodeContainer.Command).Should(gomega.Equal([]string{"/bin/bash", "/etc/flux-config/entrypoint.sh", "trainjob trainjob"})) + g.Expect(nodeContainer.VolumeMounts).Should(gomega.ContainElements( + corev1.VolumeMount{Name: constants.FluxInstallVolumeName, MountPath: constants.FluxVolumePath}, + corev1.VolumeMount{Name: constants.FluxSpackViewVolumeName, MountPath: constants.FluxSpackViewVolumePath}, + corev1.VolumeMount{Name: configMapKey.Name, MountPath: constants.FluxConfigVolumeName, ReadOnly: true}, + corev1.VolumeMount{Name: constants.FluxCurveVolumeName, MountPath: constants.FluxCurveVolumePath, ReadOnly: true}, + corev1.VolumeMount{Name: constants.FluxMemoryVolumeName, MountPath: constants.FluxMemoryVolumePath, ReadOnly: true}, + )) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Checking if the Flux ConfigMap and Secret are created") + gomega.Eventually(func(g gomega.Gomega) { + cm := &corev1.ConfigMap{} + g.Expect(k8sClient.Get(ctx, configMapKey, cm)).Should(gomega.Succeed()) + g.Expect(cm.Data).Should(gomega.HaveKey("entrypoint.sh")) + g.Expect(cm.Data).Should(gomega.HaveKey("init.sh")) + + sec := &corev1.Secret{} + g.Expect(k8sClient.Get(ctx, secKey, sec)).Should(gomega.Succeed()) + g.Expect(sec.Data).Should(gomega.HaveKey("curve.cert")) + g.Expect(string(sec.Data["curve.cert"])).Should(gomega.ContainSubstring("public-key")) + g.Expect(string(sec.Data["curve.cert"])).Should(gomega.ContainSubstring("secret-key")) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.It("Should succeed to reconcile TrainJob conditions with Complete condition", func() { + ginkgo.By("Creating Flux TrainingRuntime and suspended TrainJob") + makeFluxObjects(true) + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed()) + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + + ginkgo.By("Checking if the JobSet was created") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, &jobsetv1alpha2.JobSet{})).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Unsuspending TrainJob") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + gotTrainJob.Spec.Suspend = ptr.To(false) + g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Updating JobSet with completed status") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + meta.SetStatusCondition(&jobSet.Status.Conditions, metav1.Condition{ + Type: string(jobsetv1alpha2.JobSetCompleted), + Reason: jobsetconsts.AllJobsCompletedReason, + Message: jobsetconsts.AllJobsCompletedMessage, + Status: metav1.ConditionTrue, + }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + {Name: constants.Node, Ready: 0, Succeeded: 1, Failed: 0, Active: 0, Suspended: 0}, + } + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Checking Complete=True condition") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(meta.IsStatusConditionTrue(gotTrainJob.Status.Conditions, trainer.TrainJobComplete)).Should(gomega.BeTrue()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.It("Should succeed to reconcile TrainJob conditions with Failed condition", func() { + ginkgo.By("Creating Flux TrainingRuntime and suspended TrainJob") + makeFluxObjects(true) + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed()) + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + + ginkgo.By("Checking if the JobSet was created") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, &jobsetv1alpha2.JobSet{})).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Unsuspending TrainJob") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + gotTrainJob.Spec.Suspend = ptr.To(false) + g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Updating JobSet with failed condition") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + meta.SetStatusCondition(&jobSet.Status.Conditions, metav1.Condition{ + Type: string(jobsetv1alpha2.JobSetFailed), + Reason: jobsetconsts.FailedJobsReason, + Message: jobsetconsts.FailedJobsMessage, + Status: metav1.ConditionTrue, + }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + {Name: constants.Node, Ready: 0, Succeeded: 0, Failed: 1, Active: 0, Suspended: 0}, + } + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Checking Failed=True condition") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(meta.IsStatusConditionTrue(gotTrainJob.Status.Conditions, trainer.TrainJobFailed)).Should(gomega.BeTrue()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + }) }) })