Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
luke-lombardi committed Oct 25, 2024
1 parent f2d7d57 commit 3c4bd5a
Show file tree
Hide file tree
Showing 17 changed files with 96 additions and 60 deletions.
2 changes: 1 addition & 1 deletion pkg/abstractions/endpoint/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (i *endpointInstance) startContainers(containersToRun int) error {
EntryPoint: i.EntryPoint,
Mounts: mounts,
Stub: *i.Stub,
CheckpointEnabled: true, // XXX: Hardcoded for testing
CheckpointEnabled: i.StubConfig.CheckpointEnabled,
}

// Set initial keepwarm to prevent rapid spin-up/spin-down of containers
Expand Down
27 changes: 14 additions & 13 deletions pkg/abstractions/taskqueue/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,20 @@ func (i *taskQueueInstance) startContainers(containersToRun int) error {

for c := 0; c < containersToRun; c++ {
runRequest := &types.ContainerRequest{
ContainerId: i.genContainerId(),
Env: env,
Cpu: i.StubConfig.Runtime.Cpu,
Memory: i.StubConfig.Runtime.Memory,
GpuRequest: gpuRequest,
GpuCount: uint32(gpuCount),
ImageId: i.StubConfig.Runtime.ImageId,
StubId: i.Stub.ExternalId,
WorkspaceId: i.Workspace.ExternalId,
Workspace: *i.Workspace,
EntryPoint: i.EntryPoint,
Mounts: mounts,
Stub: *i.Stub,
ContainerId: i.genContainerId(),
Env: env,
Cpu: i.StubConfig.Runtime.Cpu,
Memory: i.StubConfig.Runtime.Memory,
GpuRequest: gpuRequest,
GpuCount: uint32(gpuCount),
ImageId: i.StubConfig.Runtime.ImageId,
StubId: i.Stub.ExternalId,
WorkspaceId: i.Workspace.ExternalId,
Workspace: *i.Workspace,
EntryPoint: i.EntryPoint,
Mounts: mounts,
Stub: *i.Stub,
CheckpointEnabled: i.StubConfig.CheckpointEnabled,
}

err := i.Scheduler.Run(runRequest)
Expand Down
22 changes: 11 additions & 11 deletions pkg/common/config.default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ worker:
terminationGracePeriod: 30
addWorkerTimeout: 10m
blobCacheEnabled: false
checkpointing:
enabled: false
cedana:
client:
leaveRunning: true
sharedStorage:
dumpStorageDir: /data
connection:
cedanaUrl: url
cedanaAuthToken: token
providers:
ec2:
accessKey:
Expand Down Expand Up @@ -188,14 +198,4 @@ monitoring:
# tag: internal_api
openmeter:
serverUrl: ""
apiKey: ""
checkpointing:
enabled: true
cedana:
client:
leaveRunning: true
sharedStorage:
dumpStorageDir: /data
connection:
cedanaUrl: url
cedanaAuthToken: token
apiKey: ""
13 changes: 4 additions & 9 deletions pkg/gateway/gateway.proto
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ message GetOrCreateStubRequest {
Autoscaler autoscaler = 22;
TaskPolicy task_policy = 23;
uint32 concurrent_requests = 24;
bool checkpoint_enabled = 25;
}

message GetOrCreateStubResponse {
Expand Down Expand Up @@ -478,27 +479,21 @@ message ListWorkersResponse {
repeated Worker workers = 3;
}

message CordonWorkerRequest {
string worker_id = 1;
}
message CordonWorkerRequest { string worker_id = 1; }

message CordonWorkerResponse {
bool ok = 1;
string err_msg = 2;
}

message UncordonWorkerRequest {
string worker_id = 1;
}
message UncordonWorkerRequest { string worker_id = 1; }

message UncordonWorkerResponse {
bool ok = 1;
string err_msg = 2;
}

message DrainWorkerRequest {
string worker_id = 1;
}
message DrainWorkerRequest { string worker_id = 1; }

message DrainWorkerResponse {
bool ok = 1;
Expand Down
1 change: 1 addition & 0 deletions pkg/gateway/services/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea
Secrets: []types.Secret{},
Authorized: in.Authorized,
Autoscaler: autoscaler,
CheckpointEnabled: in.CheckpointEnabled,

Check failure on line 101 in pkg/gateway/services/stub.go

View workflow job for this annotation

GitHub Actions / lint_and_test_go_pkg

in.CheckpointEnabled undefined (type *"github.com/beam-cloud/beta9/proto".GetOrCreateStubRequest has no field or method CheckpointEnabled)
}

// Get secrets
Expand Down
2 changes: 1 addition & 1 deletion pkg/scheduler/pool_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (wpc *LocalKubernetesWorkerPoolController) createWorkerJob(workerId string,
workerImage := fmt.Sprintf("%s/%s:%s",
wpc.config.Worker.ImageRegistry,
wpc.config.Worker.ImageName,
"cedana-w9", //wpc.config.Worker.ImageTag,
"cedana-w11", //wpc.config.Worker.ImageTag,
)

resources := corev1.ResourceRequirements{}
Expand Down
6 changes: 1 addition & 5 deletions pkg/types/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,7 @@ type StubConfigV1 struct {
Volumes []*pb.Volume `json:"volumes"`
Secrets []Secret `json:"secrets,omitempty"`
Autoscaler *Autoscaler `json:"autoscaler"`
Experimental Experimental `json:"experimental"`
}

type Experimental struct {
CheckpointEnabled bool `json:"checkpoint_enabled"`
CheckpointEnabled bool `json:"checkpoint_enabled"`
}

type AutoscalerType string
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ type AppConfig struct {
Proxy ProxyConfig `key:"proxy" json:"proxy"`
Monitoring MonitoringConfig `key:"monitoring" json:"monitoring"`
BlobCache blobcache.BlobCacheConfig `key:"blobcache" json:"blobcache"`
Checkpointing CheckpointingConfig `key:"checkpointing" json:"checkpointing"`
}

type DatabaseConfig struct {
Expand Down Expand Up @@ -195,6 +194,7 @@ type WorkerConfig struct {
AddWorkerTimeout time.Duration `key:"addWorkerTimeout" json:"add_worker_timeout"`
TerminationGracePeriod int64 `key:"terminationGracePeriod"`
BlobCacheEnabled bool `key:"blobCacheEnabled" json:"blob_cache_enabled"`
Checkpointing CheckpointingConfig `key:"checkpointing" json:"checkpointing"`
}

type PoolMode string
Expand Down
4 changes: 2 additions & 2 deletions pkg/worker/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,8 @@ func (s *Worker) isBuildRequest(request *types.ContainerRequest) bool {
return request.SourceImage != nil
}

// Waits for the endpoint to be ready to checkpoint at the desired point in execution, ie.
// after all endpoint workers have reached a checkpointable state
// Waits for the container to be ready to checkpoint at the desired point in execution, ie.
// after all processes within a container have reached a checkpointable state
func (s *Worker) createCheckpoint(ctx context.Context, request *types.ContainerRequest) error {
timeout := defaultCheckpointDeadline
managing := false
Expand Down
6 changes: 3 additions & 3 deletions pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Worker struct {
runcHandle runc.Runc
runcServer *RunCServer
cedanaClient *CedanaClient
checkpointingAvailable bool
fileCacheManager *FileCacheManager
containerNetworkManager *ContainerNetworkManager
containerCudaManager GPUManager
Expand All @@ -64,7 +65,6 @@ type Worker struct {
ctx context.Context
cancel func()
config types.AppConfig
checkpointingAvailable bool
}

type ContainerInstance struct {
Expand Down Expand Up @@ -153,8 +153,8 @@ func NewWorker() (*Worker, error) {
}

var cedanaClient *CedanaClient = nil
if config.Checkpointing.Enabled {
cedanaClient, err = NewCedanaClient(context.Background(), config.Checkpointing.Cedana, gpuType != "")
if config.Worker.Checkpointing.Enabled {
cedanaClient, err = NewCedanaClient(context.Background(), config.Worker.Checkpointing.Cedana, gpuType != "")
if err != nil {
log.Printf("[WARNING] C/R unavailable, failed to create cedana client: %v\n", err)
}
Expand Down
4 changes: 3 additions & 1 deletion sdk/src/beta9/abstractions/base/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
name: Optional[str] = None,
autoscaler: Autoscaler = QueueDepthAutoscaler(),
task_policy: TaskPolicy = TaskPolicy(),
checkpoint_enabled: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
timeout=task_policy.timeout or timeout,
ttl=task_policy.ttl,
)

self.checkpoint_enabled = checkpoint_enabled
if on_start is not None:
self._map_callable_to_attr(attr="on_start", func=on_start)

Expand Down Expand Up @@ -403,6 +404,7 @@ def prepare_runtime(
ttl=self.task_policy.ttl,
),
concurrent_requests=self.concurrent_requests,
checkpoint_enabled=self.checkpoint_enabled,
)
if _is_stub_created_for_workspace():
stub_response: GetOrCreateStubResponse = self.gateway_stub.get_or_create_stub(
Expand Down
7 changes: 7 additions & 0 deletions sdk/src/beta9/abstractions/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ class RealtimeASGI(ASGI):
various autoscaling strategies (Defaults to QueueDepthAutoscaler())
callback_url (Optional[str]):
An optional URL to send a callback to when a task is completed, timed out, or cancelled.
checkpoint_enabled (bool):
(experimental) Whether to enable checkpointing for the task queue. Default is False.
If enabled, the task queue will use be checkpointed after the on_start function has completed.
On next invocation, the task queue will restore from the checkpoint and resume execution instead of
booting up from cold.
Example:
```python
from beta9 import realtime
Expand Down Expand Up @@ -375,6 +380,7 @@ def __init__(
authorized: bool = True,
autoscaler: Autoscaler = QueueDepthAutoscaler(),
callback_url: Optional[str] = None,
checkpoint_enabled: bool = False,
):
super().__init__(
cpu=cpu,
Expand All @@ -393,6 +399,7 @@ def __init__(
autoscaler=autoscaler,
callback_url=callback_url,
concurrent_requests=concurrent_requests,
checkpoint_enabled=checkpoint_enabled,
)

def __call__(self, func):
Expand Down
7 changes: 7 additions & 0 deletions sdk/src/beta9/abstractions/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ class TaskQueue(RunnerAbstraction):
task_policy (TaskPolicy):
The task policy for the function. This helps manage the lifecycle of an individual task.
Setting values here will override timeout and retries.
checkpoint_enabled (bool):
(experimental) Whether to enable checkpointing for the task queue. Default is False.
If enabled, the task queue will use be checkpointed after the on_start function has completed.
On next invocation, the task queue will restore from the checkpoint and resume execution instead of
booting up from cold.
Example:
```python
from beta9 import task_queue, Image
Expand Down Expand Up @@ -119,6 +124,7 @@ def __init__(
authorized: bool = True,
autoscaler: Autoscaler = QueueDepthAutoscaler(),
task_policy: TaskPolicy = TaskPolicy(),
checkpoint_enabled: bool = False,
) -> None:
super().__init__(
cpu=cpu,
Expand All @@ -138,6 +144,7 @@ def __init__(
authorized=authorized,
autoscaler=autoscaler,
task_policy=task_policy,
checkpoint_enabled=checkpoint_enabled,
)
self._taskqueue_stub: Optional[TaskQueueServiceStub] = None

Expand Down
1 change: 1 addition & 0 deletions sdk/src/beta9/clients/gateway/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions sdk/src/beta9/runner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps
from multiprocessing import Value
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union

import requests
Expand Down Expand Up @@ -133,6 +135,9 @@ def new(
)


workers_ready = Value("i", 0)


class FunctionHandler:
"""
Helper class for loading user entry point functions
Expand Down Expand Up @@ -335,3 +340,21 @@ class ThreadPoolExecutorOverride(ThreadPoolExecutor):
def __exit__(self, *_, **__):
# cancel_futures added in 3.9
self.shutdown(cancel_futures=True)


CHECKPOINT_SIGNAL_FILE = "/cedana/READY_FOR_CHECKPOINT"


def wait_for_checkpoint():
with workers_ready.get_lock():
workers_ready.value += 1

if workers_ready.value == config.workers:
Path(CHECKPOINT_SIGNAL_FILE).touch(exist_ok=True)
return

while True:
with workers_ready.get_lock():
if workers_ready.value == config.workers:
break
time.sleep(1)
22 changes: 9 additions & 13 deletions sdk/src/beta9/runner/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import traceback
from contextlib import asynccontextmanager
from http import HTTPStatus
from multiprocessing import Value
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

from fastapi import Depends, FastAPI, Request
Expand All @@ -30,13 +28,16 @@
TaskLifecycleMiddleware,
WebsocketTaskLifecycleMiddleware,
)
from ..runner.common import FunctionContext, FunctionHandler, execute_lifecycle_method
from ..runner.common import (
FunctionContext,
FunctionHandler,
execute_lifecycle_method,
wait_for_checkpoint,
)
from ..runner.common import config as cfg
from ..type import LifeCycleMethod, TaskStatus
from .common import is_asgi3

workers_ready = Value("i", 0)


class EndpointFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
Expand Down Expand Up @@ -93,14 +94,9 @@ def post_fork_initialize(_, worker: UvicornWorker):
# Override the default starlette app
worker.app.callable = asgi_app

with workers_ready.get_lock():
workers_ready.value += 1
print(f"Worker PID {worker.pid} is ready")

print(f"workers_ready.value: {workers_ready.value}, cfg.workers: {cfg.workers}")
if workers_ready.value == cfg.workers:
print("creating READY_FOR_CHECKPOINT file")
Path("/cedana/READY_FOR_CHECKPOINT").touch(exist_ok=True)
# If checkpointing is enabled, wait for all workers to be ready before creating a checkpoint
if cfg.checkpoint_enabled:
wait_for_checkpoint()

except EOFError:
return
Expand Down
7 changes: 7 additions & 0 deletions sdk/src/beta9/runner/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
config,
execute_lifecycle_method,
send_callback,
wait_for_checkpoint,
)
from ..runner.common import config as cfg
from ..type import LifeCycleMethod, TaskExitCode, TaskStatus

TASK_PROCESS_WATCHDOG_INTERVAL = 0.01
Expand Down Expand Up @@ -272,6 +274,11 @@ def process_tasks(self, channel: Channel) -> None:
on_start_value = execute_lifecycle_method(name=LifeCycleMethod.OnStart)

print(f"Worker[{self.worker_index}] ready")

# If checkpointing is enabled, wait for all workers to be ready before creating a checkpoint
if cfg.checkpoint_enabled:
wait_for_checkpoint()

with ThreadPoolExecutorOverride() as thread_pool:
while True:
task = self._get_next_task(taskqueue_stub, config.stub_id, config.container_id)
Expand Down

0 comments on commit 3c4bd5a

Please sign in to comment.