Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pkg/util/testing/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,14 @@ func (m *MLPolicySourceWrapper) MPIPolicy(numProcPerNode *int32, MPImplementatio
return m
}

func (m *MLPolicySourceWrapper) FluxPolicy(numProcPerNode *int32) *MLPolicySourceWrapper {
if m.Flux == nil {
m.Flux = &trainer.FluxMLPolicySource{}
}
m.Flux.NumProcPerNode = numProcPerNode
return m
}

func (m *MLPolicySourceWrapper) Obj() *trainer.MLPolicySource {
return &m.MLPolicySource
}
Expand Down
105 changes: 105 additions & 0 deletions test/e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ const (
deepSpeedRuntime = "deepspeed-distributed"
jaxRuntime = "jax-distributed"
xgboostRuntime = "xgboost-distributed"
fluxRuntime = "flux-distributed"
)

//go:embed testdata/status_update.py
Expand Down Expand Up @@ -320,6 +321,110 @@ var _ = ginkgo.Describe("TrainJob e2e", func() {
})
})

ginkgo.When("Creating TrainJob to perform Flux workload", func() {
ginkgo.It("should create TrainJob with Flux runtime reference", func() {
ginkgo.By("Create Flux ClusterTrainingRuntime")
fluxRuntimeName := fluxRuntime + "-" + ns.Name
fluxClusterRuntimeWrapper := testingutil.MakeClusterTrainingRuntimeWrapper(fluxRuntimeName)
for _, rJob := range fluxClusterRuntimeWrapper.Spec.Template.Spec.ReplicatedJobs {
if rJob.Name == constants.Node {
rJob.Template.Spec.Template.Spec.Volumes = nil
rJob.Template.Spec.Template.Spec.Containers[0].VolumeMounts = nil
fluxClusterRuntimeWrapper.Spec.Template.Spec.ReplicatedJobs = []jobsetv1alpha2.ReplicatedJob{rJob}
break
}
}
fluxClusterRuntime := fluxClusterRuntimeWrapper.
RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(fluxClusterRuntimeWrapper.Spec).
WithMLPolicy(
testingutil.MakeMLPolicyWrapper().
WithNumNodes(2).
WithMLPolicySource(*testingutil.MakeMLPolicySourceWrapper().
FluxPolicy(ptr.To[int32](1)).
Obj(),
).
Obj(),
).
Container(
constants.Node,
constants.Node,
"ubuntu:22.04",
[]string{"bash"},
[]string{"-c", "true"},
corev1.ResourceList{},
).
Obj(),
).
Obj()
gomega.Expect(k8sClient.Create(ctx, fluxClusterRuntime)).Should(gomega.Succeed())
defer func() {
gomega.Expect(k8sClient.Delete(ctx, fluxClusterRuntime)).Should(gomega.Succeed())
}()

trainJob := testingutil.MakeTrainJobWrapper(ns.Name, "e2e-test-flux").
RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), fluxRuntimeName).
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
NumNodes(2).
Container(
"ubuntu:22.04",
[]string{"bash"},
[]string{"-c", "true"},
corev1.ResourceList{},
).
Obj(),
).
Obj()

ginkgo.By("Create a TrainJob with flux-distributed runtime reference", func() {
gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed())
})

ginkgo.By("Wait for TrainJob jobs to become active", func() {
gomega.Eventually(func(g gomega.Gomega) {
gotTrainJob := &trainer.TrainJob{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed())
g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{
{
Name: constants.Node,
Ready: ptr.To(int32(0)),
Succeeded: ptr.To(int32(0)),
Failed: ptr.To(int32(0)),
Active: ptr.To(int32(1)),
Suspended: ptr.To(int32(0)),
},
}, util.SortJobsStatus))
}, util.TimeoutE2E, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Wait for TrainJob to be in Succeeded status with all jobs succeeded", func() {
gomega.Eventually(func(g gomega.Gomega) {
gotTrainJob := &trainer.TrainJob{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed())
g.Expect(gotTrainJob.Status.Conditions).Should(gomega.BeComparableTo([]metav1.Condition{
{
Type: trainer.TrainJobComplete,
Status: metav1.ConditionTrue,
Reason: jobsetconsts.AllJobsCompletedReason,
Message: jobsetconsts.AllJobsCompletedMessage,
},
}, util.IgnoreConditions))
g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{
{
Name: constants.Node,
Ready: ptr.To(int32(0)),
Succeeded: ptr.To(int32(1)),
Failed: ptr.To(int32(0)),
Active: ptr.To(int32(0)),
Suspended: ptr.To(int32(0)),
},
}, util.SortJobsStatus))
}, util.TimeoutE2E, util.Interval).Should(gomega.Succeed())
})
})
})

