Skip to content

Commit

Permalink
Merge branch 'flyteorg:master' into flyteremote-interruptible-override
Browse files Browse the repository at this point in the history
  • Loading branch information
redartera authored Jan 28, 2025
2 parents 73cae9e + 4208a64 commit a51c2fe
Show file tree
Hide file tree
Showing 25 changed files with 1,113 additions and 69 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os:
- ubuntu-24.04-arm
- ubuntu-latest
- windows-latest
- macos-latest
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -78,7 +82,11 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os:
- ubuntu-24.04-arm
- ubuntu-latest
- windows-latest
- macos-latest
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
Expand Down
11 changes: 9 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
additional_dependencies:
- tomli
- repo: https://github.com/jsh9/pydoclint
rev: 0.6.0
hooks:
- id: pydoclint
args:
- --ignore-words-list=assertIn # Ignore 'assertIn'
additional_dependencies: [tomli]
- --style=google
- --exclude='.git|tests/flytekit/*|tests/'
- --baseline=pydoclint-errors-baseline.txt
29 changes: 29 additions & 0 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@

Uploadable = typing.Union[str, os.PathLike, pathlib.Path, bytes, io.BufferedReader, io.BytesIO, io.StringIO]

# This is the default chunk size flytekit will use for writing to S3 and GCS. This is set to 25MB by default and is
# configurable by the user if needed. This is used when put() is called on filesystems.
_WRITE_SIZE_CHUNK_BYTES = int(os.environ.get("_F_P_WRITE_CHUNK_SIZE", "26214400")) # 25 * 2**20


