Skip to content

Commit

Permalink
Inject pod affinity and anti-affinity labels
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanaoleary committed Mar 5, 2024
1 parent 9bc9133 commit 6c0e411
Showing 1 changed file with 78 additions and 31 deletions.
109 changes: 78 additions & 31 deletions applications/ray/kuberay-tpu-webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func genDNSHostnames(workerGroupSpec ray.WorkerGroupSpec, replicaIndex int) (str
// inject subdomain and TPU_WORKER_HOSTNAMES into pods for TPU multi-host initialization
func injectHostnames(hostNames string, envPath string, container corev1.Container, patches *[]patch) {
subdomainPatch, hostNamesPatch := patch{"op": "add"}, patch{"op": "add"}
subdomainPath := "template/spec/subdomain"
subdomainPath := "/spec/subdomain"
tpuWorkerHostNames := corev1.EnvVar{
Name: "TPU_WORKER_HOSTNAMES",
Value: hostNames,
Expand Down Expand Up @@ -165,6 +165,37 @@ func injectMultiHostReplicaLabel(replicaIndex int, workerGroupName string, patch
*patches = append(*patches, labelPatch)
}

// inject pod affinity and anti-affinity scheduling constraints using multiHostReplica label
func injectPodAffinity(replicaIndex int, workerGroupName string, patches *[]patch) {
podAffinityPatch := patch{"op": "add"}
podAffinityPath := "/spec/affinity"
podAntiAffinityPatch := patch{"op": "add"}
podAntiAffinityPath := "/spec/affinity"

// construct pod affinity value to inject - schedule pods with the same multiHostReplica together
key := "multiHostReplica"
value := []string{workerGroupName + strconv.Itoa(replicaIndex)}
affinitySelectorRequirement := metav1.LabelSelectorRequirement{key, metav1.LabelSelectorOpIn, value}
affinityMatchExpressions := []metav1.LabelSelectorRequirement{affinitySelectorRequirement}
affinityLabelSelector := metav1.LabelSelector{MatchExpressions: affinityMatchExpressions}
podAffinityValue := corev1.PodAffinityTerm{LabelSelector: &affinityLabelSelector}

// construct pod anti-affinity value to inject - our requirement is that we don't
// schedule any other pods with the multi-host replica pods when the multiHostReplica label exists
antiSelectorRequirement := metav1.LabelSelectorRequirement{key, metav1.LabelSelectorOpNotIn, value}
labelExistsRequirement := metav1.LabelSelectorRequirement{key, metav1.LabelSelectorOpExists, value}
antiMatchExpressions := []metav1.LabelSelectorRequirement{antiSelectorRequirement, labelExistsRequirement}
antiLabelSelector := metav1.LabelSelector{MatchExpressions: antiMatchExpressions}
podAntiAffinityValue := corev1.PodAffinityTerm{LabelSelector: &antiLabelSelector}

podAffinityPatch["path"] = podAffinityPath
podAffinityPatch["value"] = corev1.PodAffinity{RequiredDuringSchedulingIgnoredDuringExecution: []corev1.PodAffinityTerm{podAffinityValue}}
podAntiAffinityPatch["path"] = podAntiAffinityPath
podAntiAffinityPatch["value"] = corev1.PodAntiAffinity{RequiredDuringSchedulingIgnoredDuringExecution: []corev1.PodAffinityTerm{podAntiAffinityValue}}

*patches = append(*patches, podAffinityPatch, podAntiAffinityPatch)
}

// check that the # of Ray TPU worker pods equals the # of hosts defined in the topology key
func checkWorkersMatchTopology(workerGroupSpec ray.WorkerGroupSpec) (bool, error) {
numHosts := workerGroupSpec.NumOfHosts // 1 TPU VM host -> 1 Ray worker pod
Expand Down Expand Up @@ -288,37 +319,37 @@ func getReplicaIndex() int {

// returns next lowest TPU_WORKER_ID in pod slice and updates mappings
func getNextWorkerID(podSlice slice, replicaIndex int) int {
tpu_worker_id := 0
tpuWorkerID := 0
if sliceToWorkers[podSlice] == nil {
new_worker := worker{tpu_worker_id, replicaIndex, true}
sliceToWorkers[podSlice] = []worker{new_worker}
newWorker := worker{tpuWorkerID, replicaIndex, true}
sliceToWorkers[podSlice] = []worker{newWorker}
} else {
next_lowest_id := math.MaxInt32
replace_pod := false
nextLowestID := math.MaxInt32
replacePod := false
// iterate through existing workers and check if any have been deleted
for _, worker := range sliceToWorkers[podSlice] {
if worker.isRunning == false && worker.workerIndex < next_lowest_id {
replace_pod = true
next_lowest_id = worker.workerIndex
if worker.isRunning == false && worker.workerIndex < nextLowestID {
replacePod = true
nextLowestID = worker.workerIndex
}
}
// reassign next lowest TPU_WORKER_ID if pod has been deleted
if replace_pod == true {
if replacePod == true {
for _, worker := range sliceToWorkers[podSlice] {
// set worker.isRunning to true now that pod is being re-created
if worker.workerIndex == next_lowest_id {
if worker.workerIndex == nextLowestID {
worker.isRunning = true
}
}
} else {
// all pods are running -> create new worker with next TPU_WORKER_ID
next_lowest_id = len(sliceToWorkers[podSlice])
new_worker := worker{next_lowest_id, replicaIndex, true}
sliceToWorkers[podSlice] = append(sliceToWorkers[podSlice], new_worker)
nextLowestID = len(sliceToWorkers[podSlice])
newWorker := worker{nextLowestID, replicaIndex, true}
sliceToWorkers[podSlice] = append(sliceToWorkers[podSlice], newWorker)
}
tpu_worker_id = next_lowest_id
tpuWorkerID = nextLowestID
}
return tpu_worker_id
return tpuWorkerID
}

// unmarshal pod from admission request
Expand All @@ -328,8 +359,14 @@ func extractPod(admissionReview *admissionv1.AdmissionReview) (*corev1.Pod, erro
}

pod := corev1.Pod{}
if err := json.Unmarshal(admissionReview.Request.Object.Raw, &pod); err != nil {
return nil, err
if admissionReview.Request.Operation == "CREATE" {
if err := json.Unmarshal(admissionReview.Request.Object.Raw, &pod); err != nil {
return nil, err
}
} else if admissionReview.Request.Operation == "DELETE" {
if err := json.Unmarshal(admissionReview.Request.OldObject.Raw, &pod); err != nil {
return nil, err
}
}

return &pod, nil
Expand Down Expand Up @@ -362,20 +399,21 @@ func mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.Admis
podSlice := slice{clusterName, groupName, replicaIndex, numOfHosts}
tpuWorkerID := getNextWorkerID(podSlice, replicaIndex)

// if multihost -> inject hostname into pod spec for DNS records
isMultiHost, _ := isTPUMultiHost(topology) // ignore error here because topology may not be set yet
if isMultiHost {
// inject hostname into pod spec for DNS records
hostname := fmt.Sprintf(groupName+"-%d-%d", replicaIndex, tpuWorkerID)
hostnamePatch := patch{"op": "add"}
hostnamePatch["path"] = "/spec/hostname"
hostnamePatch["value"] = hostname
patches = append(patches, hostnamePatch)
}

// inject multi-host replica label
injectMultiHostReplicaLabel(replicaIndex, groupName, &patches)
// inject multi-host replica label
injectMultiHostReplicaLabel(replicaIndex, groupName, &patches)

// inject pod affinity/anti-affinity for scheduling
// inject pod affinity/anti-affinity for scheduling
injectPodAffinity(replicaIndex, groupName, &patches)
}

// inject all environment variables into the container requesting TPUs
for i := 0; i < len(containers); i++ {
Expand Down Expand Up @@ -529,14 +567,7 @@ func main() {

if admissionReview.Request.Kind.Kind == "Pod" {
klog.Info("Received review for Pod")
var response *admissionv1.AdmissionResponse
var err error
if admissionReview.Request.Operation == "CREATE" {
response, err = mutatePod(admissionReview)
}
if admissionReview.Request.Operation == "DELETE" {
response, err = deletePod(admissionReview)
}
response, err := mutatePod(admissionReview)
if err != nil {
klog.Errorf("Failed to mutate pod: %s", err)
return
Expand All @@ -558,6 +589,22 @@ func main() {
return
}

if admissionReview.Request.Kind.Kind == "Pod" {
klog.Info("Received review for Pod deletion")
response, err := deletePod(admissionReview)
if err != nil {
klog.Errorf("Failed to validate pod deletion: %s", err)
return
}
admissionReview.Response = response
responseBytes, err := json.Marshal(admissionReview)
if err != nil {
klog.Errorf("Failed to encode response: %s", err)
return
}
fmt.Fprint(w, string(responseBytes))
}

if admissionReview.Request.Kind.Kind == "RayCluster" {
klog.Info("Received review for RayCluster")
response, err := validateRayCluster(admissionReview)
Expand Down

0 comments on commit 6c0e411

Please sign in to comment.