Skip to content

Commit

Permalink
feat: PVC support for kubernetes (continuation of #9) (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
johanneskoester authored Aug 15, 2024
1 parent 1f8539d commit 33a6809
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ keywords = ["snakemake", "plugin", "executor", "cloud", "kubernetes"]

[tool.poetry.dependencies]
python = "^3.11"
snakemake-interface-common = "^1.14.1"
snakemake-interface-common = "^1.17.3"
snakemake-interface-executor-plugins = ">=9.0.0,<10.0.0"
kubernetes = ">=27.2.0,<30"

Expand Down
70 changes: 65 additions & 5 deletions snakemake_executor_plugin_kubernetes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import base64
from dataclasses import dataclass, field
from pathlib import Path
import shlex
import subprocess
import time
from typing import List, Generator, Optional
from typing import List, Generator, Optional, Self
import uuid

import kubernetes
Expand All @@ -23,10 +24,33 @@
from snakemake_interface_executor_plugins.settings import DeploymentMethod


# Optional:
# define additional settings for your executor
# They will occur in the Snakemake CLI as --<executor-name>-<param-name>
# Omit this class if you don't need any.
@dataclass
class PersistentVolume:
name: str
path: Path

@classmethod
def parse(cls, arg: str) -> Self:
spec = arg.split(":")
if len(spec) != 2:
raise WorkflowError(
f"Invalid persistent volume spec ({arg}), has to be <name>:<path>."
)
name, path = spec
return cls(name=name, path=Path(path))

def unparse(self) -> str:
return f"{self.name}:{self.path}"


def parse_persistent_volumes(args: List[str]) -> List[PersistentVolume]:
return [PersistentVolume.parse(arg) for arg in args]


def unparse_persistent_volumes(args: List[PersistentVolume]) -> List[str]:
return [arg.unparse() for arg in args]


@dataclass
class ExecutorSettings(ExecutorSettingsBase):
namespace: str = field(
Expand Down Expand Up @@ -57,6 +81,20 @@ class ExecutorSettings(ExecutorSettingsBase):
"when using Google Cloud GKE Autopilot."
},
)
privileged: Optional[bool] = field(
default=False,
metadata={"help": "Create privileged containers for jobs."},
)
persistent_volumes: List[PersistentVolume] = field(
default_factory=list,
metadata={
"help": "Mount the given persistent volumes under the given paths in each "
"job container (<name>:<path>). ",
"parse_func": parse_persistent_volumes,
"unparse_func": unparse_persistent_volumes,
"nargs": "+",
},
)


# Required:
Expand Down Expand Up @@ -104,6 +142,9 @@ def __post_init__(self):
self.log_path = self.workflow.persistence.aux_path / "kubernetes-logs"
self.log_path.mkdir(exist_ok=True, parents=True)
self.container_image = self.workflow.remote_execution_settings.container_image
self.privileged = self.workflow.executor_settings.privileged
self.persistent_volumes = self.workflow.executor_settings.persistent_volumes

self.logger.info(f"Using {self.container_image} for Kubernetes jobs.")

def run_job(self, job: JobExecutorInterface):
Expand Down Expand Up @@ -138,6 +179,10 @@ def run_job(self, job: JobExecutorInterface):
container.volume_mounts = [
kubernetes.client.V1VolumeMount(name="workdir", mount_path="/workdir"),
]
for pvc in self.persistent_volumes:
container.volume_mounts.append(
kubernetes.client.V1VolumeMount(name=pvc.name, mount_path=str(pvc.path))
)

node_selector = {}
if "machine_type" in job.resources.keys():
Expand All @@ -160,6 +205,15 @@ def run_job(self, job: JobExecutorInterface):
workdir_volume.empty_dir = kubernetes.client.V1EmptyDirVolumeSource()
body.spec.volumes = [workdir_volume]

for pvc in self.persistent_volumes:
volume = kubernetes.client.V1Volume(name=pvc.name)
volume.persistent_volume_claim = (
kubernetes.client.V1PersistentVolumeClaimVolumeSource(
claim_name=pvc.name
)
)
body.spec.volumes.append(volume)

# env vars
container.env = []
for key, e in self.secret_envvars.items():
Expand All @@ -185,6 +239,12 @@ def run_job(self, job: JobExecutorInterface):
disk_mb = int(job.resources.get("disk_mb", 1024))
container.resources.requests["ephemeral-storage"] = f"{disk_mb}M"

if self.privileged:
# allow privileged container so NFS can be used
container.security_context = kubernetes.client.V1SecurityContext(
privileged=True
)

self.logger.debug(f"k8s pod resources: {container.resources.requests}")

# capabilities
Expand Down
2 changes: 0 additions & 2 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,4 @@ def get_remote_execution_settings(
return snakemake.settings.types.RemoteExecutionSettings(
seconds_between_status_checks=10,
envvars=self.get_envvars(),
# TODO remove once we have switched to stable snakemake for dev
container_image="snakemake/snakemake:latest",
)

0 comments on commit 33a6809

Please sign in to comment.