def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {
Expand Down Expand Up @@ -108,6 +112,27 @@ def get_fsspec_storage_options(
return {}


def get_additional_fsspec_call_kwargs(protocol: typing.Union[str, tuple], method_name: str) -> Dict[str, Any]:
"""
These are different from the setup args functions defined above. Those kwargs are applied when asking fsspec
to create the filesystem. These kwargs returned here are for when the filesystem's methods are invoked.
:param protocol: s3, gcs, etc.
:param method_name: Pass in the __name__ of the fsspec.filesystem function. _'s will be ignored.
"""
kwargs = {}
method_name = method_name.replace("_", "")
if isinstance(protocol, tuple):
protocol = protocol[0]

# For s3fs and gcsfs, we feel the default chunksize of 50MB is too big.
# Re-evaluate these kwargs when we move off of s3fs to obstore.
if method_name == "put" and protocol in ["s3", "gs"]:
kwargs["chunksize"] = _WRITE_SIZE_CHUNK_BYTES

return kwargs


@decorator
def retry_request(func, *args, **kwargs):
# TODO: Remove this method once s3fs has a new release. https://github.com/fsspec/s3fs/pull/865
Expand Down Expand Up @@ -353,6 +378,10 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw
if "metadata" not in kwargs:
kwargs["metadata"] = {}
kwargs["metadata"].update(self._execution_metadata)

additional_kwargs = get_additional_fsspec_call_kwargs(file_system.protocol, file_system.put.__name__)
kwargs.update(additional_kwargs)

if isinstance(file_system, AsyncFileSystem):
dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
else:
Expand Down
12 changes: 6 additions & 6 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union
from typing import List, Optional, Union

from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
from mashumaro.mixins.json import DataClassJSONMixin
Expand Down Expand Up @@ -103,11 +103,11 @@ def convert_resources_to_resource_model(


def pod_spec_from_resources(
k8s_pod_name: str,
primary_container_name: Optional[str] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
k8s_gpu_resource_key: str = "nvidia.com/gpu",
) -> dict[str, Any]:
) -> V1PodSpec:
def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str):
if resources is None:
return None
Expand All @@ -133,10 +133,10 @@ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resour
requests = requests or limits
limits = limits or requests

k8s_pod = V1PodSpec(
pod_spec = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
name=primary_container_name,
resources=V1ResourceRequirements(
requests=requests,
limits=limits,
Expand All @@ -145,4 +145,4 @@ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resour
]
)

return k8s_pod.to_dict()
return pod_spec
35 changes: 16 additions & 19 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
DEFINITIONS = "definitions"
TITLE = "title"

_TYPE_ENGINE_COROS_BATCH_SIZE = int(os.environ.get("_F_TE_MAX_COROS", "10"))


# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
# This is relevant for cases like Dict[int, str].
Expand Down Expand Up @@ -1359,7 +1361,7 @@ def to_literal(
) -> Literal:
"""
The current dance is because we are allowing users to call from an async function, this synchronous
to_literal function, and allowing this to_literal function, to then invoke yet another async functionl,
to_literal function, and allowing this to_literal function, to then invoke yet another async function,
namely an async transformer.
"""
from flytekit.core.promise import Promise
Expand Down Expand Up @@ -1686,10 +1688,9 @@ async def async_to_literal(
raise TypeTransformerFailedError("Expected a list")

t = self.get_sub_type(python_type)
lit_list = [
asyncio.create_task(TypeEngine.async_to_literal(ctx, x, t, expected.collection_type)) for x in python_val
]
lit_list = await _run_coros_in_chunks(lit_list)
lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val]

lit_list = await _run_coros_in_chunks(lit_list, batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)

return Literal(collection=LiteralCollection(literals=lit_list))

Expand All @@ -1711,7 +1712,7 @@ async def async_to_python_value( # type: ignore

st = self.get_sub_type(expected_python_type)
result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits]
result = await _run_coros_in_chunks(result)
result = await _run_coros_in_chunks(result, batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
return result # type: ignore # should be a list, thinks its a tuple

def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
Expand Down Expand Up @@ -1907,7 +1908,7 @@ async def async_to_literal(
res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name)
found_res = True
except Exception as e:
logger.warning(
logger.debug(
f"UnionTransformer failed attempt to convert from {python_val} to {t} error: {e}",
)
continue
Expand Down Expand Up @@ -2158,13 +2159,10 @@ async def async_to_literal(
else:
_, v_type = self.extract_types_or_metadata(python_type)

lit_map[k] = asyncio.create_task(
TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
)

await _run_coros_in_chunks([c for c in lit_map.values()])
for k, v in lit_map.items():
lit_map[k] = v.result()
lit_map[k] = TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
vals = await _run_coros_in_chunks([c for c in lit_map.values()], batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
for idx, k in zip(range(len(vals)), lit_map.keys()):
lit_map[k] = vals[idx]

return Literal(map=LiteralMap(literals=lit_map))

Expand All @@ -2185,12 +2183,11 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key")
py_map = {}
for k, v in lv.map.literals.items():
fut = asyncio.create_task(TypeEngine.async_to_python_value(ctx, v, cast(Type, tp[1])))
py_map[k] = fut
py_map[k] = TypeEngine.async_to_python_value(ctx, v, cast(Type, tp[1]))

await _run_coros_in_chunks([c for c in py_map.values()])
for k, v in py_map.items():
py_map[k] = v.result()
vals = await _run_coros_in_chunks([c for c in py_map.values()], batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
for idx, k in zip(range(len(vals)), py_map.keys()):
py_map[k] = vals[idx]

return py_map

Expand Down
3 changes: 3 additions & 0 deletions flytekit/image_spec/default_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def prepare_python_install(image_spec: ImageSpec, tmp_dir: Path) -> str:
extra_urls = [f"--extra-index-url {url}" for url in image_spec.pip_extra_index_url]
pip_install_args.extend(extra_urls)

if image_spec.pip_extra_args:
pip_install_args.append(image_spec.pip_extra_args)

requirements = []
if image_spec.requirements:
requirement_basename = os.path.basename(image_spec.requirements)
Expand Down
2 changes: 2 additions & 0 deletions flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ImageSpec:
platform: Specify the target platforms for the build output (for example, windows/amd64 or linux/amd64,darwin/arm64
pip_index: Specify the custom pip index url
pip_extra_index_url: Specify one or more pip index urls as a list
pip_extra_args: Specify one or more extra pip install arguments as a space-delimited string
registry_config: Specify the path to a JSON registry config file
entrypoint: List of strings to overwrite the entrypoint of the base image with, set to [] to remove the entrypoint.
commands: Command to run during the building process
Expand Down Expand Up @@ -82,6 +83,7 @@ class ImageSpec:
platform: str = "linux/amd64"
pip_index: Optional[str] = None
pip_extra_index_url: Optional[List[str]] = None
pip_extra_args: Optional[str] = None
registry_config: Optional[str] = None
entrypoint: Optional[List[str]] = None
commands: Optional[List[str]] = None
Expand Down
6 changes: 6 additions & 0 deletions flytekit/models/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class Secret(_common.FlyteIdlEntity):
key is optional and can be an individual secret identifier within the secret For k8s this is required
version is the version of the secret. This is an optional field
mount_requirement provides a hint to the system as to how the secret should be injected
env_var is optional. Custom environment name to set the value of the secret.
If mount_requirement is ENV_VAR, then the value is the secret itself.
If mount_requirement is FILE, then the value is the path to the secret file.
"""

class MountType(Enum):
Expand All @@ -39,6 +42,7 @@ class MountType(Enum):
key: Optional[str] = None
group_version: Optional[str] = None
mount_requirement: MountType = MountType.ANY
env_var: Optional[str] = None

def __post_init__(self):
from flytekit.configuration.plugin import get_plugin
Expand All @@ -56,6 +60,7 @@ def to_flyte_idl(self) -> _sec.Secret:
group_version=self.group_version,
key=self.key,
mount_requirement=self.mount_requirement.value,
env_var=self.env_var,
)

@classmethod
Expand All @@ -65,6 +70,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Secret) -> "Secret":
group_version=pb2_object.group_version if pb2_object.group_version else None,
key=pb2_object.key if pb2_object.key else None,
mount_requirement=Secret.MountType(pb2_object.mount_requirement),
env_var=pb2_object.env_var if pb2_object.env_var else None,
)


Expand Down
20 changes: 20 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flyteidl.core import tasks_pb2 as _core_task
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from kubernetes.client import ApiClient

from flytekit.models import common as _common
from flytekit.models import interface as _interface
Expand All @@ -16,6 +17,9 @@
from flytekit.models.core import identifier as _identifier
from flytekit.models.documentation import Documentation

if typing.TYPE_CHECKING:
from flytekit import PodTemplate


class Resources(_common.FlyteIdlEntity):
class ResourceName(object):
Expand Down Expand Up @@ -1042,6 +1046,22 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sPod):
else None,
)

def to_pod_template(self) -> "PodTemplate":
from flytekit import PodTemplate

return PodTemplate(
labels=self.metadata.labels,
annotations=self.metadata.annotations,
pod_spec=self.pod_spec,
)

@classmethod
def from_pod_template(cls, pod_template: "PodTemplate") -> "K8sPod":
return cls(
metadata=K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations),
pod_spec=ApiClient().sanitize_for_serialization(pod_template.pod_spec),
)


class Sql(_common.FlyteIdlEntity):
class Dialect(object):
Expand Down
6 changes: 6 additions & 0 deletions flytekit/remote/remote_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,19 @@ def _upload_chunk(self, final=False):
"""Only uploads the file at once from the buffer.
Not suitable for large files as the buffer will blow the memory for very large files.
Suitable for default values or local dataframes being uploaded all at once.
This function is called by fsspec.flush(). This will create a new file upload location.
"""
if final is False:
return False
self.buffer.seek(0)
data = self.buffer.read()

try:
# The inputs here are flipped a bit, it should be the filename is set to the filename and the filename root
# is something deterministic, like a hash. But since this is supposed to mimic open(), we can't hash.
# With the args currently below, the backend will create a random suffix for the filename.
# Since no hash is set on it, we will not be able to write to it again (which is totally fine).
res = self._remote.client.get_upload_signed_url(
self._remote.default_project,
self._remote.default_domain,
Expand Down
12 changes: 9 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from flytekit import lazy_module
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, modify_literal_uris
from flytekit.deck.renderer import Renderable
from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator
from flytekit.loggers import developer_logger, logger
Expand Down Expand Up @@ -63,6 +63,7 @@ class (that is just a model, a Python class representation of the protobuf).
file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String()))

def _serialize(self) -> Dict[str, Optional[str]]:
# dataclass case
lv = StructuredDatasetTransformerEngine().to_literal(
FlyteContextManager.current_context(), self, type(self), None
)
Expand All @@ -85,7 +86,7 @@ def _deserialize(cls, value) -> "StructuredDataset":
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
structured_dataset=literals.StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=file_format)
),
Expand All @@ -98,6 +99,7 @@ def _deserialize(cls, value) -> "StructuredDataset":

@model_serializer
def serialize_structured_dataset(self) -> Dict[str, Optional[str]]:
# pydantic case
lv = StructuredDatasetTransformerEngine().to_literal(
FlyteContextManager.current_context(), self, type(self), None
)
Expand All @@ -117,7 +119,7 @@ def deserialize_structured_dataset(self, info) -> StructuredDataset:
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
structured_dataset=literals.StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=self.file_format)
),
Expand Down Expand Up @@ -807,6 +809,10 @@ def encode(
# with a format of "" is used.
sd_model.metadata._structured_dataset_type.format = handler.supported_format
lit = Literal(scalar=Scalar(structured_dataset=sd_model))

# Because the handler.encode may have uploaded something, and because the sd may end up living inside a
# dataclass, we need to modify any uploaded flyte:// urls here.
modify_literal_uris(lit)
sd._literal_sd = sd_model
sd._already_uploaded = True
return lit
Expand Down
Loading

0 comments on commit a51c2fe

Please sign in to comment.