diff --git a/pkg/networking/pod_eni_info_resolver.go b/pkg/networking/pod_eni_info_resolver.go index a010b58b6..b70999f73 100644 --- a/pkg/networking/pod_eni_info_resolver.go +++ b/pkg/networking/pod_eni_info_resolver.go @@ -26,7 +26,8 @@ const ( // EC2:DescribeNetworkInterface supports up to 200 filters per call. describeNetworkInterfacesFiltersLimit = 200 - labelEKSComputeType = "eks.amazonaws.com/compute-type" + labelEKSComputeType = "eks.amazonaws.com/compute-type" + labelSageMakerComputeType = "sagemaker.amazonaws.com/compute-type" ) // PodENIInfoResolver is responsible for resolve the AWS VPC ENI that supports pod network. @@ -141,20 +142,20 @@ func (r *defaultPodENIInfoResolver) saveENIInfosToCache(pods []k8s.PodInfo, eniI } func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) { - podsOnEc2, podsOnFargate, err := r.classifyPodsByComputeType(ctx, pods) + podsByComputeType, err := r.classifyPodsByComputeType(ctx, pods) if err != nil { return nil, err } eniInfoByPodKey := make(map[types.NamespacedName]ENIInfo) - if len(podsOnEc2) > 0 { - eniInfoByPodKeyEc2, err := r.resolveViaCascadedLookup(ctx, podsOnEc2, false) + if len(podsByComputeType.ec2Pods) > 0 { + eniInfoByPodKeyEc2, err := r.resolveViaCascadedLookup(ctx, podsByComputeType.ec2Pods, false) if err != nil { return nil, err } eniInfoByPodKey = eniInfoByPodKeyEc2 } - if len(podsOnFargate) > 0 { - eniInfoByPodKeyFargate, err := r.resolveViaCascadedLookup(ctx, podsOnFargate, true) + if len(podsByComputeType.fargatePods) > 0 { + eniInfoByPodKeyFargate, err := r.resolveViaCascadedLookup(ctx, podsByComputeType.fargatePods, true) if err != nil { return nil, err } @@ -164,17 +165,28 @@ func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Con } } } + if len(podsByComputeType.sageMakerHyperPodPods) > 0 { + eniInfoByPodKeySageMakerHyperPod, err := r.resolveViaCascadedLookup(ctx, podsByComputeType.sageMakerHyperPodPods, true) + if err != nil { + return nil, err + } + if len(eniInfoByPodKeySageMakerHyperPod) > 0 { + for podKey, eniInfo := range eniInfoByPodKeySageMakerHyperPod { + eniInfoByPodKey[podKey] = eniInfo + } + } + } return eniInfoByPodKey, nil } -func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isFargateNode bool) (map[types.NamespacedName]ENIInfo, error) { +func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isNonEc2Pod bool) (map[types.NamespacedName]ENIInfo, error) { eniInfoByPodKey := make(map[types.NamespacedName]ENIInfo) resolveFuncs := []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){ r.resolveViaPodENIAnnotation, r.resolveViaNodeENIs, // TODO, add support for kubenet CNI plugin(kops) by resolve via routeTable. } - if isFargateNode { + if isNonEc2Pod { resolveFuncs = []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){ r.resolveViaVPCENIs, } @@ -281,6 +293,7 @@ func (r *defaultPodENIInfoResolver) resolveViaNodeENIs(ctx context.Context, pods // resolveViaVPCENIs tries to resolve pod ENI by matching podIP against ENIs in vpc. // with EKS fargate pods, podIP is supported by an ENI in vpc. +// with SageMaker HyperPod pods, podIP is supported by the visible cross-account ENI in customer vpc. func (r *defaultPodENIInfoResolver) resolveViaVPCENIs(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) { podKeysByIP := make(map[string][]types.NamespacedName, len(pods)) for _, pod := range pods { @@ -388,33 +401,45 @@ func (r *defaultPodENIInfoResolver) isPodSupportedByNodeENI(pod k8s.PodInfo, nod return false } -// classifyPodsByComputeType classifies in to ec2 and fargate groups -func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) ([]k8s.PodInfo, []k8s.PodInfo, error) { - podsOnFargate := make([]k8s.PodInfo, 0, len(pods)) - podsOnEc2 := make([]k8s.PodInfo, 0, len(pods)) +// PodsByComputeType groups pods based on their compute type (EC2, Fargate, SageMaker HyperPod) +type PodsByComputeType struct { + ec2Pods []k8s.PodInfo + fargatePods []k8s.PodInfo + sageMakerHyperPodPods []k8s.PodInfo +} + +// classifyPodsByComputeType classifies in to ec2, fargate and sagemaker-hyperpod groups +func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) (PodsByComputeType, error) { + var podsByComputeType PodsByComputeType nodeNameByComputeType := make(map[string]string) for _, pod := range pods { if _, exists := nodeNameByComputeType[pod.NodeName]; exists { if nodeNameByComputeType[pod.NodeName] == "fargate" { - podsOnFargate = append(podsOnFargate, pod) + podsByComputeType.fargatePods = append(podsByComputeType.fargatePods, pod) + } else if nodeNameByComputeType[pod.NodeName] == "sagemaker-hyperpod" { + podsByComputeType.sageMakerHyperPodPods = append(podsByComputeType.sageMakerHyperPodPods, pod) } else { - podsOnEc2 = append(podsOnEc2, pod) + podsByComputeType.ec2Pods = append(podsByComputeType.ec2Pods, pod) } } + nodeKey := types.NamespacedName{Name: pod.NodeName} node := &corev1.Node{} if err := r.k8sClient.Get(ctx, nodeKey, node); err != nil { - return nil, nil, err + return PodsByComputeType{}, err } if node.Labels[labelEKSComputeType] == "fargate" { - podsOnFargate = append(podsOnFargate, pod) + podsByComputeType.fargatePods = append(podsByComputeType.fargatePods, pod) nodeNameByComputeType[pod.NodeName] = "fargate" + } else if node.Labels[labelSageMakerComputeType] == "hyperpod" { + podsByComputeType.sageMakerHyperPodPods = append(podsByComputeType.sageMakerHyperPodPods, pod) + nodeNameByComputeType[pod.NodeName] = "sagemaker-hyperpod" } else { - podsOnEc2 = append(podsOnEc2, pod) + podsByComputeType.ec2Pods = append(podsByComputeType.ec2Pods, pod) nodeNameByComputeType[pod.NodeName] = "ec2" } } - return podsOnEc2, podsOnFargate, nil + return podsByComputeType, nil } // computePodENIInfoCacheKey computes the cacheKey for pod's ENIInfo cache. diff --git a/pkg/networking/pod_eni_info_resolver_test.go b/pkg/networking/pod_eni_info_resolver_test.go index f9a700d68..682a37184 100644 --- a/pkg/networking/pod_eni_info_resolver_test.go +++ b/pkg/networking/pod_eni_info_resolver_test.go @@ -999,6 +999,187 @@ func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_Fargate(t *testing. } } +func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_SageMakerHyperPod(t *testing.T) { + hyperPodNodeA := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "hyperpod-i-04442beca624ba65b", + Labels: map[string]string{ + "sagemaker.amazonaws.com/compute-type": "hyperpod", + }, + }, + Spec: corev1.NodeSpec{ + ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04442beca624ba65b", + }, + } + hyperPodNodeB := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "hyperpod-i-04159267183583d03", + Labels: map[string]string{ + "sagemaker.amazonaws.com/compute-type": "hyperpod", + }, + }, + Spec: corev1.NodeSpec{ + ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04159267183583d03", + }, + } + type describeNetworkInterfacesAsListCall struct { + req *ec2sdk.DescribeNetworkInterfacesInput + resp []ec2types.NetworkInterface + err error + } + type fetchNodeInstancesCall struct { + nodes []*corev1.Node + nodeInstanceByNodeKey map[types.NamespacedName]*ec2types.Instance + err error + } + type env struct { + nodes []*corev1.Node + } + type fields struct { + describeNetworkInterfacesAsListCalls []describeNetworkInterfacesAsListCall + fetchNodeInstancesCalls []fetchNodeInstancesCall + } + type args struct { + pods []k8s.PodInfo + } + tests := []struct { + name string + env env + fields fields + args args + want map[types.NamespacedName]ENIInfo + wantErr error + }{ + { + name: "all pod's ENI resolved via VPC's ENIs", + env: env{ + nodes: []*corev1.Node{hyperPodNodeA, hyperPodNodeB}, + }, + fields: fields{ + describeNetworkInterfacesAsListCalls: []describeNetworkInterfacesAsListCall{ + { + req: &ec2sdk.DescribeNetworkInterfacesInput{ + Filters: []ec2types.Filter{ + { + Name: awssdk.String("vpc-id"), + Values: []string{"vpc-0d6d9ee10bd062dcc"}, + }, + { + Name: awssdk.String("addresses.private-ip-address"), + Values: []string{"192.168.128.151", "192.168.128.152"}, + }, + }, + }, + resp: []ec2types.NetworkInterface{ + { + NetworkInterfaceId: awssdk.String("eni-c"), + PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{ + { + PrivateIpAddress: awssdk.String("192.168.128.150"), + }, + { + PrivateIpAddress: awssdk.String("192.168.128.151"), + }, + }, + Groups: []ec2types.GroupIdentifier{ + { + GroupId: awssdk.String("sg-c-1"), + }, + }, + }, + { + NetworkInterfaceId: awssdk.String("eni-d"), + PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{ + { + PrivateIpAddress: awssdk.String("192.168.128.152"), + }, + { + PrivateIpAddress: awssdk.String("192.168.128.153"), + }, + }, + Groups: []ec2types.GroupIdentifier{ + { + GroupId: awssdk.String("sg-d-1"), + }, + }, + }, + }, + }, + }, + }, + args: args{ + pods: []k8s.PodInfo{ + { + Key: types.NamespacedName{Namespace: "default", Name: "pod-1"}, + UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc01"), + NodeName: "hyperpod-i-04442beca624ba65b", + PodIP: "192.168.128.151", + }, + { + Key: types.NamespacedName{Namespace: "default", Name: "pod-2"}, + UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc02"), + NodeName: "hyperpod-i-04159267183583d03", + PodIP: "192.168.128.152", + }, + }, + }, + want: map[types.NamespacedName]ENIInfo{ + types.NamespacedName{Namespace: "default", Name: "pod-1"}: { + NetworkInterfaceID: "eni-c", + SecurityGroups: []string{"sg-c-1"}, + }, + types.NamespacedName{Namespace: "default", Name: "pod-2"}: { + NetworkInterfaceID: "eni-d", + SecurityGroups: []string{"sg-d-1"}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + for _, call := range tt.fields.describeNetworkInterfacesAsListCalls { + ec2Client.EXPECT().DescribeNetworkInterfacesAsList(gomock.Any(), call.req).Return(call.resp, call.err) + } + k8sSchema := runtime.NewScheme() + clientgoscheme.AddToScheme(k8sSchema) + k8sClient := fake.NewClientBuilder().WithScheme(k8sSchema).Build() + for _, node := range tt.env.nodes { + assert.NoError(t, k8sClient.Create(context.Background(), node.DeepCopy())) + } + nodeInfoProvider := NewMockNodeInfoProvider(ctrl) + for _, call := range tt.fields.fetchNodeInstancesCalls { + updatedNodes := make([]*corev1.Node, 0, len(call.nodes)) + for _, node := range call.nodes { + updatedNode := &corev1.Node{} + assert.NoError(t, k8sClient.Get(context.Background(), k8s.NamespacedName(node), updatedNode)) + updatedNodes = append(updatedNodes, updatedNode) + } + nodeInfoProvider.EXPECT().FetchNodeInstances(gomock.Any(), gomock.InAnyOrder(updatedNodes)).Return(call.nodeInstanceByNodeKey, call.err) + } + r := &defaultPodENIInfoResolver{ + ec2Client: ec2Client, + k8sClient: k8sClient, + nodeInfoProvider: nodeInfoProvider, + vpcID: "vpc-0d6d9ee10bd062dcc", + logger: logr.New(&log.NullLogSink{}), + describeNetworkInterfacesIPChunkSize: 2, + } + + got, err := r.resolveViaCascadedLookup(context.Background(), tt.args.pods, true) + if tt.wantErr != nil { + assert.EqualError(t, err, tt.wantErr.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + func Test_defaultPodENIInfoResolver_resolveViaPodENIAnnotation(t *testing.T) { type describeNetworkInterfacesAsListCall struct { req *ec2sdk.DescribeNetworkInterfacesInput