diff --git a/docker/koordlet.dockerfile b/docker/koordlet.dockerfile index 682198f4a..228b4acd9 100644 --- a/docker/koordlet.dockerfile +++ b/docker/koordlet.dockerfile @@ -36,6 +36,7 @@ RUN go build -a -o koordlet cmd/koordlet/main.go FROM --platform=$TARGETPLATFORM nvidia/cuda:11.8.0-base-ubuntu22.04 WORKDIR / RUN apt-get update && apt-get install -y lvm2 && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y iptables COPY --from=builder /go/src/github.com/koordinator-sh/koordinator/koordlet . COPY --from=builder /usr/local/lib /usr/lib ENTRYPOINT ["/koordlet"] diff --git a/go.mod b/go.mod index 1a8c7055c..348aba927 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/NVIDIA/go-nvml v0.11.6-0.0.20220823120812-7e2082095e82 github.com/cakturk/go-netstat v0.0.0-20200220111822-e5b49efee7a5 github.com/containerd/nri v0.6.1 + github.com/coreos/go-iptables v0.5.0 github.com/docker/docker v20.10.21+incompatible github.com/evanphx/json-patch v5.6.0+incompatible github.com/fsnotify/fsnotify v1.6.0 @@ -196,7 +197,7 @@ require ( github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.9 // indirect - github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5 // indirect + github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5 github.com/vishvananda/netns v0.0.4 // indirect github.com/vmware/govmomi v0.30.6 // indirect go.etcd.io/etcd/api/v3 v3.5.9 // indirect diff --git a/go.sum b/go.sum index 3b0b3c27d..86f2d61b3 100644 --- a/go.sum +++ b/go.sum @@ -473,6 +473,7 @@ github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkE github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-iptables v0.4.5/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU= +github.com/coreos/go-iptables v0.5.0 h1:mw6SAibtHKZcNzAsOxjoHIG0gy5YFHhypWSSNc6EjbQ= github.com/coreos/go-iptables v0.5.0/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU= github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= diff --git a/pkg/koordlet/resourceexecutor/reader.go b/pkg/koordlet/resourceexecutor/reader.go index dfe68cc64..58a4d23f6 100644 --- a/pkg/koordlet/resourceexecutor/reader.go +++ b/pkg/koordlet/resourceexecutor/reader.go @@ -40,6 +40,7 @@ type CgroupReader interface { ReadCPUProcs(parentDir string) ([]uint32, error) ReadPSI(parentDir string) (*sysutil.PSIByResource, error) ReadMemoryColdPageUsage(parentDir string) (uint64, error) + ReadNetClsId(parentDir string) (uint64, error) } var _ CgroupReader = &CgroupV1Reader{} @@ -230,6 +231,14 @@ func (r *CgroupV1Reader) ReadPSI(parentDir string) (*sysutil.PSIByResource, erro return psi, nil } +func (r *CgroupV1Reader) ReadNetClsId(parentDir string) (uint64, error) { + resource, ok := sysutil.DefaultRegistry.Get(sysutil.CgroupVersionV1, sysutil.NetClsClassIdName) + if !ok { + return 0, ErrResourceNotRegistered + } + return readCgroupAndParseUint64(parentDir, resource) +} + var _ CgroupReader = &CgroupV2Reader{} type CgroupV2Reader struct{} @@ -436,6 +445,14 @@ func (r *CgroupV2Reader) ReadPSI(parentDir string) (*sysutil.PSIByResource, erro return psi, nil } +func (r *CgroupV2Reader) ReadNetClsId(parentDir string) (uint64, error) { + resource, ok := sysutil.DefaultRegistry.Get(sysutil.CgroupVersionV2, sysutil.NetClsClassIdName) + if !ok { + return 0, ErrResourceNotRegistered + } + return readCgroupAndParseUint64(parentDir, resource) +} + func NewCgroupReader() CgroupReader { if sysutil.GetCurrentCgroupVersion() == sysutil.CgroupVersionV2 { return &CgroupV2Reader{} diff --git a/pkg/koordlet/resourceexecutor/updater.go b/pkg/koordlet/resourceexecutor/updater.go index 526226a8c..6a21f07f8 100644 --- a/pkg/koordlet/resourceexecutor/updater.go +++ b/pkg/koordlet/resourceexecutor/updater.go @@ -55,6 +55,7 @@ func init() { sysutil.MemoryPriorityName, sysutil.MemoryUsePriorityOomName, sysutil.MemoryOomGroupName, + sysutil.NetClsClassIdName, ) // special cases DefaultCgroupUpdaterFactory.Register(NewCgroupUpdaterWithUpdateFunc(CgroupUpdateCPUSharesFunc), sysutil.CPUSharesName) diff --git a/pkg/koordlet/runtimehooks/config.go b/pkg/koordlet/runtimehooks/config.go index 62314d71d..fdc070d68 100644 --- a/pkg/koordlet/runtimehooks/config.go +++ b/pkg/koordlet/runtimehooks/config.go @@ -32,6 +32,7 @@ import ( "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/hooks/cpuset" "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/hooks/gpu" "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/hooks/groupidentity" + "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/hooks/tc" "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/hooks/terwayqos" "github.com/koordinator-sh/koordinator/pkg/koordlet/util/system" ) @@ -81,6 +82,11 @@ const ( // owner: @l1b0k // alpha: v1.5 TerwayQoS featuregate.Feature = "TerwayQoS" + + // TCNetworkQoS indicates a network qos implementation based on tc. + // owner: @lucming + // alpha: v1.5 + TCNetworkQoS featuregate.Feature = "TCNetworkQoS" ) var ( @@ -92,6 +98,7 @@ var ( CPUNormalization: {Default: false, PreRelease: featuregate.Alpha}, CoreSched: {Default: false, PreRelease: featuregate.Alpha}, TerwayQoS: {Default: false, PreRelease: featuregate.Alpha}, + TCNetworkQoS: {Default: false, PreRelease: featuregate.Alpha}, } runtimeHookPlugins = map[featuregate.Feature]HookPlugin{ @@ -102,6 +109,7 @@ var ( CPUNormalization: cpunormalization.Object(), CoreSched: coresched.Object(), TerwayQoS: terwayqos.Object(), + TCNetworkQoS: tc.Object(), } ) diff --git a/pkg/koordlet/runtimehooks/hooks/tc/helper.go b/pkg/koordlet/runtimehooks/hooks/tc/helper.go new file mode 100644 index 000000000..968a23c76 --- /dev/null +++ b/pkg/koordlet/runtimehooks/hooks/tc/helper.go @@ -0,0 +1,127 @@ +/* +Copyright 2022 The Koordinator Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tc + +import ( + "fmt" + "strconv" + "strings" + + "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/util/intstr" + + slov1alpha1 "github.com/koordinator-sh/koordinator/apis/slo/v1alpha1" +) + +func loadConfigFromNodeSlo(nodesloSpec *slov1alpha1.NodeSLOSpec) *NetQosGlobalConfig { + res := NetQosGlobalConfig{} + var total uint64 = 0 + if nodesloSpec != nil && nodesloSpec.SystemStrategy != nil { + total = uint64(nodesloSpec.SystemStrategy.TotalNetworkBandwidth.Value()) + res.HwRxBpsMax = total + res.HwTxBpsMax = total + } + + if nodesloSpec.ResourceQOSStrategy == nil { + return &res + } + + strategy := nodesloSpec.ResourceQOSStrategy + if strategy.LSClass != nil && + strategy.LSClass.NetworkQOS != nil && + *strategy.LSClass.NetworkQOS.Enable { + cur := strategy.LSClass.NetworkQOS + res.L1RxBpsMin = getBandwidthVal(total, cur.IngressRequest) + res.L1RxBpsMax = getBandwidthVal(total, cur.IngressLimit) + res.L1TxBpsMin = getBandwidthVal(total, cur.EgressRequest) + res.L1TxBpsMax = getBandwidthVal(total, cur.EgressLimit) + } + + if strategy.BEClass != nil && + strategy.BEClass.NetworkQOS != nil && + *strategy.BEClass.NetworkQOS.Enable { + cur := strategy.BEClass.NetworkQOS + res.L2RxBpsMin = getBandwidthVal(total, cur.IngressRequest) + res.L2RxBpsMax = getBandwidthVal(total, cur.IngressLimit) + res.L2TxBpsMin = getBandwidthVal(total, cur.EgressRequest) + res.L2TxBpsMax = getBandwidthVal(total, cur.EgressLimit) + } + + return &res +} + +func getBandwidthVal(total uint64, intOrPercent *intstr.IntOrString) uint64 { + if intOrPercent == nil { + return 0 + } + + switch intOrPercent.Type { + case intstr.String: + return getBandwidthByQuantityFormat(intOrPercent.StrVal) + case intstr.Int: + return getBandwidthByPercentageFormat(total, intOrPercent.IntValue()) + default: + return 0 + } +} + +func getBandwidthByQuantityFormat(quanityStr string) uint64 { + val, err := resource.ParseQuantity(quanityStr) + if err != nil { + return 0 + } + + return uint64(val.Value()) +} + +func getBandwidthByPercentageFormat(total uint64, percentage int) uint64 { + if percentage < 0 || percentage > 100 { + return 0 + } + + return total * uint64(percentage) / 100 +} + +func convertToClassId(major, minor int) string { + return fmt.Sprintf("%d:%d", major, minor) +} + +// convertToHexClassId get class id in hex. +func convertToHexClassId(major, minor int) uint32 { + hexVal := fmt.Sprintf("%d%04d", major, minor) + decimalVal, _ := strconv.ParseUint(hexVal, 16, 32) + return uint32(decimalVal) +} + +// convertIpToHex convert ip to it's hex format +// 10.211.248.149 => 0ad3f895 +func convertIpToHex(ip string) string { + result := "" + elems := strings.Split(ip, ".") + for _, elem := range elems { + cur, _ := strconv.Atoi(elem) + hex := fmt.Sprintf("%x", cur) + // each ip segment takes up two hexadecimal digits, and when it does not take up all the bits, + // it needs to be filled with 0. + for i := 0; i < 2-len(hex); i++ { + hex = "0" + hex + } + result += hex + } + + return result +} diff --git a/pkg/koordlet/runtimehooks/hooks/tc/helper_test.go b/pkg/koordlet/runtimehooks/hooks/tc/helper_test.go new file mode 100644 index 000000000..31f7e6b1c --- /dev/null +++ b/pkg/koordlet/runtimehooks/hooks/tc/helper_test.go @@ -0,0 +1,308 @@ +/* +Copyright 2022 The Koordinator Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/utils/pointer" + + slov1alpha1 "github.com/koordinator-sh/koordinator/apis/slo/v1alpha1" +) + +func Test_convertToHexClassId(t *testing.T) { + type args struct { + major int + minor int + } + tests := []struct { + name string + args args + want uint32 + }{ + // TODO: Add test cases. + { + name: "", + args: args{ + major: 11, + minor: 2, + }, + want: 1114114, + }, + { + name: "", + args: args{ + major: 1, + minor: 2222, + }, + want: 74274, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, convertToHexClassId(tt.args.major, tt.args.minor), "convertToHexClassId(%v, %v)", tt.args.major, tt.args.minor) + }) + } +} + +func genVal(in intstr.IntOrString) *intstr.IntOrString { + return &in +} + +func Test_loadConfigFromNodeSlo(t *testing.T) { + type args struct { + nodesloSpec *slov1alpha1.NodeSLOSpec + } + + tests := []struct { + name string + args args + want *NetQosGlobalConfig + }{ + // TODO: Add test cases. + { + name: "nodeslo.spec is nil", + args: args{ + nodesloSpec: &slov1alpha1.NodeSLOSpec{}, + }, + want: &NetQosGlobalConfig{}, + }, + { + name: "network qos is nil", + args: args{ + nodesloSpec: &slov1alpha1.NodeSLOSpec{ + ResourceQOSStrategy: &slov1alpha1.ResourceQOSStrategy{ + LSClass: &slov1alpha1.ResourceQOS{ + NetworkQOS: &slov1alpha1.NetworkQOSCfg{ + Enable: pointer.Bool(true), + }, + }, + }, + }, + }, + want: &NetQosGlobalConfig{}, + }, + { + name: "network config not enable to be set", + args: args{ + nodesloSpec: &slov1alpha1.NodeSLOSpec{ + ResourceQOSStrategy: &slov1alpha1.ResourceQOSStrategy{ + LSClass: &slov1alpha1.ResourceQOS{ + NetworkQOS: &slov1alpha1.NetworkQOSCfg{ + Enable: pointer.Bool(false), + NetworkQOS: slov1alpha1.NetworkQOS{ + IngressRequest: genVal(intstr.FromInt(10)), + }, + }, + }, + }, + SystemStrategy: &slov1alpha1.SystemStrategy{ + TotalNetworkBandwidth: resource.MustParse("100M"), + }, + }, + }, + want: &NetQosGlobalConfig{ + HwTxBpsMax: 100000000, + HwRxBpsMax: 100000000, + }, + }, + { + name: "total network bandwidth not been set", + args: args{ + nodesloSpec: &slov1alpha1.NodeSLOSpec{ + ResourceQOSStrategy: &slov1alpha1.ResourceQOSStrategy{ + LSClass: &slov1alpha1.ResourceQOS{ + NetworkQOS: &slov1alpha1.NetworkQOSCfg{ + Enable: pointer.Bool(true), + NetworkQOS: slov1alpha1.NetworkQOS{ + IngressRequest: genVal(intstr.FromInt(10)), + }, + }, + }, + }, + }, + }, + want: &NetQosGlobalConfig{}, + }, + { + name: "get network config from a int value", + args: args{ + nodesloSpec: &slov1alpha1.NodeSLOSpec{ + ResourceQOSStrategy: &slov1alpha1.ResourceQOSStrategy{ + LSClass: &slov1alpha1.ResourceQOS{ + NetworkQOS: &slov1alpha1.NetworkQOSCfg{ + Enable: pointer.Bool(true), + NetworkQOS: slov1alpha1.NetworkQOS{ + IngressRequest: genVal(intstr.FromInt(10)), + }, + }, + }, + }, + SystemStrategy: &slov1alpha1.SystemStrategy{ + TotalNetworkBandwidth: resource.MustParse("100M"), + }, + }, + }, + want: &NetQosGlobalConfig{ + HwTxBpsMax: 100000000, + HwRxBpsMax: 100000000, + L1RxBpsMin: 10000000, + }, + }, + { + name: "get network config from a string value", + args: args{ + nodesloSpec: &slov1alpha1.NodeSLOSpec{ + ResourceQOSStrategy: &slov1alpha1.ResourceQOSStrategy{ + LSClass: &slov1alpha1.ResourceQOS{ + NetworkQOS: &slov1alpha1.NetworkQOSCfg{ + Enable: pointer.Bool(true), + NetworkQOS: slov1alpha1.NetworkQOS{ + IngressRequest: genVal(intstr.FromString("10M")), + }, + }, + }, + }, + SystemStrategy: &slov1alpha1.SystemStrategy{ + TotalNetworkBandwidth: resource.MustParse("100M"), + }, + }, + }, + want: &NetQosGlobalConfig{ + HwTxBpsMax: 100000000, + HwRxBpsMax: 100000000, + L1RxBpsMin: 10000000, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, loadConfigFromNodeSlo(tt.args.nodesloSpec), "loadConfigFromNodeSlo(%v)", tt.args.nodesloSpec) + }) + } +} + +func Test_convertIpToHex(t *testing.T) { + type args struct { + ip string + } + tests := []struct { + name string + args args + want string + }{ + // TODO: Add test cases. + { + name: "legal ip", + args: args{ + ip: "10.211.248.149", + }, + want: "0ad3f895", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, convertIpToHex(tt.args.ip), "convertIpToHex(%v)", tt.args.ip) + }) + } +} + +func Test_getBandwidthVal(t *testing.T) { + type args struct { + total uint64 + intOrPercent *intstr.IntOrString + } + tests := []struct { + name string + args args + want uint64 + }{ + // TODO: Add test cases. + { + name: "input is nil", + args: args{ + total: 1000, + intOrPercent: nil, + }, + want: 0, + }, + { + name: "percentage value less than 0", + args: args{ + total: 1000, + intOrPercent: genVal(intstr.FromInt(-1)), + }, + want: 0, + }, + { + name: "percentage value is zero", + args: args{ + total: 1000, + intOrPercent: genVal(intstr.FromInt(0)), + }, + want: 0, + }, + { + name: "percentage value over 100", + args: args{ + total: 1000, + intOrPercent: genVal(intstr.FromInt(200)), + }, + want: 0, + }, + { + name: "valid percentage", + args: args{ + total: 1000, + intOrPercent: genVal(intstr.FromInt(10)), + }, + want: 100, + }, + { + name: "invalid quantity format", + args: args{ + total: 1000, + intOrPercent: genVal(intstr.FromString("aaa")), + }, + want: 0, + }, + { + name: "quantity string is nil", + args: args{ + total: 1000, + intOrPercent: genVal(intstr.FromString("")), + }, + want: 0, + }, + { + name: "valid string format", + args: args{ + total: 10000, + intOrPercent: genVal(intstr.FromString("2k")), + }, + want: 2000, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, getBandwidthVal(tt.args.total, tt.args.intOrPercent), "getBandwidthVal(%v, %v)", tt.args.total, tt.args.intOrPercent) + }) + } +} diff --git a/pkg/koordlet/runtimehooks/hooks/tc/ipset.go b/pkg/koordlet/runtimehooks/hooks/tc/ipset.go new file mode 100644 index 000000000..6705fb708 --- /dev/null +++ b/pkg/koordlet/runtimehooks/hooks/tc/ipset.go @@ -0,0 +1,78 @@ +//go:build linux +// +build linux + +/* +Copyright 2022 The Koordinator Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tc + +import ( + "fmt" + + "github.com/vishvananda/netlink" + apierror "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/klog/v2" +) + +func (p *tcPlugin) ipsetExisted() (bool, error) { + var errs []error + for _, cur := range ipsets { + if _, err := netlink.IpsetList(cur); err != nil { + errs = append(errs, err) + } + } + + if apierror.NewAggregate(errs) != nil { + return false, apierror.NewAggregate(errs) + } + + return true, nil +} + +func (p *tcPlugin) EnsureIpset() error { + klog.V(5).Infof("start to create ipset.") + var errs []error + for _, cur := range ipsets { + result, err := netlink.IpsetList(cur) + if err == nil && result != nil { + continue + } + + err = netlink.IpsetCreate(cur, "hash:ip", netlink.IpsetCreateOptions{}) + if err != nil { + err = fmt.Errorf("failed to create ipset. err=%v", err) + errs = append(errs, err) + } + } + + return apierror.NewAggregate(errs) +} + +func (p *tcPlugin) DestoryIpset() error { + klog.V(5).Infof("start to delete ipset rules created by tc plugin.") + var errs []error + for _, cur := range ipsets { + result, err := netlink.IpsetList(cur) + if err == nil && result != nil { + if err := netlink.IpsetDestroy(cur); err != nil { + err = fmt.Errorf("failed to destroy ipset. err=%v", err) + errs = append(errs, err) + } + } + } + + return apierror.NewAggregate(errs) +} diff --git a/pkg/koordlet/runtimehooks/hooks/tc/iptables.go b/pkg/koordlet/runtimehooks/hooks/tc/iptables.go new file mode 100644 index 000000000..a47546c28 --- /dev/null +++ b/pkg/koordlet/runtimehooks/hooks/tc/iptables.go @@ -0,0 +1,105 @@ +//go:build linux +// +build linux + +/* +Copyright 2022 The Koordinator Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tc + +import ( + "fmt" + + apierror "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/klog/v2" +) + +func (p *tcPlugin) EnsureIptables() error { + klog.V(5).Infof("start to create iptables.") + if p.iptablesHandler == nil { + return fmt.Errorf("can't create tc iptables rulues,because qos manager is nil") + } + + ipsetToClassid := map[NetQoSClass]string{ + NETQoSSystem: convertToClassId(ROOT_CLASS_MINOR_ID, SYSTEM_CLASS_MINOR_ID), + NETQoSLS: convertToClassId(ROOT_CLASS_MINOR_ID, LS_CLASS_MINOR_ID), + NETQoSBE: convertToClassId(ROOT_CLASS_MINOR_ID, BE_CLASS_MINOR_ID), + } + + for ipsetName, classid := range ipsetToClassid { + err := p.iptablesHandler.AppendUnique("mangle", "POSTROUTING", + "-m", "set", "--match-set", string(ipsetName), "src", + "-j", "CLASSIFY", "--set-class", classid) + if err != nil { + klog.Errorf("ipt append err=%v", err) + return err + } + } + + return nil +} + +func (p *tcPlugin) DelIptables() error { + klog.V(5).Infof("start to delete iptables rules created by tc plugin.") + if p.iptablesHandler == nil { + return fmt.Errorf("can't create tc iptables rulues,because qos manager is nil") + } + ipsetToClassid := map[NetQoSClass]string{ + NETQoSSystem: convertToClassId(ROOT_CLASS_MINOR_ID, SYSTEM_CLASS_MINOR_ID), + NETQoSLS: convertToClassId(ROOT_CLASS_MINOR_ID, LS_CLASS_MINOR_ID), + NETQoSBE: convertToClassId(ROOT_CLASS_MINOR_ID, BE_CLASS_MINOR_ID), + } + var errs []error + + for ipsetName, classid := range ipsetToClassid { + err := p.iptablesHandler.DeleteIfExists("mangle", "POSTROUTING", + "-m", "set", "--match-set", string(ipsetName), "src", + "-j", "CLASSIFY", "--set-class", classid) + errs = append(errs, err) + } + + return apierror.NewAggregate(errs) +} + +func (p *tcPlugin) iptablesExisted() (bool, error) { + ipsetToClassid := map[NetQoSClass]string{ + NETQoSSystem: convertToClassId(ROOT_CLASS_MINOR_ID, SYSTEM_CLASS_MINOR_ID), + NETQoSLS: convertToClassId(ROOT_CLASS_MINOR_ID, LS_CLASS_MINOR_ID), + NETQoSBE: convertToClassId(ROOT_CLASS_MINOR_ID, BE_CLASS_MINOR_ID), + } + + for ipsetName, classid := range ipsetToClassid { + exp := fmt.Sprintf("-A POSTROUTING -m set --match-set %s src -j CLASSIFY --set-class %s", string(ipsetName), classid) + existed := false + + // looks like this one: + // -A POSTROUTING -m set --match-set mid_class src -j CLASSIFY --set-class 0001:0003 + rules, err := p.iptablesHandler.List("mangle", "POSTROUTING") + if err == nil { + for _, rule := range rules { + if rule == exp { + existed = true + break + } + } + } + + if !existed { + return false, fmt.Errorf("iptables for matching ipset(%s) not found", ipsetName) + } + } + + return true, nil +} diff --git a/pkg/koordlet/runtimehooks/hooks/tc/netqos_tc.go b/pkg/koordlet/runtimehooks/hooks/tc/netqos_tc.go new file mode 100644 index 000000000..546f3013f --- /dev/null +++ b/pkg/koordlet/runtimehooks/hooks/tc/netqos_tc.go @@ -0,0 +1,124 @@ +/* +Copyright 2022 The Koordinator Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tc + +import ( + "encoding/json" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/klog/v2" + + "github.com/koordinator-sh/koordinator/apis/extension" +) + +type NetQosGlobalConfig struct { + HwTxBpsMax uint64 `json:"hw_tx_bps_max"` + HwRxBpsMax uint64 `json:"hw_rx_bps_max"` + L1TxBpsMin uint64 `json:"l1_tx_bps_min"` + L1TxBpsMax uint64 `json:"l1_tx_bps_max"` + L2TxBpsMin uint64 `json:"l2_tx_bps_min"` + L2TxBpsMax uint64 `json:"l2_tx_bps_max"` + L1RxBpsMin uint64 `json:"l1_rx_bps_min"` + L1RxBpsMax uint64 `json:"l1_rx_bps_max"` + L2RxBpsMin uint64 `json:"l2_rx_bps_min"` + L2RxBpsMax uint64 `json:"l2_rx_bps_max"` +} + +type NetQoSClass string + +const ( + NETQoSSystem NetQoSClass = "system_class" + NETQoSLS NetQoSClass = "ls_class" + NETQoSBE NetQoSClass = "be_class" + NETQoSNone NetQoSClass = "" +) + +func GetPodNetQoSClassByName(qos string) NetQoSClass { + q := extension.QoSClass(qos) + + switch q { + case extension.QoSSystem: + return NETQoSSystem + case extension.QoSLSE, extension.QoSLSR, extension.QoSLS: + return NETQoSLS + case extension.QoSBE: + return NETQoSBE + } + + return NETQoSNone +} + +func GetPodNetQoSClass(pod *corev1.Pod) NetQoSClass { + if pod == nil || pod.Labels == nil { + return NETQoSNone + } + return GetNetQoSClassByAttrs(pod.Labels, pod.Annotations) +} + +func GetNetQoSClassByAttrs(labels, annotations map[string]string) NetQoSClass { + if labels == nil { + return NETQoSNone + } + // annotations are for old format adaption reason + if q, exist := labels[extension.LabelPodQoS]; exist { + return GetPodNetQoSClassByName(q) + } + return NETQoSNone +} + +type NetworkQoS struct { + // IngressLimit and EgressLimit is the bandwidth in bps + // are used to set bandwidth for Pod. The unit is bps. + // For example, 10M means 10 megabits per second. + IngressLimit string `json:"ingressLimit"` + EgressLimit string `json:"egressLimit"` +} + +func getIngressAndEgress(anno map[string]string) (uint64, uint64, error) { + klog.V(5).Infof("start to get pod qos from anno: %v", anno) + var ingress, egress uint64 + + if anno[extension.AnnotationNetworkQOS] != "" { + nqos := &NetworkQoS{ + IngressLimit: "0", + EgressLimit: "0", + } + err := json.Unmarshal([]byte(anno[extension.AnnotationNetworkQOS]), nqos) + if err != nil { + return 0, 0, err + } + + ing, err := resource.ParseQuantity(nqos.IngressLimit) + if err != nil { + return 0, 0, err + } + ingress = BitsToBytes(uint64(ing.Value())) + + eg, err := resource.ParseQuantity(nqos.EgressLimit) + if err != nil { + return 0, 0, err + } + egress = BitsToBytes(uint64(eg.Value())) + } + + return ingress, egress, nil +} + +func BitsToBytes[T uint64 | float64 | int](bits T) T { + return bits / 8 +} diff --git a/pkg/koordlet/runtimehooks/hooks/tc/rule.go b/pkg/koordlet/runtimehooks/hooks/tc/rule.go new file mode 100644 index 000000000..b1a88600b --- /dev/null +++ b/pkg/koordlet/runtimehooks/hooks/tc/rule.go @@ -0,0 +1,175 @@ +//go:build linux +// +build linux + +/* +Copyright 2022 The Koordinator Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tc + +import ( + "errors" + "fmt" + "reflect" + + "k8s.io/apimachinery/pkg/types" + "k8s.io/klog/v2" + + slov1alpha1 "github.com/koordinator-sh/koordinator/apis/slo/v1alpha1" + "github.com/koordinator-sh/koordinator/pkg/koordlet/resourceexecutor" + "github.com/koordinator-sh/koordinator/pkg/koordlet/statesinformer" +) + +type tcRule struct { + enable bool + netCfg *NetQosGlobalConfig + speed uint64 + uidToHandle map[types.UID]uint32 + handleToUid map[uint32]types.UID +} + +func newRule() *tcRule { + return &tcRule{ + enable: false, + netCfg: nil, + speed: 0, + uidToHandle: map[types.UID]uint32{}, + handleToUid: map[uint32]types.UID{}, + } +} + +func (p *tcPlugin) getRule() *tcRule { + p.ruleRWMutex.RLock() + defer p.ruleRWMutex.RUnlock() + if p.rule == nil { + return nil + } + rule := *p.rule + return &rule +} + +func (p *tcPlugin) updateRule(newRule *tcRule) bool { + p.ruleRWMutex.RLock() + defer p.ruleRWMutex.RUnlock() + if !reflect.DeepEqual(newRule, p.rule) { + p.rule = newRule + return true + } + return false +} + +func (p *tcPlugin) parseRuleForNodeSLO(mergedNodeSLOIf interface{}) (bool, error) { + mergedNodeSLO := mergedNodeSLOIf.(*slov1alpha1.NodeSLOSpec) + if mergedNodeSLO == nil { + return false, nil + } + qosStrategy := mergedNodeSLO.ResourceQOSStrategy + + // default policy enables + isNETQOSPolicyTC := qosStrategy == nil || qosStrategy.Policies == nil || qosStrategy.Policies.NETQOSPolicy == nil || + *qosStrategy.Policies.NETQOSPolicy == slov1alpha1.NETQOSPolicyTC + + newRule := p.getRule() + if isNETQOSPolicyTC { + if mergedNodeSLO.SystemStrategy == nil { + return false, nil + } + newRule.enable = true + newRule.speed = uint64(mergedNodeSLO.SystemStrategy.TotalNetworkBandwidth.Value()) + newRule.netCfg = loadConfigFromNodeSlo(mergedNodeSLO) + } else { + newRule.enable = false + } + + updated := p.updateRule(newRule) + return updated, nil +} + +func (p *tcPlugin) parseForAllPods(e interface{}) (bool, error) { + _, ok := e.(*struct{}) + if !ok { + return false, fmt.Errorf("invalid rule type %T", e) + } + + return true, nil +} + +func (p *tcPlugin) ruleUpdateCbForNodeSlo(target *statesinformer.CallbackTarget) error { + if err := p.prepare(); err != nil { + return err + } + + r := p.getRule() + if r == nil { + klog.V(5).Infof("hook plugin rule is nil, nothing to do for plugin %v", name) + return nil + } + + if r.enable { + klog.V(5).Infof("tc plugin is enabled, ready to init related rules.") + return p.InitRelatedRules() + } + + klog.V(5).Infof("tc plugin is not enabled, ready to cleanup related rules.") + return p.CleanUp() +} + +func (p *tcPlugin) ruleUpdateCbForPod(target *statesinformer.CallbackTarget) error { + if target == nil { + return errors.New("callback target is nil") + } + podMetas := target.Pods + if len(podMetas) <= 0 { + klog.V(5).Infof("plugin %s skipped for rule update, no pod passed from callback", name) + return nil + } + + if err := p.prepare(); err != nil { + return err + } + + r := p.getRule() + if r == nil { + klog.V(5).Infof("hook plugin rule is nil, nothing to do for plugin %v", name) + return nil + } + + // cache classId from net_cls cgroup when first sync. + p.allPodsSyncOnce.Do(func() { + // mark the class that has been used + r.handleToUid[rootClass] = "" + r.handleToUid[systemClass] = "" + r.handleToUid[lsClass] = "" + r.handleToUid[beClass] = "" + + cgroupReader := resourceexecutor.NewCgroupReader() + for _, pod := range podMetas { + if value, ok := r.uidToHandle[pod.Pod.UID]; ok || value == 0 { + continue + } + + // netClsId is a decimal number. + netClsId, err := cgroupReader.ReadNetClsId(pod.CgroupDir) + if err != nil || netClsId == 0 { + continue + } + r.uidToHandle[pod.Pod.UID] = uint32(netClsId) + r.handleToUid[uint32(netClsId)] = pod.Pod.UID + p.updateRule(r) + } + }) + + return p.refreshForAllPods(podMetas, r) +} diff --git a/pkg/koordlet/runtimehooks/hooks/tc/tc.go b/pkg/koordlet/runtimehooks/hooks/tc/tc.go new file mode 100644 index 000000000..43aa00f10 --- /dev/null +++ b/pkg/koordlet/runtimehooks/hooks/tc/tc.go @@ -0,0 +1,46 @@ +//go:build !linux +// +build !linux + +/* +Copyright 2022 The Koordinator Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tc + +import ( + "k8s.io/klog/v2" + + "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/hooks" +) + +const ( + name = "TCInjection" + description = "set tc rules for nodes" +) + +type tcPlugin struct{} + +func Object() *tcPlugin { + return nil +} + +func (n *tcPlugin) Reconcile() { + klog.V(5).Info("net qos plugin start to reconcile in !linux os") + return +} + +func (n *tcPlugin) Register(op hooks.Options) { + klog.V(5).Infof("register hook %v", name) +} diff --git a/pkg/koordlet/runtimehooks/hooks/tc/tc_linux.go b/pkg/koordlet/runtimehooks/hooks/tc/tc_linux.go new file mode 100644 index 000000000..3a2771fca --- /dev/null +++ b/pkg/koordlet/runtimehooks/hooks/tc/tc_linux.go @@ -0,0 +1,879 @@ +//go:build linux +// +build linux + +/* +Copyright 2022 The Koordinator Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tc + +import ( + "errors" + "fmt" + "math/rand" + "net" + "os" + "strconv" + "strings" + "sync" + + "github.com/coreos/go-iptables/iptables" + "github.com/vishvananda/netlink" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + apierror "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/klog/v2" + "k8s.io/kubernetes/pkg/kubelet/util/format" + "k8s.io/utils/exec" + "k8s.io/utils/pointer" + + "github.com/koordinator-sh/koordinator/pkg/koordlet/resourceexecutor" + "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/hooks" + "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/protocol" + "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/reconciler" + "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/rule" + "github.com/koordinator-sh/koordinator/pkg/koordlet/statesinformer" + "github.com/koordinator-sh/koordinator/pkg/koordlet/util/system" + sysutil "github.com/koordinator-sh/koordinator/pkg/koordlet/util/system" +) + +const ( + name = "tcPlugin" + description = "setup tc rules for node" + + ruleNameForNodeSLO = name + " (nodeSLO)" + ruleNameForAllPods = name + " (allPods)" +) + +const ( + MAJOR_ID = 1 + QDISC_MINOR_ID = 0 + ROOT_CLASS_MINOR_ID = 1 + SYSTEM_CLASS_MINOR_ID = 2 + LS_CLASS_MINOR_ID = 3 + BE_CLASS_MINOR_ID = 4 + + // 0-7, In the round-robin process, classes with the lowest priority field are tried for packets first. + SYSTEM_CLASS_PRIO = 1 + LS_CLASS_PRIO = 2 + BE_CLASS_PRIO = 3 + + POD_FILTER_PRIO = 4 + + // Maximum rate this class and all its children are guaranteed. Mandatory. + // attention: the values below only represent the percentage of bandwidth can be used by different tc classes on the host, + // the real values need to be calculated based on the physical network bandwidth. + // eg: eth0: speed:200Mbit => high_clss.rate = 200Mbit * 40 / 100 = 80Mbit + BE_CLASS_RATE_PERCENTAGE = 30 + LS_CLASS_RATE_PERCENTAGE = 30 + SYSTEM_CLASS_RATE_PERCENTAGE = 40 + + // Maximum rate at which a class can send, if its parent has bandwidth to spare. Defaults to the configured rate, + // which implies no borrowing + CEIL_PERCENTAGE = 100 + + DEFAULT_INTERFACE_NAME = "eth0" +) + +var ( + rootClass = netlink.MakeHandle(MAJOR_ID, ROOT_CLASS_MINOR_ID) + systemClass = netlink.MakeHandle(MAJOR_ID, SYSTEM_CLASS_MINOR_ID) + lsClass = netlink.MakeHandle(MAJOR_ID, LS_CLASS_MINOR_ID) + beClass = netlink.MakeHandle(MAJOR_ID, BE_CLASS_MINOR_ID) + + ipsets = []string{string(NETQoSSystem), string(NETQoSLS), string(NETQoSBE)} +) + +type tcPlugin struct { + ruleRWMutex sync.RWMutex + rule *tcRule + + // this is the physical NIC on host, default eth0 + interfLink netlink.Link + + // for executing the iptables command. + iptablesHandler *iptables.IPTables + // for executing the tc and ipset command. + netLinkHandler netlink.Handle + + allPodsSyncOnce sync.Once + + executor resourceexecutor.ResourceUpdateExecutor +} + +var singleton *tcPlugin + +func Object() *tcPlugin { + if singleton == nil { + singleton = newPlugin() + } + return singleton +} + +func newPlugin() *tcPlugin { + return &tcPlugin{ + ruleRWMutex: sync.RWMutex{}, + netLinkHandler: netlink.Handle{}, + allPodsSyncOnce: sync.Once{}, + rule: newRule(), + } +} + +func (p *tcPlugin) Register(op hooks.Options) { + klog.V(5).Infof("register hook %v", name) + + rule.Register(ruleNameForNodeSLO, description, + rule.WithParseFunc(statesinformer.RegisterTypeNodeSLOSpec, p.parseRuleForNodeSLO), + rule.WithUpdateCallback(p.ruleUpdateCbForNodeSlo), + ) + + rule.Register(ruleNameForAllPods, description, + rule.WithParseFunc(statesinformer.RegisterTypeAllPods, p.parseForAllPods), + rule.WithUpdateCallback(p.ruleUpdateCbForPod)) + // TODO register NRI after there is pod ip in NRI request + + reconciler.RegisterCgroupReconciler(reconciler.PodLevel, sysutil.NetClsClassId, description+" (pod net class id)", + p.SetPodNetCls, reconciler.NoneFilter()) + + p.executor = op.Executor +} + +func (p *tcPlugin) SetPodNetCls(proto protocol.HooksProtocol) error { + podCtx := proto.(*protocol.PodContext) + if podCtx == nil { + return fmt.Errorf("pod protocol is nil for plugin %v", name) + } + + netQos := GetNetQoSClassByAttrs(podCtx.Request.Labels, podCtx.Request.Annotations) + if netQos == NETQoSNone { + return nil + } + + ing, egress, err := getIngressAndEgress(podCtx.Request.Annotations) + if err != nil { + klog.Errorf("failed to get net config from annotation in pod(%s/%s/%v)", podCtx.Request.PodMeta.Namespace, podCtx.Request.PodMeta.Name, podCtx.Request.PodMeta.UID) + } + + r := p.getRule() + if r == nil { + klog.V(5).Infof("hook plugin rule is nil, nothing to do for plugin %v", name) + return nil + } + + var handle uint32 + needLimitInPodLevel := ing != 0 || egress != 0 + if needLimitInPodLevel { + if handleId, ok := r.uidToHandle[types.UID(podCtx.Request.PodMeta.UID)]; ok { + handle = handleId + } + } else { + netqosToHandle := func(qos NetQoSClass) uint32 { + m := map[NetQoSClass]uint32{ + NETQoSSystem: systemClass, + NETQoSLS: lsClass, + NETQoSBE: beClass, + } + if handle, ok := m[qos]; ok { + return handle + } + return lsClass + } + + handle = netqosToHandle(netQos) + } + + podCtx.Response.Resources.NetClsClassId = pointer.Uint32(handle) + + return nil +} + +func (p *tcPlugin) createTcRulesForHostPod(rule *tcRule, pod *v1.Pod, egress uint64) error { + handle, _ := rule.uidToHandle[pod.UID] + netqos := GetNetQoSClassByAttrs(pod.Labels, pod.Annotations) + cls := newClass(p.interfLink.Attrs().Index, rootClass, handle, egress, egress, GetPrio(netqos)) + err := p.ensureClass(p.interfLink, cls) + if err != nil { + return err + } + + minorHex := getMinorId(handle) + minorDecimal, _ := strconv.ParseUint(minorHex, 16, 64) + genFilterCmd := func(op string) exec.Cmd { + // tc filter add dev eth0 parent 1: protocol ip prio 2 handle 5: cgroup + // means to 1:5 class + return exec.New().Command("tc", "filter", op, "dev", p.interfLink.Attrs().Name, + "protocol", "ip", "prio", strconv.Itoa(getPrio(p.interfLink.Attrs().Name)), + "parent", fmt.Sprintf("%d:", ROOT_CLASS_MINOR_ID), + "handle", fmt.Sprintf("%d:", minorDecimal), + "cgroup", + ) + } + + return p.ensureFilter(KeyByHandle(minorHex), genFilterCmd("add"), genFilterCmd("change")) +} + +func (p *tcPlugin) delTcRules(handle uint32) error { + prio := getFilterPrio(p.interfLink.Attrs().Name, handle) + if prio == "" { + return nil + } + + delFilterCmd := exec.New().Command("tc", "filter", "delete", "dev", p.interfLink.Attrs().Name, + "parent", fmt.Sprintf("%d:", ROOT_CLASS_MINOR_ID), "prio", prio, + ) + data, err := delFilterCmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to delete tc filter, output: %s, err: %v", string(data), err) + } + + cls := newClass(p.interfLink.Attrs().Index, rootClass, handle, 0, 0, 0) + return p.deleteClass(p.interfLink, cls) +} + +func getFilterPrio(nic string, handle uint32) string { + minorHex := getMinorId(handle) + prio := getPrioForFilter(nic, KeyByHandle(minorHex)) + if prio != "" { + return prio + } + + return getPrioForFilter(nic, KeyByFlowId(minorHex)) +} + +// find filter by a key value, just as "handle 0x**" +func getPrioForFilter(nic string, key string) string { + // output just like this: + // filter parent 1: protocol ip pref 1 cgroup chain 0 + // filter parent 1: protocol ip pref 1 cgroup chain 0 handle 0x3 + output, err := exec.New().Command("tc", "filter", "show", "dev", nic).CombinedOutput() + if err != nil { + return "" + } + strs := strings.Split(strings.TrimSpace(string(output)), "\n") + if len(strs) == 0 { + return "" + } + + for _, line := range strs { + if !strings.Contains(line, key) { + continue + } + + params := strings.Fields(line) + for idx, param := range params { + if param == "pref" { + return params[idx+1] + } + } + } + + return "" +} + +func (p *tcPlugin) prepare() error { + linkInfo, err := system.GetLinkInfoByDefaultRoute() + if err != nil { + return fmt.Errorf("failed to get link info by default route. err=%v\n", err) + } + if linkInfo == nil || linkInfo.Attrs() == nil { + return fmt.Errorf("link info is nil") + } + + p.interfLink = linkInfo + + ipt, err := iptables.New() + if err != nil { + klog.Errorf("failed to get iptables handler in those dir(%s). err=%v\n", os.Getenv("PATH"), err) + return err + } + p.iptablesHandler = ipt + + return nil +} + +func getMinorId(num uint32) string { + minor := num - MAJOR_ID<<16 + return strconv.FormatUint(uint64(minor), 16) +} + +func (p *tcPlugin) refreshForAllPods(pods []*statesinformer.PodMeta, rule *tcRule) error { + ipInK8S := make(map[string]sets.String) + ipInIpset := make(map[string]sets.String) + activePods := make(map[types.UID]interface{}) + + // handle active pod + for _, pod := range pods { + activePods[pod.Pod.UID] = nil + netqos := GetPodNetQoSClass(pod.Pod) + if netqos == NETQoSNone { + continue + } + + ing, egress, err := getIngressAndEgress(pod.Pod.Annotations) + if err != nil { + klog.Errorf("failed to get net config from annotation in pod(%s)", format.Pod(pod.Pod)) + } + needLimitAtPodLevel := ing != 0 || egress != 0 + + if needLimitAtPodLevel { + // create tc rules + if err := initHandleId(pod.Pod.UID, rule.handleToUid, rule.uidToHandle); err != nil { + klog.Errorf(err.Error()) + continue + } + p.updateRule(rule) + + if pod.Pod.Spec.HostNetwork { + // pod in host network namespace, network bandwidth can be limited by net_cls cgroup. + err = p.createTcRulesForHostPod(rule, pod.Pod, egress) + } else { + err = p.createRulesForPod(rule, pod.Pod, netqos, egress) + } + if err != nil { + klog.Errorf("failed to create network rules for pod(uid:%s; ip:%s). err=%v", string(pod.Pod.UID), pod.Pod.Status.PodIP, err) + } + continue + } + + // finally, handled by network rules at the node level + if ipInK8S[string(netqos)] == nil { + ipInK8S[string(netqos)] = sets.NewString() + } + ipInK8S[string(netqos)].Insert(pod.Pod.Status.PodIP) + } + + // delete netqos rules for the deleted pods. + for uid, handle := range rule.uidToHandle { + _, ok := activePods[uid] + if ok { + continue + } + + if err := p.delTcRules(handle); err != nil { + klog.Errorf("failed to delete network rules for pod(uid:%s). err=%v", uid, err) + continue + } + + delete(rule.uidToHandle, uid) + delete(rule.handleToUid, handle) + p.updateRule(rule) + } + + // handle netqos rules at the node level. + for _, setName := range ipsets { + result, err := netlink.IpsetList(setName) + if err != nil || result == nil { + klog.Errorf("failed to get ipset.err=%v", err) + continue + } + + for _, entry := range result.Entries { + if ipInIpset[setName] == nil { + ipInIpset[setName] = sets.NewString() + } + ipInIpset[setName].Insert(entry.IP.String()) + } + } + + for _, setName := range ipsets { + for ip := range ipInK8S[setName].Difference(ipInIpset[setName]) { + klog.V(5).Infof("ready to add %s to ipset %s.", ip, setName) + if err := netlink.IpsetAdd(setName, &netlink.IPSetEntry{IP: net.ParseIP(ip).To4()}); err != nil { + klog.Warningf("failed to write ip %s to ipset %s, err=%v", ip, setName, err) + return err + } + } + + for ip := range ipInIpset[setName].Difference(ipInK8S[setName]) { + klog.V(5).Infof("ready to del ip %s from ipset %s.", ip, setName) + if err := netlink.IpsetDel(setName, &netlink.IPSetEntry{IP: net.ParseIP(ip).To4()}); err != nil { + klog.Warningf("failed to write ip %s to ipset %s, err=%v", ip, setName, err) + return err + } + } + } + + return nil +} + +func (p *tcPlugin) createRulesForPod(rule *tcRule, pod *v1.Pod, netqos NetQoSClass, egress uint64) error { + klog.V(5).Infof("start to create related rules for pod(uid:%s; ip:%s), anno:%v", pod.UID, pod.Status.PodIP, pod.Annotations) + + handle, _ := rule.uidToHandle[pod.UID] + cls := newClass(p.interfLink.Attrs().Index, rootClass, handle, egress, egress, GetPrio(netqos)) + err := p.ensureClass(p.interfLink, cls) + if err != nil { + klog.Errorf("failed to create class for pod %s, err=%v", string(pod.UID), err) + return err + } + minorHex := getMinorId(handle) + // tc filter add dev eth0 parent 1:0 protocol ip prio 2 u32 match ip dst 0.0.0.0/0 flowid 1:5 + genFilterCmd := func(op string) exec.Cmd { + // tc filter add dev br0 parent 1:0 protocol ip prio 2 u32 match ip src 1.2.0.0 classid 1:5 + return exec.New().Command("tc", "filter", op, "dev", p.interfLink.Attrs().Name, + "parent", fmt.Sprintf("%d:", ROOT_CLASS_MINOR_ID), + "protocol", "ip", "prio", strconv.Itoa(getPrio(p.interfLink.Attrs().Name)), + "u32", "match", "ip", "src", pod.Status.PodIP, + "classid", fmt.Sprintf("%d:%s", ROOT_CLASS_MINOR_ID, minorHex), + ) + } + + if err := p.ensureFilter(KeyByFlowId(minorHex), genFilterCmd("add"), genFilterCmd("change")); err != nil { + klog.Errorf("failed to create class for pod %s, err=%v", string(pod.UID), err) + return err + } + + return nil +} + +// getPrio get next available priority for tc filter. +func getPrio(name string) int { + // output just like this: + // filter parent 1: protocol ip pref 1 cgroup chain 0 + // filter parent 1: protocol ip pref 1 cgroup chain 0 handle 0x3 + output, err := exec.New().Command("tc", "filter", "show", "dev", name).CombinedOutput() + if err != nil { + return 0 + } + strs := strings.Split(strings.TrimSpace(string(output)), "\n") + if len(strs) == 0 { + return 0 + } + + lastLine := "" + for i := len(strs) - 1; i >= 0; i-- { + if strings.HasPrefix(strs[i], "filter") { + lastLine = strs[i] + break + } + } + + params := strings.Fields(lastLine) + for idx, param := range params { + if param == "pref" { + prio, _ := strconv.Atoi(params[idx+1]) + return prio + 1 + } + } + + return 0 +} + +// initHandleId get class minor id from pod.uid(last 4 digits). +func initHandleId(uid types.UID, handleToUid map[uint32]types.UID, uidToHandle map[types.UID]uint32) error { + if _, ok := uidToHandle[uid]; ok { + return nil + } + + if len(handleToUid) >= (1<<16)-1 { + return errors.New("tc class is too much") + } + + for { + minorId := rand.Int31n(MAJOR_ID << 16) + handleId := netlink.MakeHandle(MAJOR_ID, uint16(minorId)) + if _, ok := handleToUid[handleId]; !ok { + handleToUid[handleId] = uid + uidToHandle[uid] = handleId + return nil + } + } +} + +func (p *tcPlugin) InitRelatedRules() error { + return apierror.NewAggregate([]error{ + p.EnsureQdisc(), + p.EnsureClasses(), + p.EnsureCgroupFilters(), + p.EnsureIpset(), + p.EnsureIptables(), + }) +} + +func (p *tcPlugin) CleanUp() error { + return apierror.NewAggregate([]error{ + p.DelQdisc(), + p.DelIptables(), + p.DestoryIpset(), + }) +} + +func (p *tcPlugin) EnsureQdisc() error { + klog.V(5).Infoln("start to create qdisc for default net interface") + attrs := netlink.QdiscAttrs{ + LinkIndex: p.interfLink.Attrs().Index, + Handle: netlink.MakeHandle(MAJOR_ID, QDISC_MINOR_ID), + Parent: netlink.HANDLE_ROOT, + } + htb := netlink.NewHtb(attrs) + htb.Defcls = SYSTEM_CLASS_MINOR_ID + + qdiscs, err := p.netLinkHandler.QdiscList(p.interfLink) + if err != nil { + return fmt.Errorf("failed to get qdisc. err=%v", err) + } + + for _, qdisc := range qdiscs { + if qdisc.Type() != "htb" { + continue + } + if qdisc.Attrs().Handle == htb.Handle { + return nil + } + if err := netlink.QdiscDel(htb); err != nil { + return fmt.Errorf("failed to delete old qidsc on %s, err=%v", p.interfLink.Attrs().Name, err) + } + } + + return p.netLinkHandler.QdiscAdd(htb) +} + +func (p *tcPlugin) DelQdisc() error { + klog.V(5).Infof("start to delete qdisc created by tc plugin.") + attrs := netlink.QdiscAttrs{ + LinkIndex: p.interfLink.Attrs().Index, + Handle: netlink.MakeHandle(MAJOR_ID, QDISC_MINOR_ID), + Parent: netlink.HANDLE_ROOT, + } + htb := netlink.NewHtb(attrs) + + qdiscs, err := p.netLinkHandler.QdiscList(p.interfLink) + if err != nil { + return err + } + + for _, qdisc := range qdiscs { + if qdisc.Type() == "htb" && qdisc.Attrs().Handle == htb.Handle { + if err := netlink.QdiscDel(htb); err != nil { + return fmt.Errorf("failed to delete old qidsc on %s, err=%v", p.interfLink.Attrs().Name, err) + } + } + } + + return nil +} + +func (p *tcPlugin) EnsureClasses() error { + klog.V(5).Infof("start to create tc class rules.") + r := p.getRule() + if r == nil { + klog.V(5).Infof("hook plugin rule is nil, nothing to do for plugin %v", name) + return nil + } + + return apierror.NewAggregate([]error{ + p.ensureClass(p.interfLink, newClass(p.interfLink.Attrs().Index, netlink.HANDLE_ROOT, rootClass, r.netCfg.HwTxBpsMax, r.netCfg.HwTxBpsMax, 0)), + p.ensureClass(p.interfLink, newClass(p.interfLink.Attrs().Index, rootClass, systemClass, r.netCfg.HwTxBpsMax-r.netCfg.L1TxBpsMin-r.netCfg.L2TxBpsMin, r.netCfg.HwTxBpsMax, SYSTEM_CLASS_PRIO)), + p.ensureClass(p.interfLink, newClass(p.interfLink.Attrs().Index, rootClass, lsClass, r.netCfg.L1TxBpsMin, r.netCfg.L1TxBpsMax, LS_CLASS_PRIO)), + p.ensureClass(p.interfLink, newClass(p.interfLink.Attrs().Index, rootClass, beClass, r.netCfg.L2TxBpsMin, r.netCfg.L2TxBpsMax, BE_CLASS_PRIO)), + }) +} + +func (p *tcPlugin) EnsureCgroupFilters() error { + klog.V(5).Infof("start to create tc cgroup filter rules.") + genFilterCmd := func(op string, prio int, minor uint16) exec.Cmd { + return exec.New().Command("tc", "filter", op, "dev", p.interfLink.Attrs().Name, + "protocol", "ip", "prio", fmt.Sprintf("%d", prio), + "parent", fmt.Sprintf("%d:", ROOT_CLASS_MINOR_ID), + "handle", fmt.Sprintf("%d:", minor), + "cgroup", + ) + } + + return apierror.NewAggregate([]error{ + p.ensureFilter(KeyByHandle(strconv.Itoa(SYSTEM_CLASS_MINOR_ID)), genFilterCmd("add", SYSTEM_CLASS_PRIO, SYSTEM_CLASS_MINOR_ID), + genFilterCmd("change", SYSTEM_CLASS_PRIO, SYSTEM_CLASS_MINOR_ID)), + p.ensureFilter(KeyByHandle(strconv.Itoa(LS_CLASS_MINOR_ID)), genFilterCmd("add", LS_CLASS_PRIO, LS_CLASS_MINOR_ID), + genFilterCmd("change", LS_CLASS_PRIO, LS_CLASS_MINOR_ID)), + p.ensureFilter(KeyByHandle(strconv.Itoa(BE_CLASS_MINOR_ID)), genFilterCmd("add", BE_CLASS_PRIO, BE_CLASS_MINOR_ID), + genFilterCmd("change", BE_CLASS_PRIO, BE_CLASS_MINOR_ID)), + }) +} + +func KeyByHandle(clsMinorId string) string { + return "handle 0x" + clsMinorId +} + +func KeyByFlowId(clsMinorId string) string { + return fmt.Sprintf("flowid %d:%s", MAJOR_ID, clsMinorId) +} + +func GetPrio(qos NetQoSClass) uint32 { + m := map[NetQoSClass]uint32{ + NETQoSSystem: SYSTEM_CLASS_PRIO, + NETQoSLS: LS_CLASS_PRIO, + NETQoSBE: BE_CLASS_PRIO, + NETQoSNone: LS_CLASS_PRIO, + } + + return m[qos] +} + +func newClass(index int, parent, handle uint32, rate, ceil uint64, prio uint32) *netlink.HtbClass { + attr := netlink.ClassAttrs{ + LinkIndex: index, + Parent: parent, + Handle: handle, + } + classAttr := netlink.HtbClassAttrs{ + Rate: rate, + Ceil: ceil, + Prio: prio, + } + htbClass := NewHtbClass(attr, classAttr) + if htbClass.Cbuffer < 200 { + htbClass.Cbuffer = 200 + } + if htbClass.Buffer < 200 { + htbClass.Buffer = 200 + } + + return htbClass +} + +// NewHtbClass NOTE: function is in here because it uses other linux functions +func NewHtbClass(attrs netlink.ClassAttrs, cattrs netlink.HtbClassAttrs) *netlink.HtbClass { + mtu := 1600 + rate := cattrs.Rate / 8 + ceil := cattrs.Ceil / 8 + buffer := cattrs.Buffer + cbuffer := cattrs.Cbuffer + + if ceil == 0 { + ceil = rate + } + + if buffer == 0 { + buffer = uint32(float64(rate)/netlink.Hz() + float64(mtu)) + klog.V(2).Infof("buffer[%v]=rate[%v]/hz[%v]+mtu[%v]\n", buffer, rate, netlink.Hz(), mtu) + } + dstBuffer := netlink.Xmittime(rate, buffer) + klog.V(2).Infof("buffer[%v]=(1000000*(srcBuffer[%v]/rate[%v]))/tick[%v]\n", dstBuffer, buffer, rate, netlink.TickInUsec()) + + if cbuffer == 0 { + cbuffer = uint32(float64(ceil)/netlink.Hz() + float64(mtu)) + klog.V(2).Infof("cbuffer[%v]=ceil[%v]/hz[%v]+mtu[%v]\n", cbuffer, ceil, netlink.Hz(), mtu) + } + dstCbuffer := netlink.Xmittime(ceil, cbuffer) + klog.V(2).Infof("cbuffer[%v]=(1000000*(srcCbuffer[%v]/ceil[%v]))/tick[%v]", dstCbuffer, cbuffer, ceil, netlink.TickInUsec()) + + return &netlink.HtbClass{ + ClassAttrs: attrs, + Rate: rate, + Ceil: ceil, + Buffer: buffer, + Cbuffer: cbuffer, + Level: 0, + Prio: cattrs.Prio, + Quantum: cattrs.Quantum, + } +} + +func (p *tcPlugin) ensureClass(nic netlink.Link, expect *netlink.HtbClass) error { + classes, err := netlink.ClassList(nic, 0) + if err != nil { + return fmt.Errorf("failed to get tc class. err=%v", err) + } + var existing *netlink.HtbClass + for _, class := range classes { + switch class.(type) { + case *netlink.HtbClass: + htbClass := class.(*netlink.HtbClass) + if htbClass != nil && + htbClass.Handle == expect.Handle && + htbClass.Parent == expect.Parent { + existing = htbClass + break + } + } + } + + if existing != nil { + if expect.Rate == existing.Rate && + expect.Ceil == existing.Ceil && + expect.Prio == existing.Prio && + expect.Buffer == existing.Buffer && + expect.Cbuffer == existing.Cbuffer { + return nil + } + if err := netlink.ClassChange(expect); err != nil { + return fmt.Errorf("failed to change class from %v to %v on interface %s. err=: %v", existing, expect, p.interfLink.Attrs().Name, err) + } + klog.Infof("succeed to changed htb class from %v to %v on interface %s.", existing, expect, p.interfLink.Attrs().Name) + return nil + } + if err := netlink.ClassAdd(expect); err != nil { + return fmt.Errorf("failed to create htb class %v: %v on interface %s. err=%v", expect, err, p.interfLink.Attrs().Name, err) + } + + klog.V(2).Infof("succed to creat htb class: %v on interface %s\n", expect, p.interfLink.Attrs().Name) + return nil +} + +func (p *tcPlugin) deleteClass(nic netlink.Link, expect *netlink.HtbClass) error { + classes, err := netlink.ClassList(nic, 0) + if err != nil { + return fmt.Errorf("failed to get tc class. err=%v", err) + } + + for _, class := range classes { + switch class.(type) { + case *netlink.HtbClass: + htbClass := class.(*netlink.HtbClass) + if htbClass != nil && + htbClass.Handle == expect.Handle && + htbClass.Parent == expect.Parent { + return netlink.ClassDel(expect) + } + } + } + + return nil +} + +func (p *tcPlugin) deleteFilter(key string, delFunc exec.Cmd) error { + matchCmd := exec.New().Command("tc", "filter", "show", "dev", p.interfLink.Attrs().Name) + data, err := matchCmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to get tc filter by key:%s, err:%v", key, err) + } + + if !strings.Contains(string(data), key) { + return nil + } + + data, err = delFunc.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to delete tc filter, output: %s, err: %v", string(data), err) + } + klog.V(5).Infof("succeed to delete filter for %s ", p.interfLink.Attrs().Name) + + return nil +} + +func (p *tcPlugin) ensureFilter(find string, createCmd, updateCmd exec.Cmd) error { + matchCmd := exec.New().Command("tc", "filter", "show", "dev", p.interfLink.Attrs().Name) + data, err := matchCmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to get tc filter by key:%s, err:%v", find, err) + } + + if strings.Contains(string(data), find) { + return nil + } + + // handled by command because netlink does not support creating filters for cgroup types. + // creating a tc filter for be netqos pod, just as follows: + // tc filter add dev eth0 parent 1:0 protocol ip prio 2 match ip src 1.2.0.0 classid 1:5 + data, err = createCmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to create tc filter, output: %s, err: %v", string(data), err) + } + klog.V(5).Infof("%s created filter", p.interfLink.Attrs().Name) + + return nil +} + +func (p *tcPlugin) checkAllRulesExisted() (bool, error) { + if _, err := p.QdiscExisted(); err != nil { + return false, err + } + + if _, err := p.classesExisted(); err != nil { + return false, err + } + + if _, err := p.ipsetExisted(); err != nil { + return false, err + } + + if _, err := p.iptablesExisted(); err != nil { + return false, err + } + + return true, nil +} + +func (p *tcPlugin) QdiscExisted() (bool, error) { + attrs := netlink.QdiscAttrs{ + LinkIndex: p.interfLink.Attrs().Index, + Handle: netlink.MakeHandle(MAJOR_ID, QDISC_MINOR_ID), + Parent: netlink.HANDLE_ROOT, + } + htb := netlink.NewHtb(attrs) + + qdiscs, err := p.netLinkHandler.QdiscList(p.interfLink) + if err != nil || qdiscs == nil { + return false, err + } + + if len(qdiscs) == 1 && qdiscs[0].Type() == "htb" && + qdiscs[0].Attrs().Handle == htb.Handle { + return true, nil + } + + return false, fmt.Errorf("qdisc not found") +} + +func (p *tcPlugin) classesExisted() (bool, error) { + link, err := netlink.LinkByIndex(p.interfLink.Attrs().Index) + if err != nil { + return false, err + } + + r := p.getRule() + if r == nil { + klog.V(5).Infof("hook plugin rule is nil, nothing to do for plugin %v", name) + return false, nil + } + + maxCeil := r.speed * CEIL_PERCENTAGE / 100 + // other leaf class + highClassRate := r.speed * SYSTEM_CLASS_RATE_PERCENTAGE / 100 + midClassRate := r.speed * LS_CLASS_RATE_PERCENTAGE / 100 + lowClassRate := r.speed * BE_CLASS_RATE_PERCENTAGE / 100 + + errs := apierror.NewAggregate([]error{ + p.classExisted(link, newClass(p.interfLink.Attrs().Index, netlink.HANDLE_ROOT, rootClass, maxCeil, maxCeil, 0)), + p.classExisted(link, newClass(p.interfLink.Attrs().Index, rootClass, systemClass, highClassRate, maxCeil, SYSTEM_CLASS_PRIO)), + p.classExisted(link, newClass(p.interfLink.Attrs().Index, rootClass, lsClass, midClassRate, maxCeil, LS_CLASS_PRIO)), + p.classExisted(link, newClass(p.interfLink.Attrs().Index, rootClass, beClass, lowClassRate, maxCeil, BE_CLASS_PRIO)), + }) + + if errs != nil { + return false, errs + } + + return true, nil +} + +func (p *tcPlugin) classExisted(nic netlink.Link, expect *netlink.HtbClass) error { + classes, err := netlink.ClassList(nic, 0) + if err != nil { + return err + } + + for _, class := range classes { + htbClass := class.(*netlink.HtbClass) + if class.Type() == "htb" && + htbClass.Handle == expect.Handle && + htbClass.Parent == expect.Parent { + return nil + } + } + + return fmt.Errorf("class(classid:%d) not find", expect.Handle) +} diff --git a/pkg/koordlet/runtimehooks/hooks/tc/tc_linux_test.go b/pkg/koordlet/runtimehooks/hooks/tc/tc_linux_test.go new file mode 100644 index 000000000..61f8212c8 --- /dev/null +++ b/pkg/koordlet/runtimehooks/hooks/tc/tc_linux_test.go @@ -0,0 +1,322 @@ +//go:build linux +// +build linux + +/* +Copyright 2022 The Koordinator Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tc + +import ( + "fmt" + "os" + "testing" + + "github.com/coreos/go-iptables/iptables" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/vishvananda/netlink" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/klog/v2" + + "github.com/koordinator-sh/koordinator/pkg/koordlet/runtimehooks/hooks" + "github.com/koordinator-sh/koordinator/pkg/koordlet/statesinformer" + "github.com/koordinator-sh/koordinator/pkg/koordlet/util/system" +) + +func TestObject(t *testing.T) { + t.Run("test", func(t *testing.T) { + b := Object() + assert.NotNil(t, b) + b1 := Object() + assert.Equal(t, b, b1) + }) +} + +func Test_bvtPlugin_Register(t *testing.T) { + t.Run("register tc plugin", func(t *testing.T) { + r := &tcPlugin{} + r.Register(hooks.Options{}) + }) +} + +func newTestTCPlugin() *tcPlugin { + klog.Info("start to init net qos manager") + p := tcPlugin{ + netLinkHandler: netlink.Handle{}, + rule: &tcRule{ + enable: true, + netCfg: &NetQosGlobalConfig{ + HwTxBpsMax: 1000000000, + HwRxBpsMax: 1000000000, + L1TxBpsMin: 500000000, + L1TxBpsMax: 1000000000, + L2TxBpsMin: 500000000, + L2TxBpsMax: 1000000000, + L1RxBpsMin: 500000000, + L1RxBpsMax: 1000000000, + L2RxBpsMin: 500000000, + L2RxBpsMax: 1000000000, + }, + }, + } + + linkInfo, err := system.GetLinkInfoByDefaultRoute() + if err != nil { + klog.Errorf("failed to get link info by default route. err=%v\n", err) + return nil + } + if linkInfo == nil || linkInfo.Attrs() == nil { + klog.Errorf("link info is nil") + return nil + } + + p.interfLink = linkInfo + + ipt, err := iptables.New() + if err != nil { + klog.Errorf("failed to get iptables handler in those dir(%s). err=%v\n", os.Getenv("PATH"), err) + return nil + } + p.iptablesHandler = ipt + + return &p +} + +func TestTCPlugin_Init(t *testing.T) { + plugin := newTestTCPlugin() + if err := plugin.InitRelatedRules(); err != nil { + return + } + + tests := []struct { + name string + preHandle func() error + wantErr bool + endHandle func() error + }{ + { + name: "tc qdisc rules already existed", + preHandle: plugin.EnsureQdisc, + wantErr: false, + endHandle: plugin.CleanUp, + }, + { + name: "tc class rules already existed", + preHandle: func() error { + return errors.NewAggregate([]error{ + plugin.EnsureQdisc(), + plugin.EnsureClasses(), + }) + }, + wantErr: false, + endHandle: plugin.CleanUp, + }, + { + name: "ipset rules already existed", + preHandle: plugin.EnsureIpset, + wantErr: false, + endHandle: plugin.CleanUp, + }, + { + name: "iptables rules already existed", + preHandle: func() error { + return errors.NewAggregate([]error{ + plugin.EnsureIpset(), + plugin.EnsureIptables(), + }) + }, + wantErr: false, + endHandle: plugin.CleanUp, + }, + { + name: "all rulues have already been inited", + preHandle: plugin.InitRelatedRules, + wantErr: false, + endHandle: plugin.CleanUp, + }, + { + name: "cleanup all rules will be used in advance", + preHandle: plugin.CleanUp, + wantErr: false, + endHandle: plugin.CleanUp, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.preHandle(); err != nil { + t.Errorf("failed to run preHandle.err=%v", err) + return + } + if err := plugin.InitRelatedRules(); (err != nil) != tt.wantErr { + t.Errorf("Init() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err := tt.endHandle(); err != nil { + t.Errorf("failed to run endHandle.err=%v", err) + return + } + }) + } +} + +func genPod(podName, netqos, ip string) *statesinformer.PodMeta { + return &statesinformer.PodMeta{ + Pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Labels: map[string]string{ + "koordinator.sh/netQoSClass": netqos, + }, + }, + Status: corev1.PodStatus{ + PodIP: ip, + }, + }, + } +} + +func TestTCPlugin_Callback(t *testing.T) { + pod1 := genPod("pod1", "high_class", "192.168.0.1") + pod2 := genPod("pod2", "high_class", "192.168.0.2") + pod3 := genPod("pod3", "mid_class", "192.168.0.3") + pod4 := genPod("pod4", "low_class", "192.168.0.4") + pod5 := genPod("pod5", "", "192.168.0.5") + pod6 := genPod("pod6", "low_class", "192.168.0.6") + + plugin := newTestTCPlugin() + if err := plugin.InitRelatedRules(); err != nil { + klog.Errorf("failed to init some necessary info tc plugin.") + return + } + defer plugin.CleanUp() + + type args struct { + targets *statesinformer.CallbackTarget + } + + tests := []struct { + name string + args args + wantFields *int64 + ipsetExpected map[string][]string + }{ + { + name: "", + args: args{ + targets: &statesinformer.CallbackTarget{ + Pods: []*statesinformer.PodMeta{ + pod1, pod2, pod3, pod4, pod5, + }, + }, + }, + wantFields: nil, + ipsetExpected: map[string][]string{ + "high_class": {"192.168.0.1", "192.168.0.2"}, + "mid_class": {"192.168.0.3"}, + "low_class": {"192.168.0.4"}, + }, + }, + { + name: "", + args: args{ + targets: &statesinformer.CallbackTarget{ + Pods: []*statesinformer.PodMeta{ + pod2, pod3, pod4, pod5, pod6, + }, + }, + }, + wantFields: nil, + ipsetExpected: map[string][]string{ + "high_class": {"192.168.0.2"}, + "mid_class": {"192.168.0.3"}, + "low_class": {"192.168.0.4", "192.168.0.6"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + if err := plugin.ruleUpdateCbForPod(tt.args.targets); err != nil { + klog.Errorf("failed to process ruleUpdateCb, err=%v", err) + return + } + if _, err := plugin.checkAllRulesExisted(); err != nil { + t.Errorf("some necessary rules not existed. err=%v", err) + return + } + if !checkIpsetIsRight(tt.ipsetExpected) { + t.Errorf("ipset rules not the same as expected") + return + } + }) + } +} + +func checkIpsetIsRight(rules map[string][]string) bool { + for setName, ips := range rules { + for _, ip := range ips { + if !ipsetEntryExisted(setName, ip) { + fmt.Printf("%s:%s ipset rules not the same as expected\n", setName, ip) + return false + } + } + } + + return true +} + +func ipsetEntryExisted(setName, ip string) bool { + result, err := netlink.IpsetList(setName) + if err != nil || result == nil { + return false + } + + for _, entry := range result.Entries { + if entry.IP.String() == ip { + return true + } + } + + return false +} + +func Test_getMinorId(t *testing.T) { + type args struct { + num uint32 + } + tests := []struct { + name string + args args + want string + }{ + // TODO: Add test cases. + { + name: "demo", + args: args{num: 115914}, + want: "c4ca", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, getMinorId(tt.args.num), "getMinorId(%v)", tt.args.num) + }) + } +} diff --git a/pkg/koordlet/runtimehooks/protocol/pod_context.go b/pkg/koordlet/runtimehooks/protocol/pod_context.go index 58f6d5897..ea1ec3106 100644 --- a/pkg/koordlet/runtimehooks/protocol/pod_context.go +++ b/pkg/koordlet/runtimehooks/protocol/pod_context.go @@ -285,6 +285,20 @@ func (p *PodContext) injectForExt() { p.Request.PodMeta.Namespace, p.Request.PodMeta.Name, *p.Response.Resources.MemoryLimit, p.Request.CgroupParent) } } + + if p.Response.Resources.NetClsClassId != nil { + eventHelper := audit.V(3).Pod(p.Request.PodMeta.Namespace, p.Request.PodMeta.Name).Reason("runtime-hooks").Message( + "set pod net class id to %v", *p.Response.Resources.NetClsClassId) + updater, err := injectNetClsClassId(p.Request.CgroupParent, *p.Response.Resources.NetClsClassId, eventHelper, p.executor) + if err != nil { + klog.Infof("set pod %v/%v net class id %v on cgroup parent %v failed, error %v", p.Request.PodMeta.Namespace, + p.Request.PodMeta.Name, *p.Response.Resources.NetClsClassId, p.Request.CgroupParent, err) + } else { + p.updaters = append(p.updaters, updater) + klog.V(5).Infof("set pod %v/%v net class id %v on cgroup parent %v", + p.Request.PodMeta.Namespace, p.Request.PodMeta.Name, *p.Response.Resources.NetClsClassId, p.Request.CgroupParent) + } + } } func (p *PodContext) removeForExt() { diff --git a/pkg/koordlet/runtimehooks/protocol/protocol.go b/pkg/koordlet/runtimehooks/protocol/protocol.go index 6950ba6c7..b3f6a93c8 100644 --- a/pkg/koordlet/runtimehooks/protocol/protocol.go +++ b/pkg/koordlet/runtimehooks/protocol/protocol.go @@ -73,10 +73,11 @@ var HooksProtocolBuilder = hooksProtocolBuilder{ type Resources struct { // origin resources - CPUShares *int64 - CFSQuota *int64 - CPUSet *string - MemoryLimit *int64 + CPUShares *int64 + CFSQuota *int64 + CPUSet *string + MemoryLimit *int64 + NetClsClassId *uint32 // extended resources CPUBvt *int64 @@ -177,3 +178,12 @@ func injectCPUIdle(cgroupParent string, idleValue int64, a *audit.EventHelper, e } return updater, nil } + +func injectNetClsClassId(cgroupParent string, classId uint32, a *audit.EventHelper, e resourceexecutor.ResourceUpdateExecutor) (resourceexecutor.ResourceUpdater, error) { + clsIdStr := strconv.FormatUint(uint64(classId), 10) + updater, err := resourceexecutor.DefaultCgroupUpdaterFactory.New(sysutil.NetClsClassIdName, cgroupParent, clsIdStr, a) + if err != nil { + return nil, err + } + return updater, nil +} diff --git a/pkg/koordlet/runtimehooks/reconciler/reconciler.go b/pkg/koordlet/runtimehooks/reconciler/reconciler.go index 2c382ecd5..8f64e6c4d 100644 --- a/pkg/koordlet/runtimehooks/reconciler/reconciler.go +++ b/pkg/koordlet/runtimehooks/reconciler/reconciler.go @@ -17,6 +17,7 @@ limitations under the License. package reconciler import ( + "strconv" "sync" "time" @@ -107,6 +108,7 @@ type podQOSFilter struct{} const ( PodQOSFilterName = "podQOS" + HostNetWork = "hostNetwork" ) func (p *podQOSFilter) Name() string { @@ -127,6 +129,16 @@ func (p *podQOSFilter) Filter(podMeta *statesinformer.PodMeta) string { return string(qosClass) } +type podHostNetworkFilter struct{} + +func (p *podHostNetworkFilter) Name() string { + return HostNetWork +} + +func (p *podHostNetworkFilter) Filter(podMeta *statesinformer.PodMeta) string { + return strconv.FormatBool(podMeta.Pod.Spec.HostNetwork) +} + var singletonPodQOSFilter *podQOSFilter // PodQOSFilter returns a Filter which filters pod qos class @@ -137,6 +149,16 @@ func PodQOSFilter() Filter { return singletonPodQOSFilter } +var singletonPodHostNetworkFilter *podHostNetworkFilter + +// PodHostNetworkFilter returns a Filter which filters pod hostnetwork is true +func PodHostNetworkFilter() *podHostNetworkFilter { + if singletonPodQOSFilter == nil { + singletonPodHostNetworkFilter = &podHostNetworkFilter{} + } + return singletonPodHostNetworkFilter +} + type reconcileFunc func(protocol.HooksProtocol) error type reconcileFunc4AllPods func([]protocol.HooksProtocol) error diff --git a/pkg/koordlet/util/system/cgroup_resource.go b/pkg/koordlet/util/system/cgroup_resource.go index d89ebbd95..2915cbeaf 100644 --- a/pkg/koordlet/util/system/cgroup_resource.go +++ b/pkg/koordlet/util/system/cgroup_resource.go @@ -111,6 +111,7 @@ const ( // subsystems CgroupCPUAcctDir string = "cpuacct/" CgroupMemDir string = "memory/" CgroupBlkioDir string = "blkio/" + CgroupNetClsDir string = "net_cls/" CgroupV2Dir = "" ) @@ -170,6 +171,8 @@ const ( BlkioIOWeightName = "blkio.cost.weight" BlkioIOQoSName = "blkio.cost.qos" BlkioIOModelName = "blkio.cost.model" + + NetClsClassIdName = "net_cls.classid" ) var ( @@ -195,6 +198,8 @@ var ( BlkioIOQoSValidator = &BlkIORangeValidator{min: 0, max: math.MaxInt64, resource: BlkioIOQoSName} BlkioIOModelValidator = &BlkIORangeValidator{min: 1, max: math.MaxInt64, resource: BlkioIOModelName} + NetClsClassIdValidator = &NetClsRangeValidator{resource: NetClsClassIdName} + CPUSetCPUSValidator = &CPUSetStrValidator{} ) @@ -243,6 +248,8 @@ var ( BlkioIOQoS = DefaultFactory.New(BlkioIOQoSName, CgroupBlkioDir).WithValidator(BlkioIOQoSValidator).WithSupported(SupportedIfFileExistsInRootCgroup(BlkioIOQoSName, CgroupBlkioDir)) BlkioIOModel = DefaultFactory.New(BlkioIOModelName, CgroupBlkioDir).WithValidator(BlkioIOModelValidator).WithSupported(SupportedIfFileExistsInRootCgroup(BlkioIOModelName, CgroupBlkioDir)) + NetClsClassId = DefaultFactory.New(NetClsClassIdName, CgroupNetClsDir).WithValidator(NetClsClassIdValidator).WithCheckSupported(SupportedIfFileExistsInKubepods).WithCheckOnce(true) + knownCgroupResources = []Resource{ CPUStat, CPUShares, @@ -280,6 +287,7 @@ var ( BlkioIOWeight, BlkioIOQoS, BlkioIOModel, + NetClsClassId, } CPUCFSQuotaV2 = DefaultFactory.NewV2(CPUCFSQuotaName, CPUMaxName) @@ -345,6 +353,8 @@ var ( MemoryUsePriorityOomV2, MemoryOomGroupV2, // TODO: register BlkioIOWeight, BlkioIOQoS and BlkioIOModel + + NetClsClassId, } ) diff --git a/pkg/koordlet/util/system/common_linux.go b/pkg/koordlet/util/system/common_linux.go index 8272d4eba..5854d8e49 100644 --- a/pkg/koordlet/util/system/common_linux.go +++ b/pkg/koordlet/util/system/common_linux.go @@ -23,6 +23,7 @@ import ( "bytes" "fmt" "io" + "net" "os" "os/exec" "path" @@ -36,6 +37,7 @@ import ( "unicode" "github.com/cakturk/go-netstat/netstat" + "github.com/vishvananda/netlink" utilerrors "k8s.io/apimachinery/pkg/util/errors" "k8s.io/klog/v2" ) @@ -222,3 +224,20 @@ func WorkingDirOf(pid int) (string, error) { return strings.TrimSpace(tokens[1]), nil } } + +func GetLinkInfoByDefaultRoute() (netlink.Link, error) { + routes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{}, netlink.RT_FILTER_DST) + if err != nil { + return nil, err + } + if len(routes) == 0 { + return nil, fmt.Errorf("not find route info by dst ip=%s", net.IPv4zero.String()) + } + + linkInfo, err := netlink.LinkByIndex(routes[0].LinkIndex) + if err != nil { + return nil, err + } + + return linkInfo, nil +} diff --git a/pkg/koordlet/util/system/common_linux_test.go b/pkg/koordlet/util/system/common_linux_test.go index 03390d166..eb92b2d10 100644 --- a/pkg/koordlet/util/system/common_linux_test.go +++ b/pkg/koordlet/util/system/common_linux_test.go @@ -29,6 +29,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/vishvananda/netlink" ) func Test_KubeletPortToPid(t *testing.T) { @@ -115,3 +116,24 @@ func Test_WorkingDirOf(t *testing.T) { assert.NotEmpty(t, err) }) } + +func TestGetLinkInfoByDefaultRoute(t *testing.T) { + tests := []struct { + name string + want netlink.Link + wantErr error + }{ + // TODO: Add test cases. + { + name: "normal", + want: nil, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := GetLinkInfoByDefaultRoute() + assert.Equalf(t, tt.wantErr, err, "GetLinkInfoByDefaultRoute()") + }) + } +} diff --git a/pkg/koordlet/util/system/util_test_tool.go b/pkg/koordlet/util/system/util_test_tool.go index 4384e7726..a79c05b74 100644 --- a/pkg/koordlet/util/system/util_test_tool.go +++ b/pkg/koordlet/util/system/util_test_tool.go @@ -50,6 +50,7 @@ var ( BlkioReadIops, BlkioWriteBps, BlkioWriteIops, + NetClsClassId, } ) diff --git a/pkg/koordlet/util/system/validator.go b/pkg/koordlet/util/system/validator.go index 9b89a6ef2..591f53290 100644 --- a/pkg/koordlet/util/system/validator.go +++ b/pkg/koordlet/util/system/validator.go @@ -121,3 +121,46 @@ func (r *BlkIORangeValidator) Validate(value string) (bool, string) { return true, "" } + +type NetClsRangeValidator struct { + resource string +} + +const ( + maxClassIdDecimal = 41231686041 + maxClassIdHex = 99999999 +) + +func (r *NetClsRangeValidator) Validate(value string) (bool, string) { + if value == "" { + return false, "value is nil" + } + + if r.resource == NetClsClassIdName { + if strings.HasPrefix(value, "0x") { + value = value[2:] + // You can write hexadecimal values to net_cls.classid; the format for these values is 0xAAAABBBB; + // AAAA is the major handle number and BBBB is the minor handle number. Reading net_cls.classid yields a decimal result. + // so, the max length of this value is 8. + hexVal, err := strconv.Atoi(value) + if err != nil { + return false, err.Error() + } + + if hexVal < 0 || hexVal > maxClassIdHex { + return false, "class id is invalid, decimal value must in 0~0x99999999" + } + } else { + decimalVal, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return false, err.Error() + } + + if decimalVal > maxClassIdDecimal || decimalVal < 0 { + return false, fmt.Sprintf("class id is invaild, decimal vaule must in 0~%d", maxClassIdDecimal) + } + } + } + + return true, "" +} diff --git a/pkg/koordlet/util/system/validator_test.go b/pkg/koordlet/util/system/validator_test.go index 433a84d86..39f281cd5 100644 --- a/pkg/koordlet/util/system/validator_test.go +++ b/pkg/koordlet/util/system/validator_test.go @@ -70,3 +70,99 @@ func Test_RangeValidate(t *testing.T) { }) } } + +func TestNetClsRangeValidator_Validate(t *testing.T) { + type fields struct { + resource string + } + type args struct { + value string + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + // TODO: Add test cases. + { + name: "nil", + fields: fields{ + resource: NetClsClassIdName, + }, + args: args{ + value: "", + }, + want: false, + }, + { + name: "not number", + fields: fields{ + resource: NetClsClassIdName, + }, + args: args{ + value: "abc", + }, + want: false, + }, + { + name: "decimal negative number", + fields: fields{ + resource: NetClsClassIdName, + }, + args: args{ + value: "-1", + }, + want: false, + }, + { + name: "decimal positive number but too big", + fields: fields{ + resource: NetClsClassIdName, + }, + args: args{ + value: "111111111111111111111111", + }, + want: false, + }, + { + name: "invalid hex number", + fields: fields{ + resource: NetClsClassIdName, + }, + args: args{ + value: "0xmm", + }, + want: false, + }, + { + name: "negative flag in hex number", + fields: fields{ + resource: NetClsClassIdName, + }, + args: args{ + value: "0x-1", + }, + want: false, + }, + { + name: "invalid number but too big", + fields: fields{ + resource: NetClsClassIdName, + }, + args: args{ + value: "0x1111111111111111", + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &NetClsRangeValidator{ + resource: tt.fields.resource, + } + got, _ := r.Validate(tt.args.value) + assert.Equalf(t, tt.want, got, "Validate(%v)", tt.args.value) + }) + } +}