Skip to content

Commit

Permalink
Applied yapf from tox.
Browse files Browse the repository at this point in the history
  • Loading branch information
alxmrs committed Oct 9, 2022
1 parent 8ad9813 commit 1b6ec0f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 61 deletions.
57 changes: 28 additions & 29 deletions sdks/python/apache_beam/runners/dask/dask_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""DaskRunner, executing remote jobs on Dask.distributed.
The DaskRunner is a runner implementation that executes a graph of
Expand All @@ -39,7 +40,6 @@


class DaskOptions(PipelineOptions):

@staticmethod
def _parse_timeout(candidate):
try:
Expand All @@ -51,37 +51,37 @@ def _parse_timeout(candidate):
@classmethod
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'--dask_client_address',
dest='address',
type=str,
default=None,
help='Address of a dask Scheduler server. Will default to a '
'`dask.LocalCluster()`.')
'--dask_client_address',
dest='address',
type=str,
default=None,
help='Address of a dask Scheduler server. Will default to a '
'`dask.LocalCluster()`.')
parser.add_argument(
'--dask_connection_timeout',
dest='timeout',
type=DaskOptions._parse_timeout,
help='Timeout duration for initial connection to the scheduler.')
'--dask_connection_timeout',
dest='timeout',
type=DaskOptions._parse_timeout,
help='Timeout duration for initial connection to the scheduler.')
parser.add_argument(
'--dask_scheduler_file',
type=str,
default=None,
help='Path to a file with scheduler information if available.')
'--dask_scheduler_file',
type=str,
default=None,
help='Path to a file with scheduler information if available.')
# TODO(alxr): Add options for security.
parser.add_argument(
'--dask_client_name',
dest='name',
type=str,
default=None,
help='Gives the client a name that will be included in logs generated on '
'the scheduler for matters relating to this client.')
'--dask_client_name',
dest='name',
type=str,
default=None,
help='Gives the client a name that will be included in logs generated on '
'the scheduler for matters relating to this client.')
parser.add_argument(
'--dask_connection_limit',
dest='connection_limit',
type=int,
default=512,
help='The number of open comms to maintain at once in the connection '
'pool.')
'--dask_connection_limit',
dest='connection_limit',
type=int,
default=512,
help='The number of open comms to maintain at once in the connection '
'pool.')


@dataclasses.dataclass
Expand Down Expand Up @@ -120,7 +120,6 @@ def metrics(self):

class DaskRunner(BundleBasedDirectRunner):
"""Executes a pipeline on a Dask distributed client."""

@staticmethod
def to_dask_bag_visitor() -> PipelineVisitor:
from dask import bag as db
Expand Down Expand Up @@ -168,7 +167,7 @@ def run_pipeline(self, pipeline, options):
import dask.distributed as ddist
except ImportError:
raise ImportError(
'DaskRunner is not available. Please install apache_beam[dask].')
'DaskRunner is not available. Please install apache_beam[dask].')

dask_options = options.view_as(DaskOptions).get_all_options(
drop_default=True)
Expand Down
8 changes: 1 addition & 7 deletions sdks/python/apache_beam/runners/dask/dask_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

class DaskRunnerRunPipelineTest(unittest.TestCase):
"""Test class used to introspect the dask runner via a debugger."""

def setUp(self) -> None:
self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner())

Expand All @@ -40,7 +39,6 @@ def test_create(self):
assert_that(pcoll, equal_to([1]))

def test_create_and_map(self):

def double(x):
return x * 2

Expand All @@ -49,18 +47,14 @@ def double(x):
assert_that(pcoll, equal_to([2]))

def test_create_map_and_groupby(self):

def double(x):
return x * 2, x

with self.pipeline as p:
pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey()
assert_that(pcoll, equal_to([
(2, [1])
]))
assert_that(pcoll, equal_to([(2, [1])]))

def test_map_with_side_inputs(self):

def mult_by(x, y):
return x * y

Expand Down
11 changes: 0 additions & 11 deletions sdks/python/apache_beam/runners/dask/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def get_windowing(self, inputs: t.Any) -> beam.Windowing:
@typehints.with_input_types(K)
@typehints.with_output_types(K)
class _Reshuffle(beam.PTransform):

def expand(self, input_or_inputs):
return beam.pvalue.PCollection.from_(input_or_inputs)