ginkgo.When("Creating a TrainJob with RuntimePatches", func() {
ginkgo.It("should preserve user-provided manager fields", func() {
userTime := metav1.NewTime(time.Now().Add(-time.Hour).Truncate(time.Second))
Expand Down
251 changes: 251 additions & 0 deletions test/integration/controller/trainjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2275,5 +2275,256 @@ alpha-node-0-1.alpha slots=8
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})
})
ginkgo.Context("Integration Tests for the Flux Runtime", func() {
var (
configMapKey client.ObjectKey
secKey client.ObjectKey
)

makeFluxObjects := func(suspend bool) {
trainJobWrapper := testingutil.MakeTrainJobWrapper(ns.Name, "alpha").
RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), "alpha").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
NumNodes(2).
Container("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
Obj(),
)
if suspend {
trainJobWrapper.Suspend(true)
}
trainJob = trainJobWrapper.Obj()
trainJobKey = client.ObjectKeyFromObject(trainJob)

configMapKey = client.ObjectKey{
Name: fmt.Sprintf("%s-flux-entrypoint", trainJobKey.Name),
Namespace: trainJobKey.Namespace,
}
secKey = client.ObjectKey{
Name: fmt.Sprintf("%s-flux-curve", trainJobKey.Name),
Namespace: trainJobKey.Namespace,
}

trainingRuntime = testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha").
RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha").Spec).
WithMLPolicy(
testingutil.MakeMLPolicyWrapper().
WithNumNodes(2).
WithMLPolicySource(*testingutil.MakeMLPolicySourceWrapper().
FluxPolicy(ptr.To[int32](1)).
Obj(),
).
Obj(),
).
Container(constants.Node, constants.Node, "test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
Obj(),
).
Obj()
}

ginkgo.It("Should succeed to create TrainJob with Flux TrainingRuntime", func() {
ginkgo.By("Creating Flux TrainingRuntime and TrainJob")
makeFluxObjects(false)
gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed())
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime)).Should(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())
gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed())

ginkgo.By("Checking if the appropriate Flux JobSet is created")
gomega.Eventually(func(g gomega.Gomega) {
jobSet := &jobsetv1alpha2.JobSet{}
g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed())
g.Expect(jobSet.Spec.ReplicatedJobs).Should(gomega.HaveLen(3))

var nodeJob *jobsetv1alpha2.ReplicatedJob
for i := range jobSet.Spec.ReplicatedJobs {
if jobSet.Spec.ReplicatedJobs[i].Name == constants.Node {
nodeJob = &jobSet.Spec.ReplicatedJobs[i]
}
}
g.Expect(nodeJob).ShouldNot(gomega.BeNil())
g.Expect(nodeJob.Replicas).Should(gomega.Equal(int32(1)))
g.Expect(nodeJob.Template.Spec.Parallelism).Should(gomega.Equal(ptr.To[int32](2)))
g.Expect(nodeJob.Template.Spec.Completions).Should(gomega.Equal(ptr.To[int32](2)))

podSpec := nodeJob.Template.Spec.Template.Spec
var fluxInstaller *corev1.Container
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == constants.FluxInstallerContainerName {
fluxInstaller = &podSpec.InitContainers[i]
}
}
g.Expect(fluxInstaller).ShouldNot(gomega.BeNil())
g.Expect(fluxInstaller.Image).Should(gomega.Equal(constants.FluxInstallerImage))
g.Expect(fluxInstaller.Command).Should(gomega.Equal([]string{"/bin/bash", "/etc/flux-config/init.sh"}))
g.Expect(fluxInstaller.VolumeMounts).Should(gomega.ConsistOf(
corev1.VolumeMount{Name: constants.FluxInstallVolumeName, MountPath: constants.FluxVolumePath},
corev1.VolumeMount{Name: configMapKey.Name, MountPath: constants.FluxConfigVolumeName, ReadOnly: true},
))
g.Expect(podSpec.Volumes).Should(gomega.ContainElements(
corev1.Volume{
Name: constants.FluxSpackViewVolumeName,
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
},
corev1.Volume{
Name: constants.FluxInstallVolumeName,
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
},
corev1.Volume{
Name: constants.FluxMemoryVolumeName,
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
},
},
},
))
var configVolume, curveVolume *corev1.Volume
for i := range podSpec.Volumes {
switch podSpec.Volumes[i].Name {
case configMapKey.Name:
configVolume = &podSpec.Volumes[i]
case constants.FluxCurveVolumeName:
curveVolume = &podSpec.Volumes[i]
}
}
g.Expect(configVolume).ShouldNot(gomega.BeNil())
g.Expect(configVolume.ConfigMap).ShouldNot(gomega.BeNil())
g.Expect(configVolume.ConfigMap.Name).Should(gomega.Equal(configMapKey.Name))
g.Expect(curveVolume).ShouldNot(gomega.BeNil())
g.Expect(curveVolume.Secret).ShouldNot(gomega.BeNil())
g.Expect(curveVolume.Secret.SecretName).Should(gomega.Equal(secKey.Name))

