Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OSPP:Sedna joint inference and federated learning controller optimization #451

Closed
wants to merge 8 commits into from
176 changes: 157 additions & 19 deletions pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import (
"context"
"fmt"
"strconv"
"sync"
"time"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
utilrand "k8s.io/apimachinery/pkg/util/rand"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
Expand Down Expand Up @@ -88,6 +90,15 @@ type Controller struct {
cfg *config.ControllerConfig

sendToEdgeFunc runtime.DownstreamSendFunc

// map to record the pods that are recreated
recreatedPods sync.Map

fl_selector labels.Selector

aggServiceHost string

preventRecreation bool
}

// Run starts the main goroutine responsible for watching and syncing jobs.
Expand Down Expand Up @@ -190,6 +201,49 @@ func (c *Controller) deletePod(obj interface{}) {
}
}
c.enqueueByPod(pod, true)

// when the CRD is updated, do not recreate the pod
// if c.preventRecreation is true, do not recreate the pod
if c.preventRecreation {
return
}
// if pod is manually deleted, recreate it
// first check if the pod is owned by a FederatedLearningJob
controllerRef := metav1.GetControllerOf(pod)
if controllerRef == nil || controllerRef.Kind != Kind.Kind {
return
}

// then check if the pod is already in the map
if _, exists := c.recreatedPods.Load(pod.Name); exists {
return
}

// if not, recreate it
klog.Infof("Pod %s/%s deleted, recreating...", pod.Namespace, pod.Name)
// Create a deep copy of the old pod
newPod := pod.DeepCopy()
// Reset the resource version and UID as they are unique to each object
newPod.ResourceVersion = ""
newPod.UID = ""
// Clear the status
newPod.Status = v1.PodStatus{}
// Remove the deletion timestamp
newPod.DeletionTimestamp = nil
// Remove the deletion grace period seconds
newPod.DeletionGracePeriodSeconds = nil
_, err := c.kubeClient.CoreV1().Pods(pod.Namespace).Create(context.TODO(), newPod, metav1.CreateOptions{})
if err != nil {
return
}
klog.Infof("Successfully recreated pod %s/%s", newPod.Namespace, newPod.Name)
// mark the pod as recreated
c.recreatedPods.Store(newPod.Name, true)
// set a timer to delete the record from the map after a while
go func() {
time.Sleep(5 * time.Second)
c.recreatedPods.Delete(pod.Name)
}()
}

// obj could be an *sednav1.FederatedLearningJob, or a DeletionFinalStateUnknown marker item,
Expand Down Expand Up @@ -271,14 +325,16 @@ func (c *Controller) sync(key string) (bool, error) {
return true, nil
}

selector, _ := runtime.GenerateSelector(&job)
pods, err := c.podStore.Pods(job.Namespace).List(selector)
c.fl_selector, _ = runtime.GenerateSelector(&job)
pods, err := c.podStore.Pods(job.Namespace).List(c.fl_selector)
if err != nil {
return false, err
}

activePods := k8scontroller.FilterActivePods(pods)
active := int32(len(activePods))
var activeAgg int32
var activeTrain int32
succeeded, failed := countPods(pods)
conditions := len(job.Status.Conditions)

Expand All @@ -289,6 +345,8 @@ func (c *Controller) sync(key string) (bool, error) {
}

var manageJobErr error
var manageAggErr error
var manageTrainErr error
jobFailed := false
var failureReason string
var failureMessage string
Expand All @@ -307,7 +365,13 @@ func (c *Controller) sync(key string) (bool, error) {
} else {
// in the First time, we create the pods
if len(pods) == 0 {
active, manageJobErr = c.createPod(&job)
activeAgg, manageAggErr = c.createAggPod(&job)
createServiceErr := c.createService(&job)
if createServiceErr != nil {
return false, createServiceErr
}
activeTrain, manageTrainErr = c.createTrainPod(&job)
active = activeAgg + activeTrain
}
complete := false
if succeeded > 0 && active == 0 {
Expand All @@ -324,6 +388,10 @@ func (c *Controller) sync(key string) (bool, error) {
}
}

// Combine manageAggErr and manageTrainErr into a single error
if manageAggErr != nil || manageTrainErr != nil {
manageJobErr = fmt.Errorf("aggregator error: %v, training error: %v", manageAggErr, manageTrainErr)
}
forget := false
// Check if the number of jobs succeeded increased since the last check. If yes "forget" should be true
// This logic is linked to the issue: https://github.com/kubernetes/kubernetes/issues/56853 that aims to
Expand Down Expand Up @@ -499,8 +567,7 @@ func (c *Controller) addTransmitterToWorkerParam(param *runtime.WorkerParam, job

return nil
}

func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
func (c *Controller) createAggPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
active = 0
ctx := context.Background()

Expand All @@ -513,7 +580,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
modelName := job.Spec.AggregationWorker.Model.Name
model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName)
if err != nil {
return active, err
return active, fmt.Errorf("failed to get aggregation model: %w", err)
}

participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))
Expand All @@ -524,6 +591,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
// Configure aggregation worker's mounts and envs
var aggPort int32 = 7363
var aggWorkerParam runtime.WorkerParam