Expand All @@ -61,7 +60,6 @@ def expand(self, input_or_inputs):
@typehints.with_input_types(t.Tuple[K, V])
@typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
class _GroupByKeyOnly(beam.PTransform):

def expand(self, input_or_inputs):
return beam.pvalue.PCollection.from_(input_or_inputs)

Expand All @@ -77,7 +75,6 @@ def infer_output_type(self, input_type):
@typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
class _GroupAlsoByWindow(beam.ParDo):
"""Not used yet..."""

def __init__(self, windowing):
super().__init__(_GroupAlsoByWindowDoFn(windowing))
self.windowing = windowing
Expand All @@ -89,22 +86,18 @@ def expand(self, input_or_inputs):
@typehints.with_input_types(t.Tuple[K, V])
@typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
class _GroupByKey(beam.PTransform):

def expand(self, input_or_inputs):
return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly()


class _Flatten(beam.PTransform):

def expand(self, input_or_inputs):
is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs)
return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded)


def dask_overrides() -> t.List[PTransformOverride]:

class CreateOverride(PTransformOverride):

def matches(self, applied_ptransform: AppliedPTransform) -> bool:
return applied_ptransform.transform.__class__ == beam.Create

Expand All @@ -113,7 +106,6 @@ def get_replacement_transform_for_applied_ptransform(
return _Create(t.cast(beam.Create, applied_ptransform.transform).values)

class ReshuffleOverride(PTransformOverride):

def matches(self, applied_ptransform: AppliedPTransform) -> bool:
return applied_ptransform.transform.__class__ == beam.Reshuffle

Expand All @@ -122,7 +114,6 @@ def get_replacement_transform_for_applied_ptransform(
return _Reshuffle()

class ReadOverride(PTransformOverride):

def matches(self, applied_ptransform: AppliedPTransform) -> bool:
return applied_ptransform.transform.__class__ == beam.io.Read

Expand All @@ -131,7 +122,6 @@ def get_replacement_transform_for_applied_ptransform(
return _Read(t.cast(beam.io.Read, applied_ptransform.transform).source)

class GroupByKeyOverride(PTransformOverride):

def matches(self, applied_ptransform: AppliedPTransform) -> bool:
return applied_ptransform.transform.__class__ == beam.GroupByKey

Expand All @@ -140,7 +130,6 @@ def get_replacement_transform_for_applied_ptransform(
return _GroupByKey()

class FlattenOverride(PTransformOverride):

def matches(self, applied_ptransform: AppliedPTransform) -> bool:
return applied_ptransform.transform.__class__ == beam.Flatten

Expand Down
16 changes: 2 additions & 14 deletions sdks/python/apache_beam/runners/dask/transform_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,11 @@ def apply(self, input_bag: OpInput) -> db.Bag:


class NoOp(DaskBagOp):

def apply(self, input_bag: OpInput) -> db.Bag:
return input_bag


class Create(DaskBagOp):

def apply(self, input_bag: OpInput) -> db.Bag:
assert input_bag is None, 'Create expects no input!'
original_transform = t.cast(_Create, self.applied.transform)
Expand All @@ -66,29 +64,20 @@ def apply(self, input_bag: OpInput) -> db.Bag:


class ParDo(DaskBagOp):

def apply(self, input_bag: OpInput) -> db.Bag:
transform = t.cast(apache_beam.ParDo, self.applied.transform)
return input_bag.map(
transform.fn.process,
*transform.args,
**transform.kwargs
).flatten()
transform.fn.process, *transform.args, **transform.kwargs).flatten()


class Map(DaskBagOp):

def apply(self, input_bag: OpInput) -> db.Bag:
transform = t.cast(apache_beam.Map, self.applied.transform)
return input_bag.map(
transform.fn.process,
*transform.args,
**transform.kwargs
)
transform.fn.process, *transform.args, **transform.kwargs)


class GroupByKey(DaskBagOp):

def apply(self, input_bag: OpInput) -> db.Bag:
def key(item):
return item[0]
Expand All @@ -101,7 +90,6 @@ def value(item):


class Flatten(DaskBagOp):

def apply(self, input_bag: OpInput) -> db.Bag:
assert type(input_bag) is list, 'Must take a sequence of bags!'
return db.concat(input_bag)
Expand Down

0 comments on commit 1b6ec0f

Please sign in to comment.