Skip to content

Commit

Permalink
add support for extra environment variables
Browse files Browse the repository at this point in the history
This commit add support for the sriov-device-plugin to inject extra env variables

Signed-off-by: Sebastian Sch <[email protected]>
  • Loading branch information
SchSeba committed Jan 4, 2023
1 parent b46ebca commit 4202525
Show file tree
Hide file tree
Showing 45 changed files with 775 additions and 274 deletions.
22 changes: 20 additions & 2 deletions pkg/accelerator/accelDevice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,16 @@ var _ = Describe("Accelerator", func() {

// TODO: assert other fields once implemented
Expect(out.GetDriver()).To(Equal("vfio-pci"))
Expect(out.GetEnvVal()).To(Equal(pciAddr))
envs := out.GetEnvVal()
Expect(len(envs)).To(Equal(1))

_, exist := envs["vfio"]
Expect(exist).To(BeTrue())

pci, exist := envs["vfio"]["pci"]
Expect(exist).To(BeTrue())
Expect(pci).To(Equal(pciAddr))

Expect(out.GetDeviceSpecs()).To(HaveLen(2)) // /dev/vfio/vfio0 and default /dev/vfio/vfio
Expect(out.GetAPIDevice().Topology.Nodes[0].ID).To(Equal(int64(0)))
Expect(err).NotTo(HaveOccurred())
Expand Down Expand Up @@ -174,7 +183,16 @@ var _ = Describe("Accelerator", func() {
dev, err := accelerator.NewAccelDevice(in, f, config)
Expect(err).NotTo(HaveOccurred())
Expect(dev).NotTo(BeNil())
Expect(dev.GetEnvVal()).To(Equal(pciAddr))

envs := dev.GetEnvVal()
Expect(len(envs)).To(Equal(1))

_, exist := envs["vfio"]
Expect(exist).To(BeTrue())

pci, exist := envs["vfio"]["pci"]
Expect(exist).To(BeTrue())
Expect(pci).To(Equal(pciAddr))
})
})
})
Expand Down
12 changes: 8 additions & 4 deletions pkg/devices/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,14 @@ func (ad *APIDeviceImpl) GetDeviceSpecs() []*pluginapi.DeviceSpec {
}