aggWorkerParam.Env = map[string]string{
"NAMESPACE": job.Namespace,
"WORKER_NAME": "aggworker-" + utilrand.String(5),
Expand All @@ -534,7 +602,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
}

if err := c.addTransmitterToWorkerParam(&aggWorkerParam, job); err != nil {
return active, err
return active, fmt.Errorf("failed to add transmitter to worker param: %w", err)
}

aggWorkerParam.WorkerType = jobStageAgg
Expand All @@ -547,19 +615,36 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
c.addWorkerMount(&aggWorkerParam, pretrainedModel.Spec.URL, "PRETRAINED_MODEL_URL",
pretrainedModelSecret, true)
}

aggWorker.Template.Name = fmt.Sprintf("%s-aggworker", job.Name)
// create aggpod based on configured parameters
_, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &aggWorker.Template, &aggWorkerParam)
if err != nil {
return active, fmt.Errorf("failed to create aggregation worker: %w", err)
}
klog.Infof("create aggpod success")
active++
return
}

func (c *Controller) createTrainPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
active = 0
ctx := context.Background()

aggServiceHost, err := runtime.CreateEdgeMeshService(c.kubeClient, job, jobStageAgg, aggPort)
pretrainedModelName := job.Spec.PretrainedModel.Name
pretrainedModel, pretrainedModelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, pretrainedModelName)
if err != nil {
return active, err
return active, fmt.Errorf("failed to get pretrained model: %w", err)
}

modelName := job.Spec.AggregationWorker.Model.Name
model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName)
if err != nil {
return active, fmt.Errorf("failed to get aggregation model: %w", err)
}

var aggPort int32 = 7363
participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))

// deliver pod for training worker
for i, trainingWorker := range job.Spec.TrainingWorkers {
// Configure training worker's mounts and envs
Expand All @@ -583,7 +668,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,

workerParam.Env = map[string]string{
"AGG_PORT": strconv.Itoa(int(aggPort)),
"AGG_IP": aggServiceHost,
"AGG_IP": c.aggServiceHost,

"WORKER_NAME": "trainworker-" + utilrand.String(5),
"JOB_NAME": job.Name,
Expand All @@ -593,14 +678,15 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
"DATASET_NAME": datasetName,
"LC_SERVER": c.cfg.LC.Server,
}

workerParam.WorkerType = runtime.TrainPodType
workerParam.HostNetwork = true
workerParam.RestartPolicy = v1.RestartPolicyOnFailure

if err := c.addTransmitterToWorkerParam(&workerParam, job); err != nil {
return active, err
return active, fmt.Errorf("failed to add transmitter to worker param: %w", err)
}

trainingWorker.Template.Name = fmt.Sprintf("%s-trainworker-%d", job.Name, i)
// create training worker based on configured parameters
_, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &trainingWorker.Template, &workerParam)
if err != nil {
Expand Down Expand Up @@ -640,13 +726,8 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {
// send it to edge's LC.
fc.syncToEdge(watch.Added, obj)
},
UpdateFunc: func(old, cur interface{}) {
fc.enqueueController(cur, true)
UpdateFunc: fc.updateJob,

// when a federated learning job is updated,
// send it to edge's LC as Added event.
fc.syncToEdge(watch.Added, cur)
},
DeleteFunc: func(obj interface{}) {
fc.enqueueController(obj, true)

Expand All @@ -669,3 +750,60 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {

return fc, nil
}

func (c *Controller) updateJob(old, cur interface{}) {
oldJob, ok := old.(*sednav1.FederatedLearningJob)
if !ok {
return
}
curJob, ok := cur.(*sednav1.FederatedLearningJob)
if !ok {
return
}

if oldJob.ResourceVersion == curJob.ResourceVersion {
return
}

if oldJob.Generation != curJob.Generation {
pods, err := c.podStore.Pods(curJob.Namespace).List(c.fl_selector)
if err != nil {
klog.Errorf("Failed to list pods: %v", err)
}
c.preventRecreation = true
for _, pod := range pods {
// delete all pods
c.kubeClient.CoreV1().Pods(pod.Namespace).Delete(context.TODO(), pod.Name, metav1.DeleteOptions{})
klog.Infof("CRD modified, so we deleted pod %s/%s", pod.Namespace, pod.Name)
}
klog.Infof("CRD modified, so we deleted all pods, and will create new pods")
curJob.SetGroupVersionKind(Kind)
_, err = c.createAggPod(curJob)
if err != nil {
klog.Errorf("Failed to create aggregation worker: %v", err)
}
_, err = c.createTrainPod(curJob)
if err != nil {
klog.Errorf("Failed to create training workers: %v", err)
}
// update the job status
c.client.FederatedLearningJobs(curJob.Namespace).Update(context.TODO(), curJob, metav1.UpdateOptions{})
}

c.preventRecreation = false
c.enqueueController(curJob, true)

// when a federated learning job is updated,
// send it to edge's LC as Added event.
c.syncToEdge(watch.Added, curJob)
}

// create edgemesh service for the job
func (c *Controller) createService(job *sednav1.FederatedLearningJob) (err error) {
var aggPort int32 = 7363
c.aggServiceHost, err = runtime.CreateEdgeMeshService(c.kubeClient, job, jobStageAgg, aggPort)
if err != nil {
return err
}
return nil
}
Loading
Loading