Skip to content

Commit

Permalink
Improve API documentation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714103523
  • Loading branch information
iindyk authored and copybara-github committed Jan 10, 2025
1 parent 10fdd0d commit 65b9b85
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 68 deletions.
6 changes: 6 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
autoapi_dirs = [
'../grain/_src/core',
'../grain/_src/python',
'../grain/python',
]

autoapi_options = [
'members',
'imported-members',
]

autoapi_ignore = [
Expand Down
41 changes: 15 additions & 26 deletions grain/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,24 @@ licenses(["notice"])

exports_files(["LICENSE"])

py_library(
name = "core",
srcs = ["core.py"],
srcs_version = "PY3",
# Implicit build flag
deps = [
"//grain/_src/core:config", # build_cleaner: keep
"//grain/_src/core:constants", # build_cleaner: keep
"//grain/_src/core:sharding", # build_cleaner: keep
],
)

py_library(
name = "python",
srcs = ["python.py"],
srcs = [
"__init__.py",
"_src/__init__.py",
"core/__init__.py",
"python/__init__.py",
"python/experimental.py",
"python/fast_proto.py",
],
srcs_version = "PY3",
# Implicit build flag
visibility = ["//visibility:public"],
deps = [
":core", # build_cleaner: keep
":python_experimental", # build_cleaner: keep
"//grain/_src/core:config", # build_cleaner: keep
"//grain/_src/core:constants", # build_cleaner: keep
"//grain/_src/core:monitoring", # build_cleaner: keep
"//grain/_src/core:sharding", # build_cleaner: keep
"//grain/_src/core:transforms", # build_cleaner: keep
"//grain/_src/python:checkpoint_handlers", # build_cleaner: keep
"//grain/_src/python:data_loader", # build_cleaner: keep
Expand All @@ -38,19 +33,13 @@ py_library(
"//grain/_src/python:load", # build_cleaner: keep
"//grain/_src/python:operations", # build_cleaner: keep
"//grain/_src/python:options", # build_cleaner: keep
"//grain/_src/python:record",
"//grain/_src/python:samplers", # build_cleaner: keep
],
)

py_library(
name = "python_experimental",
srcs = ["python_experimental.py"],
data = ["//grain/_src/python/experimental/index_shuffle/python:index_shuffle_module.so"],
srcs_version = "PY3",
# Implicit build flag
deps = [
"//grain/_src/core:transforms", # build_cleaner: keep
"//grain/_src/python:shared_memory_array",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
"//grain/_src/python/dataset:visualize", # build_cleaner: keep
"//grain/_src/python/dataset/transformations:core_transformations",
"//grain/_src/python/dataset/transformations:flatmap", # build_cleaner: keep
"//grain/_src/python/dataset/transformations:interleave", # build_cleaner: keep
"//grain/_src/python/dataset/transformations:packing", # build_cleaner: keep
Expand Down
Empty file added grain/__init__.py
Empty file.
Empty file added grain/_src/__init__.py
Empty file.
6 changes: 3 additions & 3 deletions grain/core.py → grain/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# pylint: disable=unused-import
# pylint: disable=g-importing-member

from ._src.core.config import config
from ._src.core.constants import (
from grain._src.core.config import config
from grain._src.core.constants import (
DATASET_INDEX,
EPOCH,
INDEX,
Expand All @@ -27,4 +27,4 @@
RECORD_KEY,
SEED,
)
from ._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
from grain._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
47 changes: 29 additions & 18 deletions grain/python.py → grain/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@
"""APIs for Grain Python backend."""

# pylint: disable=g-importing-member
# pylint: disable=g-import-not-at-top
# pylint: disable=g-multiple-import
# pylint: disable=unused-import
# pylint: disable=wildcard-import

from . import python_experimental as experimental

from ._src.core.config import config
from ._src.core.constants import DATASET_INDEX, EPOCH, INDEX, META_FEATURES, RECORD, RECORD_KEY, SEED
from ._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
from ._src.core.transforms import (
from grain._src.core.config import config
from grain._src.core.constants import (
DATASET_INDEX,
EPOCH,
INDEX,
META_FEATURES,
RECORD,
RECORD_KEY,
SEED,
)
from grain._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
from grain._src.core.transforms import (
BatchTransform as Batch,
FilterTransform,
MapTransform,
Expand All @@ -33,43 +40,47 @@
Transformations,
)

from ._src.python.checkpoint_handlers import PyGrainCheckpointHandler
from ._src.python.data_loader import (
from grain._src.python.checkpoint_handlers import PyGrainCheckpointHandler
from grain._src.python.data_loader import (
DataLoader,
PyGrainDatasetIterator,
)
from ._src.python.data_sources import (
from grain._src.python.data_sources import (
ArrayRecordDataSource,
InMemoryDataSource,
RandomAccessDataSource,
RangeDataSource,
)
from ._src.python.dataset.base import DatasetSelectionMap
from ._src.python.dataset.dataset import (
from grain._src.python.dataset.base import DatasetSelectionMap
from grain._src.python.dataset.dataset import (
MapDataset,
IterDataset,
DatasetIterator,
)

from ._src.python.load import load
from ._src.python.operations import (
from grain._src.python.load import load
from grain._src.python.operations import (
BatchOperation,
FilterOperation,
MapOperation,
Operation,
RandomMapOperation,
)
from ._src.python.options import ReadOptions, MultiprocessingOptions
from ._src.python.record import (Record, RecordMetadata)
from ._src.python.samplers import (
from grain._src.python.options import ReadOptions, MultiprocessingOptions
from grain._src.python.record import (Record, RecordMetadata)
from grain._src.python.samplers import (
IndexSampler,
Sampler,
SequentialSampler,
)
from ._src.python.shared_memory_array import SharedMemoryArray
from grain._src.python.shared_memory_array import SharedMemoryArray
from grain.python import experimental, fast_proto

# These are imported only if Orbax is present.
try:
from ._src.python.checkpoint_handlers import PyGrainCheckpointSave, PyGrainCheckpointRestore # pylint: disable=g-import-not-at-top
from grain._src.python.checkpoint_handlers import (
PyGrainCheckpointSave,
PyGrainCheckpointRestore,
)
except ImportError:
pass
38 changes: 22 additions & 16 deletions grain/python_experimental.py → grain/python/experimental.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google LLC
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,36 +17,42 @@
# pylint: disable=g-bad-import-order
# pylint: disable=g-multiple-import
# pylint: disable=unused-import
# pylint: disable=wildcard-import
# pylint: disable=g-import-not-at-top

from ._src.python.dataset.base import (
from grain._src.python.dataset.base import (
DatasetOptions,
ExecutionTrackingMode,
)
from ._src.python.dataset.dataset import (
from grain._src.python.dataset.dataset import (
apply_transformations,
WithOptionsIterDataset,
)
from ._src.python.dataset.transformations.flatmap import (
from grain._src.python.dataset.transformations.flatmap import (
FlatMapMapDataset,
FlatMapIterDataset,
)
from ._src.python.dataset.transformations.interleave import (
from grain._src.python.dataset.transformations.interleave import (
InterleaveIterDataset,
)
from ._src.python.dataset.transformations.map import RngPool
from ._src.python.dataset.transformations.mix import ConcatenateMapDataset
from ._src.python.dataset.transformations.packing import FirstFitPackIterDataset
from ._src.python.dataset.transformations.prefetch import (
from grain._src.python.dataset.transformations.map import RngPool
from grain._src.python.dataset.transformations.mix import ConcatenateMapDataset
from grain._src.python.dataset.transformations.packing import (
FirstFitPackIterDataset,
)
from grain._src.python.dataset.transformations.prefetch import (
MultiprocessPrefetchIterDataset,
ThreadPrefetchIterDataset,
)
from ._src.python.dataset.transformations.shuffle import WindowShuffleMapDataset
from ._src.python.dataset.transformations.zip import ZipMapDataset
from ._src.core.transforms import (
from grain._src.python.dataset.transformations.shuffle import (
WindowShuffleMapDataset,
)
from grain._src.python.dataset.transformations.zip import ZipMapDataset
from grain._src.core.transforms import (
FlatMapTransform,
MapWithIndexTransform,
)
from ._src.python.experimental.example_packing.packing import PackAndBatchOperation
from ._src.python.experimental.index_shuffle.python.index_shuffle_module import index_shuffle
from grain._src.python.experimental.example_packing.packing import (
PackAndBatchOperation,
)
from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import (
index_shuffle,
)
10 changes: 5 additions & 5 deletions grain/python_proto.py → grain/python/fast_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
"""Proto PyGrain APIs."""

# pylint: disable=g-importing-member
# pylint: disable=g-bad-import-order
# pylint: disable=g-multiple-import
# pylint: disable=unused-import
# pylint: disable=wildcard-import

from ._src.python.proto.decode import parse_tf_example
from ._src.python.proto.decode import parse_tf_example_experimental
from ._src.python.proto.encode import make_tf_example
from grain._src.python.proto.decode import (
parse_tf_example,
parse_tf_example_experimental,
)
from grain._src.python.proto.encode import make_tf_example

0 comments on commit 65b9b85

Please sign in to comment.