Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion ray-operator/controllers/ray/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
"github.com/ray-project/kuberay/ray-operator/pkg/features"
)

const (
Expand Down Expand Up @@ -244,7 +245,7 @@ func getEnableProbesInjection() bool {
}

// DefaultWorkerPodTemplate sets the config values
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string) corev1.PodTemplateSpec {
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, numHostIndex int) corev1.PodTemplateSpec {
podTemplate := workerSpec.Template
podTemplate.GenerateName = podName
// Pods created by RayCluster should be restricted to the namespace of the RayCluster.
Expand Down Expand Up @@ -315,6 +316,11 @@ func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, wo
podTemplate.Labels = make(map[string]string)
}
podTemplate.Labels = labelPod(rayv1.WorkerNode, instance.Name, workerSpec.GroupName, workerSpec.Template.ObjectMeta.Labels)
// Add additional labels for RayMultihostIndexing
multihostIndexingEnabled := features.Enabled(features.RayMulithostIndexing) && workerSpec.NumOfHosts > 1
if multihostIndexingEnabled {
podTemplate.Labels = addMultihostIndexingPodLabels(podTemplate.Labels, replicaGrpName, numHostIndex)
}
workerSpec.RayStartParams = setMissingRayStartParams(ctx, workerSpec.RayStartParams, rayv1.WorkerNode, headPort, fqdnRayIP)

initTemplateAnnotations(instance, &podTemplate)
Expand Down Expand Up @@ -628,6 +634,15 @@ func labelPod(rayNodeType rayv1.RayNodeType, rayClusterName string, groupName st
return labels
}

// addMultihostIndexingPodLabels returns labels that contain RayMultihostIndexing feature labels
func addMultihostIndexingPodLabels(currentLabels map[string]string, replicaGrpName string, numHostIndex int) map[string]string {
labels := currentLabels
labels[utils.RayWorkerReplicaIndexKey] = replicaGrpName
labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex)

return labels
}

func setInitContainerEnvVars(container *corev1.Container, fqdnRayIP string) {
if len(container.Env) == 0 {
container.Env = []corev1.EnvVar{}
Expand Down
22 changes: 11 additions & 11 deletions ray-operator/controllers/ray/common/pod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ func TestBuildPod(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)

// Check resources
Expand Down Expand Up @@ -752,7 +752,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080")
actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0])
Expand Down Expand Up @@ -783,7 +783,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
workerContainer := workerPod.Spec.Containers[utils.RayContainerIndex]
assert.Equal(t, []string{"I am worker"}, workerContainer.Command)
Expand Down Expand Up @@ -838,7 +838,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP)

val, ok = pod.Labels[utils.RayClusterServingServiceLabelKey]
Expand Down Expand Up @@ -894,7 +894,7 @@ func TestBuildPod_WithLoginBash(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP)

// Verify worker container command
Expand Down Expand Up @@ -1157,7 +1157,7 @@ func TestDefaultWorkerPodTemplateWithName(t *testing.T) {
expectedWorker := *worker.DeepCopy()

// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
assert.Equal(t, expectedWorker, worker)
}
Expand Down Expand Up @@ -1204,7 +1204,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName := cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
// DefaultWorkerPodTemplate will add the default metrics port if user doesn't specify it.
// Verify the default metrics port exists.
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, int32(utils.DefaultMetricsPort)))
Expand All @@ -1214,7 +1214,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
ContainerPort: customMetricsPort,
}
cluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Ports = []corev1.ContainerPort{metricsPort}
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
// Verify the custom metrics port exists.
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, customMetricsPort))
}
Expand Down Expand Up @@ -1253,7 +1253,7 @@ func TestDefaultWorkerPodTemplate_Autoscaling(t *testing.T) {

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379")
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0)
assert.Equal(t, tc.expectedRestartPolicy, podTemplateSpec.Spec.RestartPolicy)
})
}
Expand All @@ -1269,7 +1269,7 @@ func TestDefaultInitContainer(t *testing.T) {
expectedResult := len(cluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers) + 1

// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
numInitContainers := len(podTemplateSpec.Spec.InitContainers)
assert.Equal(t, expectedResult, numInitContainers, "A default init container is expected to be added.")

Expand Down Expand Up @@ -1328,7 +1328,7 @@ func TestDefaultInitContainerImagePullPolicy(t *testing.T) {
// set ray container imagePullPolicy
worker.Template.Spec.Containers[utils.RayContainerIndex].ImagePullPolicy = tc.imagePullPolicy

podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test with replicaGrpName set and numHostIndex > 1


healthCheckContainer := podTemplateSpec.Spec.InitContainers[len(podTemplateSpec.Spec.InitContainers)-1]
assert.Equal(t, tc.expectedPullPolicy, healthCheckContainer.ImagePullPolicy, "The ImagePullPolicy of the init container should be the same as the Ray container.")
Expand Down
Loading
Loading