// GetEnvVal returns device environment variable
func (ad *APIDeviceImpl) GetEnvVal() string {
// Currently Device Plugin does not support returning multiple Env Vars
// so we use the value provided by the first InfoProvider.
return ad.infoProviders[0].GetEnvVal()
func (ad *APIDeviceImpl) GetEnvVal() map[string]types.EnvironmentVariables {
envValList := make(map[string]types.EnvironmentVariables, 0)
for _, provider := range ad.infoProviders {
for k, v := range provider.GetEnvVal() {
envValList[k] = v
}
}
return envValList
}

// GetMounts returns list of device host mounts
Expand Down
13 changes: 11 additions & 2 deletions pkg/devices/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,29 @@ var _ = Describe("ApiDevice", func() {
mockSpec1 := []*v1beta1.DeviceSpec{
{HostPath: "/mock/spec/1"},
}
mockInfo1.On("GetEnvVal").Return("0000:00:00.1")
mockEnv := map[string]types.EnvironmentVariables{"netdevice": {"pci": "0000:00:00.1"}}
mockInfo1.On("GetEnvVal").Return(mockEnv)
mockInfo1.On("GetDeviceSpecs").Return(mockSpec1)
mockInfo1.On("GetMounts").Return(nil)
mockInfo2 := &mocks.DeviceInfoProvider{}
mockSpec2 := []*v1beta1.DeviceSpec{
{HostPath: "/mock/spec/2"},
}
mockInfo2.On("GetEnvVal").Return(mockEnv)
mockInfo2.On("GetDeviceSpecs").Return(mockSpec2)
mockInfo2.On("GetMounts").Return(nil)

infoProviders := []types.DeviceInfoProvider{mockInfo1, mockInfo2}
dev := devices.NewAPIDeviceImpl("0000:00:00.1", infoProviders, -1)

Expect(dev.GetEnvVal()).To(Equal("0000:00:00.1"))
envs := dev.GetEnvVal()
Expect(len(envs)).To(Equal(1))
_, exist := envs["netdevice"]
Expect(exist).To(BeTrue())
pci, exist := envs["netdevice"]["pci"]
Expect(exist).To(BeTrue())
Expect(pci).To(Equal("0000:00:00.1"))

Expect(dev.GetDeviceSpecs()).To(HaveLen(2))
Expect(dev.GetMounts()).To(HaveLen(0))
Expect(dev.GetAPIDevice()).NotTo(BeNil())
Expand Down
4 changes: 4 additions & 0 deletions pkg/devices/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package devices
import (
"github.com/jaypipes/ghw"

"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/infoprovider"
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils"
)
Expand All @@ -46,6 +47,9 @@ func NewHostDeviceImpl(dev *ghw.PCIDevice, deviceID string, rFactory types.Resou
// Use the default Information Provided if not
if len(infoProviders) == 0 {
infoProviders = append(infoProviders, rFactory.GetDefaultInfoProvider(deviceID, driverName))
if rc.ExtraEnvVariables != nil {
infoProviders = append(infoProviders, infoprovider.NewEnvInfoProvider(dev.Address, rc.ExtraEnvVariables))
}
}

nodeNum := -1
Expand Down
22 changes: 18 additions & 4 deletions pkg/devices/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,16 @@ var _ = Describe("HostDevice", func() {
mockSpec1 := []*v1beta1.DeviceSpec{
{HostPath: "/mock/spec/1"},
}
mockInfo1.On("GetEnvVal").Return(pciAddr1)
mockEnv1 := map[string]types.EnvironmentVariables{"netdevice": {"pci": pciAddr1}}
mockInfo1.On("GetEnvVal").Return(mockEnv1)
mockInfo1.On("GetDeviceSpecs").Return(mockSpec1)
mockInfo1.On("GetMounts").Return(nil)
mockInfo2 := &mocks.DeviceInfoProvider{}
mockSpec2 := []*v1beta1.DeviceSpec{
{HostPath: "/mock/spec/2"},
}
mockInfo2.On("GetEnvVal").Return(pciAddr2)
mockEnv2 := map[string]types.EnvironmentVariables{"netdevice": {"pci": pciAddr2}}
mockInfo2.On("GetEnvVal").Return(mockEnv2)
mockInfo2.On("GetDeviceSpecs").Return(mockSpec2)
mockInfo2.On("GetMounts").Return(nil)
f.On("GetDefaultInfoProvider", pciAddr1, "mlx5_core").Return(mockInfo1).
Expand All @@ -211,7 +213,13 @@ var _ = Describe("HostDevice", func() {
dev, err := devices.NewHostDeviceImpl(in1, pciAddr1, f, rc, infoProviders)

Expect(dev.GetDriver()).To(Equal("mlx5_core"))
Expect(dev.GetEnvVal()).To(Equal(pciAddr1))
envs := dev.GetEnvVal()
Expect(len(envs)).To(Equal(1))
_, exist := envs["netdevice"]
Expect(exist).To(BeTrue())
pci, exist := envs["netdevice"]["pci"]
Expect(exist).To(BeTrue())
Expect(pci).To(Equal(pciAddr1))
Expect(dev.GetDeviceSpecs()).To(Equal(mockSpec1))
Expect(dev.GetMounts()).To(HaveLen(0))
Expect(err).NotTo(HaveOccurred())
Expand All @@ -224,7 +232,13 @@ var _ = Describe("HostDevice", func() {
dev, err := devices.NewHostDeviceImpl(in2, pciAddr2, f, rc, infoProviders)

Expect(dev.GetDriver()).To(Equal("mlx5_core"))
Expect(dev.GetEnvVal()).To(Equal(pciAddr2))
envs := dev.GetEnvVal()
Expect(len(envs)).To(Equal(1))
_, exist := envs["netdevice"]
Expect(exist).To(BeTrue())
pci, exist := envs["netdevice"]["pci"]
Expect(exist).To(BeTrue())
Expect(pci).To(Equal(pciAddr2))
Expect(dev.GetDeviceSpecs()).To(Equal(mockSpec2))
Expect(dev.GetMounts()).To(HaveLen(0))
Expect(err).NotTo(HaveOccurred())
Expand Down
66 changes: 66 additions & 0 deletions pkg/infoprovider/envInfoProvider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2018 Intel Corp. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package infoprovider

import (
"github.com/golang/glog"

pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"

"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
)

/*
envInfoProvider implements DeviceInfoProvider
*/
type envInfoProvider struct {
pciAddr string
envs map[string]types.EnvironmentVariables
}

// NewEnvInfoProvider create instance of Environment DeviceInfoProvider
func NewEnvInfoProvider(pciAddr string, envs map[string]types.EnvironmentVariables) types.DeviceInfoProvider {
return &envInfoProvider{
pciAddr: pciAddr,
envs: envs,
}
}

func (rp *envInfoProvider) GetDeviceSpecs() []*pluginapi.DeviceSpec {
devSpecs := make([]*pluginapi.DeviceSpec, 0)
return devSpecs
}

func (rp *envInfoProvider) GetEnvVal() map[string]types.EnvironmentVariables {
envs := make(map[string]string, 0)

// first we search for global configuration with the * then we check for specific one to override
for _, value := range []string{"*", rp.pciAddr} {
extraEnvDict, ok := rp.envs[value]
if ok {
for k, v := range extraEnvDict {
envs[k] = v
}
}
}
envMap := map[string]types.EnvironmentVariables{"envs": envs}
glog.Infof("Env GetEnvVal(): %v", envMap)
return envMap
}

func (rp *envInfoProvider) GetMounts() []*pluginapi.Mount {
mounts := make([]*pluginapi.Mount, 0)
return mounts
}
94 changes: 94 additions & 0 deletions pkg/infoprovider/envInfoProvider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package infoprovider_test

import (
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/infoprovider"
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

var _ = Describe("EnvInfoProvider", func() {
Describe("creating new rdmaInfoProvider", func() {
It("should return valid rdmaInfoProvider object", func() {
dip := infoprovider.NewEnvInfoProvider("fake01", map[string]types.EnvironmentVariables{})
Expect(dip).NotTo(Equal(nil))
})
})
Describe("GetEnvVal", func() {
It("should return an empty object if there are no environment variables", func() {
dip := infoprovider.NewEnvInfoProvider("fake", nil)
envs := dip.GetEnvVal()
Expect(len(envs)).To(Equal(1))
envMap, exist := envs["envs"]
Expect(exist).To(BeTrue())
Expect(len(envMap)).To(Equal(0))
})
It("should return an object with environment variables", func() {
dip := infoprovider.NewEnvInfoProvider("fake", map[string]types.EnvironmentVariables{"*": map[string]string{"test": "test"}})
envs := dip.GetEnvVal()
Expect(len(envs)).To(Equal(1))
envMap, exist := envs["envs"]
Expect(exist).To(BeTrue())
Expect(len(envMap)).To(Equal(1))
value, exist := envs["envs"]["test"]
Expect(exist).To(BeTrue())
Expect(value).To(Equal("test"))
})
It("should return an object with specific selector for environment variable", func() {
dip := infoprovider.NewEnvInfoProvider("0000:00:00.1", map[string]types.EnvironmentVariables{"*": map[string]string{"test": "test"}, "0000:00:00.1": map[string]string{"test": "test1"}})
envs := dip.GetEnvVal()
Expect(len(envs)).To(Equal(1))
envMap, exist := envs["envs"]
Expect(exist).To(BeTrue())
Expect(len(envMap)).To(Equal(1))
value, exist := envs["envs"]["test"]
Expect(exist).To(BeTrue())
Expect(value).To(Equal("test1"))
})
It("should return an object with specific selector for multiple environment variable", func() {
dip := infoprovider.NewEnvInfoProvider("0000:00:00.1", map[string]types.EnvironmentVariables{"*": map[string]string{"test": "test", "bla": "bla"}, "0000:00:00.1": map[string]string{"test": "test1"}})
envs := dip.GetEnvVal()
Expect(len(envs)).To(Equal(1))
envMap, exist := envs["envs"]
Expect(exist).To(BeTrue())
Expect(len(envMap)).To(Equal(2))
value, exist := envs["envs"]["test"]
Expect(exist).To(BeTrue())
Expect(value).To(Equal("test1"))
value, exist = envs["envs"]["bla"]
Expect(exist).To(BeTrue())
Expect(value).To(Equal("bla"))
})
It("should return an object with multiple specific selector for environment variable", func() {
dip := infoprovider.NewEnvInfoProvider("0000:00:00.1", map[string]types.EnvironmentVariables{"*": map[string]string{"test": "test"}, "0000:00:00.1": map[string]string{"test": "test1", "bla": "bla"}})
envs := dip.GetEnvVal()
Expect(len(envs)).To(Equal(1))
envMap, exist := envs["envs"]
Expect(exist).To(BeTrue())
Expect(len(envMap)).To(Equal(2))
value, exist := envs["envs"]["test"]
Expect(exist).To(BeTrue())
Expect(value).To(Equal("test1"))
value, exist = envs["envs"]["bla"]
Expect(exist).To(BeTrue())
Expect(value).To(Equal("bla"))
})
})
})
10 changes: 8 additions & 2 deletions pkg/infoprovider/genericInfoProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package infoprovider

import (
"github.com/golang/glog"

pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"

"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
Expand All @@ -37,8 +39,12 @@ func (rp *genericInfoProvider) GetDeviceSpecs() []*pluginapi.DeviceSpec {
return devSpecs
}

func (rp *genericInfoProvider) GetEnvVal() string {
return rp.pciAddr
func (rp *genericInfoProvider) GetEnvVal() map[string]types.EnvironmentVariables {
envs := make(map[string]string, 0)
envs["pci"] = rp.pciAddr
genericMap := map[string]types.EnvironmentVariables{"netdevice": envs}
glog.Infof("Generic GetEnvVal(): %v", genericMap)
return genericMap
}

func (rp *genericInfoProvider) GetMounts() []*pluginapi.Mount {
Expand Down
27 changes: 24 additions & 3 deletions pkg/infoprovider/rdmaInfoProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package infoprovider

import (
"strings"

"github.com/golang/glog"

pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
Expand All @@ -30,6 +32,7 @@ import (
*/
type rdmaInfoProvider struct {
rdmaSpec types.RdmaSpec
mounts string
}

// NewRdmaInfoProvider returns a new Rdma Information Provider
Expand All @@ -47,11 +50,29 @@ func (ip *rdmaInfoProvider) GetDeviceSpecs() []*pluginapi.DeviceSpec {
glog.Errorf("GetDeviceSpecs(): rdma is required in the configuration but the device is not rdma device")
return nil
}
return ip.rdmaSpec.GetRdmaDeviceSpec()

devsSpec := ip.rdmaSpec.GetRdmaDeviceSpec()
mounts := ""
for _, devSpec := range devsSpec {
mounts = mounts + devSpec.ContainerPath + ","
}
if mounts != "" {
mounts = strings.TrimSuffix(mounts, ",")
ip.mounts = mounts
}

return devsSpec
}

func (ip *rdmaInfoProvider) GetEnvVal() string {
return ""
func (ip *rdmaInfoProvider) GetEnvVal() map[string]types.EnvironmentVariables {
envs := make(map[string]string, 0)
if ip.mounts != "" {
envs["mounts"] = ip.mounts
}

rdmaMap := map[string]types.EnvironmentVariables{"rdma": envs}
glog.Infof("RDMA GetEnvVal(): %v", rdmaMap)
return rdmaMap
}

func (ip *rdmaInfoProvider) GetMounts() []*pluginapi.Mount {
Expand Down
Loading

0 comments on commit 4202525

Please sign in to comment.