var nodeContainer *corev1.Container
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == constants.Node {
nodeContainer = &podSpec.Containers[i]
}
}
g.Expect(nodeContainer).ShouldNot(gomega.BeNil())
g.Expect(nodeContainer.Image).Should(gomega.Equal("test:trainjob"))
g.Expect(nodeContainer.Command).Should(gomega.Equal([]string{"/bin/bash", "/etc/flux-config/entrypoint.sh", "trainjob trainjob"}))
g.Expect(nodeContainer.VolumeMounts).Should(gomega.ContainElements(
corev1.VolumeMount{Name: constants.FluxInstallVolumeName, MountPath: constants.FluxVolumePath},
corev1.VolumeMount{Name: constants.FluxSpackViewVolumeName, MountPath: constants.FluxSpackViewVolumePath},
corev1.VolumeMount{Name: configMapKey.Name, MountPath: constants.FluxConfigVolumeName, ReadOnly: true},
corev1.VolumeMount{Name: constants.FluxCurveVolumeName, MountPath: constants.FluxCurveVolumePath, ReadOnly: true},
corev1.VolumeMount{Name: constants.FluxMemoryVolumeName, MountPath: constants.FluxMemoryVolumePath, ReadOnly: true},
))
}, util.Timeout, util.Interval).Should(gomega.Succeed())

ginkgo.By("Checking if the Flux ConfigMap and Secret are created")
gomega.Eventually(func(g gomega.Gomega) {
cm := &corev1.ConfigMap{}
g.Expect(k8sClient.Get(ctx, configMapKey, cm)).Should(gomega.Succeed())
g.Expect(cm.Data).Should(gomega.HaveKey("entrypoint.sh"))
g.Expect(cm.Data).Should(gomega.HaveKey("init.sh"))

sec := &corev1.Secret{}
g.Expect(k8sClient.Get(ctx, secKey, sec)).Should(gomega.Succeed())
g.Expect(sec.Data).Should(gomega.HaveKey("curve.cert"))
g.Expect(string(sec.Data["curve.cert"])).Should(gomega.ContainSubstring("public-key"))
g.Expect(string(sec.Data["curve.cert"])).Should(gomega.ContainSubstring("secret-key"))
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.It("Should succeed to reconcile TrainJob conditions with Complete condition", func() {
ginkgo.By("Creating Flux TrainingRuntime and suspended TrainJob")
makeFluxObjects(true)
gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed())
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime)).Should(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())
gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed())

ginkgo.By("Checking if the JobSet was created")
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, trainJobKey, &jobsetv1alpha2.JobSet{})).Should(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())

ginkgo.By("Unsuspending TrainJob")
gomega.Eventually(func(g gomega.Gomega) {
gotTrainJob := &trainer.TrainJob{}
g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed())
gotTrainJob.Spec.Suspend = ptr.To(false)
g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())

ginkgo.By("Updating JobSet with completed status")
gomega.Eventually(func(g gomega.Gomega) {
jobSet := &jobsetv1alpha2.JobSet{}
g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed())
meta.SetStatusCondition(&jobSet.Status.Conditions, metav1.Condition{
Type: string(jobsetv1alpha2.JobSetCompleted),
Reason: jobsetconsts.AllJobsCompletedReason,
Message: jobsetconsts.AllJobsCompletedMessage,
Status: metav1.ConditionTrue,
})
jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{
{Name: constants.Node, Ready: 0, Succeeded: 1, Failed: 0, Active: 0, Suspended: 0},
}
g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())

ginkgo.By("Checking Complete=True condition")
gomega.Eventually(func(g gomega.Gomega) {
gotTrainJob := &trainer.TrainJob{}
g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed())
g.Expect(meta.IsStatusConditionTrue(gotTrainJob.Status.Conditions, trainer.TrainJobComplete)).Should(gomega.BeTrue())
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.It("Should succeed to reconcile TrainJob conditions with Failed condition", func() {
ginkgo.By("Creating Flux TrainingRuntime and suspended TrainJob")
makeFluxObjects(true)
gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed())
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime)).Should(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())
gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed())

ginkgo.By("Checking if the JobSet was created")
gomega.Eventually(func(g gomega.Gomega) {
g.Expect(k8sClient.Get(ctx, trainJobKey, &jobsetv1alpha2.JobSet{})).Should(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())

ginkgo.By("Unsuspending TrainJob")
gomega.Eventually(func(g gomega.Gomega) {
gotTrainJob := &trainer.TrainJob{}
g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed())
gotTrainJob.Spec.Suspend = ptr.To(false)
g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())

ginkgo.By("Updating JobSet with failed condition")
gomega.Eventually(func(g gomega.Gomega) {
jobSet := &jobsetv1alpha2.JobSet{}
g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed())
meta.SetStatusCondition(&jobSet.Status.Conditions, metav1.Condition{
Type: string(jobsetv1alpha2.JobSetFailed),
Reason: jobsetconsts.FailedJobsReason,
Message: jobsetconsts.FailedJobsMessage,
Status: metav1.ConditionTrue,
})
jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{
{Name: constants.Node, Ready: 0, Succeeded: 0, Failed: 1, Active: 0, Suspended: 0},
}
g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())

ginkgo.By("Checking Failed=True condition")
gomega.Eventually(func(g gomega.Gomega) {
gotTrainJob := &trainer.TrainJob{}
g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed())
g.Expect(meta.IsStatusConditionTrue(gotTrainJob.Status.Conditions, trainer.TrainJobFailed)).Should(gomega.BeTrue())
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})
})
})
})
Loading