Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP for modal integration #623

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions examples/parallelism/modal/h_modal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from types import FunctionType
from typing import Any, Callable, Dict

import modal

from hamilton.execution import executors
from hamilton.execution.executors import (
ExecutionManager,
SynchronousLocalTaskExecutor,
TaskExecutor,
TaskFuture,
base_execute_task,
)
from hamilton.execution.grouping import TaskImplementation
from hamilton.execution.state import TaskState
from hamilton.function_modifiers import tag

MODAL = "modal"


#
# stub = modal.Stub()

#
# @stub.function()
# def globally_scoped(a: int):
# return 1
#
# # return base_execute_task(task)
#
# class TaskExecutor:
# def __init__(self, task: TaskImplementation):
# self.task = task
#
# def __enter__(self):
# pass
#
# @method()
# def run(self):
# return base_execute_task(self.task)
#


class ModalExecutor(executors.TaskExecutor):
def __init__(self, stub_params: Dict[str, Any], global_function_params: Dict[str, Any] = None):
self.stub = modal.Stub(**stub_params)
self.run = self.stub.run()
self.global_function_params = global_function_params or {}

def init(self):
"""Initializes the modal executor, by entering the context manager.
TODO -- be more specific about this -- we could enter and exit multiple times?
"""

def finalize(self):
"""Finalizes the modal executor, by ending the context manager.

@return:
"""
# self.run.__exit__(None, None, None)

def submit_task(self, task: TaskImplementation) -> TaskFuture:
function_params = dict(name=task.task_id, serialized=True, **self.global_function_params)
function_to_submit = self.stub.function(**function_params)(base_execute_task)

with self.stub.run():
result = function_to_submit.remote(task=task)
return TaskFuture(get_state=lambda: TaskState.SUCCESSFUL, get_result=lambda: result)

def can_submit_task(self) -> bool:
return True


class RemoteExecutionManager(ExecutionManager):
TAG_KEY = "remote"

def __init__(self, **remote_executors: TaskExecutor):
self.local_executor = SynchronousLocalTaskExecutor()
self.remote_executors = remote_executors
super(RemoteExecutionManager, self).__init__(
[self.local_executor] + list(remote_executors.values())
)

def get_executor_for_task(self, task: TaskImplementation) -> TaskExecutor:
"""Simple implementation that returns the local executor for single task executions,

:param task: Task to get executor for
:return: A local task if this is a "single-node" task, a remote task otherwise
"""
is_single_node_task = len(task.nodes) == 1
if not is_single_node_task:
raise ValueError("Only single node tasks supported")
(node,) = task.nodes
tag_value = node.tags.get(self.TAG_KEY)
if tag_value is None:
return self.local_executor
if tag_value in self.remote_executors:
return self.remote_executors[tag_value]
raise ValueError(f"Unknown remote source {tag_value}")


# class ModalExecutionManager(DelegatingExecutionManager):
# def __init__(self, stub_params: Dict[str, Any], global_function_params: Dict[str, Any] = None):
# remote_executor = ModalExecutor(stub_params)
# super().__init__(remote_executor, remote_tag_value=MODAL)


def remote(source: str) -> Callable[[FunctionType], FunctionType]:
def decorator(fn: FunctionType) -> FunctionType:
return tag(remote=source)(fn)

return decorator
26 changes: 26 additions & 0 deletions examples/parallelism/modal/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import modal
import simple_pipeline
from h_modal import ModalExecutor, RemoteExecutionManager

from hamilton import driver


def test_simple_pipeline():
image = modal.Image.debian_slim().pip_install("sf-hamilton")
dr = (
driver.Builder()
.with_modules(simple_pipeline)
.enable_dynamic_execution(allow_experimental_mode=True)
.with_execution_manager(
RemoteExecutionManager(
modal=ModalExecutor(stub_params={}, global_function_params={"image": image})
)
)
.build()
)
result = dr.execute(final_vars=["locally_gathered_data"])
print(result)


if __name__ == "__main__":
test_simple_pipeline()
19 changes: 19 additions & 0 deletions examples/parallelism/modal/simple_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pandas as pd
from h_modal import remote


def data() -> pd.DataFrame:
return pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@remote(source="modal")
def remote_processed_data(data: pd.DataFrame) -> pd.DataFrame:
import time

time.sleep(1)
return data**2


def locally_gathered_data(remote_processed_data: pd.DataFrame) -> pd.DataFrame:
print(remote_processed_data)
return remote_processed_data
2 changes: 2 additions & 0 deletions hamilton/execution/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ def run_graph_to_completion(
execution_state.update_task_state(task_name, state, result)
if TaskState.is_terminal(state):
del task_futures[task_name]
if TaskState.is_failure(state):
raise RuntimeError(f"Task {task_name} failed to execute.")
logger.info(f"Graph is done, graph state is {execution_state.get_graph_state()}")
finally:
execution_manager.finalize()
4 changes: 4 additions & 0 deletions hamilton/execution/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class TaskState(enum.Enum):
def is_terminal(task_state: "TaskState") -> bool:
return task_state in [TaskState.SUCCESSFUL, TaskState.FAILED]

@staticmethod
def is_failure(task_state: "TaskState") -> bool:
return task_state in [TaskState.FAILED]


# TODO -- determine a better set of states for the graph
GraphState = TaskState
Expand Down