diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py b/sdks/python/apache_beam/runners/dask/dask_runner.py index ed43c4fb5a727..60b9e43ed10eb 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + """DaskRunner, executing remote jobs on Dask.distributed. The DaskRunner is a runner implementation that executes a graph of @@ -39,7 +40,6 @@ class DaskOptions(PipelineOptions): - @staticmethod def _parse_timeout(candidate): try: @@ -51,37 +51,37 @@ def _parse_timeout(candidate): @classmethod def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None: parser.add_argument( - '--dask_client_address', - dest='address', - type=str, - default=None, - help='Address of a dask Scheduler server. Will default to a ' - '`dask.LocalCluster()`.') + '--dask_client_address', + dest='address', + type=str, + default=None, + help='Address of a dask Scheduler server. Will default to a ' + '`dask.LocalCluster()`.') parser.add_argument( - '--dask_connection_timeout', - dest='timeout', - type=DaskOptions._parse_timeout, - help='Timeout duration for initial connection to the scheduler.') + '--dask_connection_timeout', + dest='timeout', + type=DaskOptions._parse_timeout, + help='Timeout duration for initial connection to the scheduler.') parser.add_argument( - '--dask_scheduler_file', - type=str, - default=None, - help='Path to a file with scheduler information if available.') + '--dask_scheduler_file', + type=str, + default=None, + help='Path to a file with scheduler information if available.') # TODO(alxr): Add options for security. parser.add_argument( - '--dask_client_name', - dest='name', - type=str, - default=None, - help='Gives the client a name that will be included in logs generated on ' - 'the scheduler for matters relating to this client.') + '--dask_client_name', + dest='name', + type=str, + default=None, + help='Gives the client a name that will be included in logs generated on ' + 'the scheduler for matters relating to this client.') parser.add_argument( - '--dask_connection_limit', - dest='connection_limit', - type=int, - default=512, - help='The number of open comms to maintain at once in the connection ' - 'pool.') + '--dask_connection_limit', + dest='connection_limit', + type=int, + default=512, + help='The number of open comms to maintain at once in the connection ' + 'pool.') @dataclasses.dataclass @@ -120,7 +120,6 @@ def metrics(self): class DaskRunner(BundleBasedDirectRunner): """Executes a pipeline on a Dask distributed client.""" - @staticmethod def to_dask_bag_visitor() -> PipelineVisitor: from dask import bag as db @@ -168,7 +167,7 @@ def run_pipeline(self, pipeline, options): import dask.distributed as ddist except ImportError: raise ImportError( - 'DaskRunner is not available. Please install apache_beam[dask].') + 'DaskRunner is not available. Please install apache_beam[dask].') dask_options = options.view_as(DaskOptions).get_all_options( drop_default=True) diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py b/sdks/python/apache_beam/runners/dask/dask_runner_test.py index e4933eeb11cd6..a75b1f2fb94a2 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner_test.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py @@ -30,7 +30,6 @@ class DaskRunnerRunPipelineTest(unittest.TestCase): """Test class used to introspect the dask runner via a debugger.""" - def setUp(self) -> None: self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner()) @@ -40,7 +39,6 @@ def test_create(self): assert_that(pcoll, equal_to([1])) def test_create_and_map(self): - def double(x): return x * 2 @@ -49,18 +47,14 @@ def double(x): assert_that(pcoll, equal_to([2])) def test_create_map_and_groupby(self): - def double(x): return x * 2, x with self.pipeline as p: pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey() - assert_that(pcoll, equal_to([ - (2, [1]) - ])) + assert_that(pcoll, equal_to([(2, [1])])) def test_map_with_side_inputs(self): - def mult_by(x, y): return x * y diff --git a/sdks/python/apache_beam/runners/dask/overrides.py b/sdks/python/apache_beam/runners/dask/overrides.py index 0735eba99b039..7528e0132d5a1 100644 --- a/sdks/python/apache_beam/runners/dask/overrides.py +++ b/sdks/python/apache_beam/runners/dask/overrides.py @@ -45,7 +45,6 @@ def get_windowing(self, inputs: t.Any) -> beam.Windowing: @typehints.with_input_types(K) @typehints.with_output_types(K) class _Reshuffle(beam.PTransform): - def expand(self, input_or_inputs): return beam.pvalue.PCollection.from_(input_or_inputs) @@ -61,7 +60,6 @@ def expand(self, input_or_inputs): @typehints.with_input_types(t.Tuple[K, V]) @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupByKeyOnly(beam.PTransform): - def expand(self, input_or_inputs): return beam.pvalue.PCollection.from_(input_or_inputs) @@ -77,7 +75,6 @@ def infer_output_type(self, input_type): @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupAlsoByWindow(beam.ParDo): """Not used yet...""" - def __init__(self, windowing): super().__init__(_GroupAlsoByWindowDoFn(windowing)) self.windowing = windowing @@ -89,22 +86,18 @@ def expand(self, input_or_inputs): @typehints.with_input_types(t.Tuple[K, V]) @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupByKey(beam.PTransform): - def expand(self, input_or_inputs): return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly() class _Flatten(beam.PTransform): - def expand(self, input_or_inputs): is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs) return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded) def dask_overrides() -> t.List[PTransformOverride]: - class CreateOverride(PTransformOverride): - def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.Create @@ -113,7 +106,6 @@ def get_replacement_transform_for_applied_ptransform( return _Create(t.cast(beam.Create, applied_ptransform.transform).values) class ReshuffleOverride(PTransformOverride): - def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.Reshuffle @@ -122,7 +114,6 @@ def get_replacement_transform_for_applied_ptransform( return _Reshuffle() class ReadOverride(PTransformOverride): - def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.io.Read @@ -131,7 +122,6 @@ def get_replacement_transform_for_applied_ptransform( return _Read(t.cast(beam.io.Read, applied_ptransform.transform).source) class GroupByKeyOverride(PTransformOverride): - def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.GroupByKey @@ -140,7 +130,6 @@ def get_replacement_transform_for_applied_ptransform( return _GroupByKey() class FlattenOverride(PTransformOverride): - def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.Flatten diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py b/sdks/python/apache_beam/runners/dask/transform_evaluator.py index e6fb5c336554e..7ff31ef505cea 100644 --- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py @@ -51,13 +51,11 @@ def apply(self, input_bag: OpInput) -> db.Bag: class NoOp(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: return input_bag class Create(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: assert input_bag is None, 'Create expects no input!' original_transform = t.cast(_Create, self.applied.transform) @@ -66,29 +64,20 @@ def apply(self, input_bag: OpInput) -> db.Bag: class ParDo(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: transform = t.cast(apache_beam.ParDo, self.applied.transform) return input_bag.map( - transform.fn.process, - *transform.args, - **transform.kwargs - ).flatten() + transform.fn.process, *transform.args, **transform.kwargs).flatten() class Map(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: transform = t.cast(apache_beam.Map, self.applied.transform) return input_bag.map( - transform.fn.process, - *transform.args, - **transform.kwargs - ) + transform.fn.process, *transform.args, **transform.kwargs) class GroupByKey(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: def key(item): return item[0] @@ -101,7 +90,6 @@ def value(item): class Flatten(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: assert type(input_bag) is list, 'Must take a sequence of bags!' return db.concat(input_bag)