Skip to content

Commit

Permalink
[TPU] Add envtests for multi-host (#1950)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin85421 authored Feb 29, 2024
1 parent 0f2f441 commit e7a00f9
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions ray-operator/controllers/ray/raycluster_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,100 @@ var _ = Context("Inside the default namespace", func() {
time.Second*3, time.Millisecond*500).Should(Equal(rayv1.Ready))
})
})

Describe("RayCluster with a multi-host worker group", func() {
ctx := context.Background()
namespace := "default"
rayCluster := rayClusterTemplate("raycluster-multihost", namespace)
numOfHosts := int32(4)
rayCluster.Spec.WorkerGroupSpecs[0].NumOfHosts = numOfHosts
rayCluster.Spec.EnableInTreeAutoscaling = pointer.Bool(true)
workerPods := corev1.PodList{}
workerFilterLabels := client.MatchingLabels{utils.RayClusterLabelKey: rayCluster.Name, utils.RayNodeGroupLabelKey: rayCluster.Spec.WorkerGroupSpecs[0].GroupName}

It("Verify RayCluster spec", func() {
// These test are designed based on the following assumptions:
// (1) Ray Autoscaler is enabled.
// (2) There is only one worker group, and its `replicas` is set to 3, and `workersToDelete` is empty.
// (3) The worker group is a multi-host TPU PodSlice consisting of 4 hosts.
Expect(*rayCluster.Spec.EnableInTreeAutoscaling).To(Equal(true))
Expect(len(rayCluster.Spec.WorkerGroupSpecs)).To(Equal(1))
Expect(rayCluster.Spec.WorkerGroupSpecs[0].NumOfHosts).To(Equal(numOfHosts))
Expect(rayCluster.Spec.WorkerGroupSpecs[0].Replicas).To(Equal(pointer.Int32(3)))
Expect(rayCluster.Spec.WorkerGroupSpecs[0].ScaleStrategy.WorkersToDelete).To(BeEmpty())
})

It("Create a RayCluster custom resource", func() {
err := k8sClient.Create(ctx, rayCluster)
Expect(err).NotTo(HaveOccurred(), "Failed to create RayCluster")
Eventually(
getResourceFunc(ctx, client.ObjectKey{Name: rayCluster.Name, Namespace: namespace}, rayCluster),
time.Second*3, time.Millisecond*500).Should(BeNil(), "Should be able to see RayCluster: %v", rayCluster.Name)
})

It("Check the number of worker Pods", func() {
numWorkerPods := 3 * int(numOfHosts)
Eventually(
listResourceFunc(ctx, &workerPods, workerFilterLabels, &client.ListOptions{Namespace: namespace}),
time.Second*3, time.Millisecond*500).Should(Equal(numWorkerPods), fmt.Sprintf("workerGroup %v", workerPods.Items))
})

It("Simulate Ray Autoscaler scales down", func() {
err := retry.RetryOnConflict(retry.DefaultRetry, func() error {
Eventually(
getResourceFunc(ctx, client.ObjectKey{Name: rayCluster.Name, Namespace: namespace}, rayCluster),
time.Second*3, time.Millisecond*500).Should(BeNil())
rayCluster.Spec.WorkerGroupSpecs[0].Replicas = pointer.Int32(2)
rayCluster.Spec.WorkerGroupSpecs[0].ScaleStrategy.WorkersToDelete = []string{
workerPods.Items[0].Name, workerPods.Items[1].Name, workerPods.Items[2].Name, workerPods.Items[3].Name,
}
return k8sClient.Update(ctx, rayCluster)
})
Expect(err).NotTo(HaveOccurred(), "Failed to update RayCluster custom resource")

numWorkerPods := 2 * int(numOfHosts)
Eventually(
listResourceFunc(ctx, &workerPods, workerFilterLabels, &client.ListOptions{Namespace: namespace}),
time.Second*3, time.Millisecond*500).Should(Equal(numWorkerPods), fmt.Sprintf("workerGroup %v", workerPods.Items))

// Ray Autoscaler should clean up WorkersToDelete after scaling process has finished.
// Call cleanUpWorkersToDelete to simulate the behavior of the Ray Autoscaler.
cleanUpWorkersToDelete(ctx, rayCluster, 0)
})

It("Simulate Ray Autoscaler scales up", func() {
err := retry.RetryOnConflict(retry.DefaultRetry, func() error {
Eventually(
getResourceFunc(ctx, client.ObjectKey{Name: rayCluster.Name, Namespace: namespace}, rayCluster),
time.Second*3, time.Millisecond*500).Should(BeNil())
rayCluster.Spec.WorkerGroupSpecs[0].Replicas = pointer.Int32(4)
return k8sClient.Update(ctx, rayCluster)
})
Expect(err).NotTo(HaveOccurred(), "Failed to update RayCluster custom resource")

numWorkerPods := 4 * int(numOfHosts)
Eventually(
listResourceFunc(ctx, &workerPods, workerFilterLabels, &client.ListOptions{Namespace: namespace}),
time.Second*3, time.Millisecond*500).Should(Equal(numWorkerPods), fmt.Sprintf("workerGroup %v", workerPods.Items))
})

It("Delete a worker Pod, and KubeRay should create a new one", func() {
numWorkerPods := 4 * int(numOfHosts)
Eventually(
listResourceFunc(ctx, &workerPods, workerFilterLabels, &client.ListOptions{Namespace: namespace}),
time.Second*3, time.Millisecond*500).Should(Equal(numWorkerPods), fmt.Sprintf("workerGroup %v", workerPods.Items))

pod := workerPods.Items[0]
err := k8sClient.Delete(ctx, &pod, &client.DeleteOptions{GracePeriodSeconds: pointer.Int64(0)})
Expect(err).NotTo(HaveOccurred(), "Failed to delete a Pod")
Eventually(
listResourceFunc(ctx, &workerPods, workerFilterLabels, &client.ListOptions{Namespace: namespace}),
time.Second*3, time.Millisecond*500).Should(Equal(numWorkerPods), fmt.Sprintf("workerGroup %v", workerPods.Items))
Consistently(
listResourceFunc(ctx, &workerPods, workerFilterLabels, &client.ListOptions{Namespace: namespace}),
time.Second*2, time.Millisecond*200).Should(Equal(numWorkerPods), fmt.Sprintf("workerGroup %v", workerPods.Items))
})
})
})

func getResourceFunc(ctx context.Context, key client.ObjectKey, obj client.Object) func() error {
Expand Down

0 comments on commit e7a00f9

Please sign in to comment.