Skip to content

Commit

Permalink
2024-03-03 nightly release (42874a1)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Mar 3, 2024
1 parent 0b14029 commit bd460e5
Show file tree
Hide file tree
Showing 88 changed files with 939 additions and 736 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
python-version: 3.8
architecture: x64
packages: |
ufmt==1.3.2
black==22.3.0
usort==1.0.2
ufmt==2.5.1
black==24.2.0
usort==1.0.8
- name: Checkout Torchrec
uses: actions/checkout@v2
- name: Run pre-commit
Expand Down
4 changes: 2 additions & 2 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ init_command = [
'python3',
'tools/lint/pip_init.py',
'--dry-run={{DRYRUN}}',
'black==22.3.0',
'black==24.2.0',
]
is_formatter = true

Expand All @@ -28,6 +28,6 @@ init_command = [
'python3',
'tools/lint/pip_init.py',
'--dry-run={{DRYRUN}}',
'usort==1.0.2',
'usort==1.0.8',
]
is_formatter = true
8 changes: 5 additions & 3 deletions examples/inference/dlrm_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.model_config.embedding_dim,
num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
if self.model_config.num_embeddings is None
else self.model_config.num_embeddings,
num_embeddings=(
self.model_config.num_embeddings_per_feature[feature_idx]
if self.model_config.num_embeddings is None
else self.model_config.num_embeddings
),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(
Expand Down
8 changes: 5 additions & 3 deletions examples/inference/dlrm_predict_single_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.model_config.embedding_dim,
num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
if self.model_config.num_embeddings is None
else self.model_config.num_embeddings,
num_embeddings=(
self.model_config.num_embeddings_per_feature[feature_idx]
if self.model_config.num_embeddings is None
else self.model_config.num_embeddings
),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(
Expand Down
3 changes: 2 additions & 1 deletion examples/nvt_dataloader/nvt_binary_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def __getitem__(self, idx: int):
"""Numerical features are returned in the order they appear in the channel spec section
For performance reasons, this is required to be the order they are saved in, as specified
by the relevant chunk in source spec.
Categorical features are returned in the order they appear in the channel spec section"""
Categorical features are returned in the order they appear in the channel spec section
"""

if idx >= self._num_entries:
raise IndexError()
Expand Down
22 changes: 13 additions & 9 deletions examples/nvt_dataloader/train_torchrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,11 @@ def main(argv: List[str]):
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=args.embedding_dim,
num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx]
if num_embeddings_per_feature is not None
else args.num_embeddings,
num_embeddings=(
none_throws(num_embeddings_per_feature)[feature_idx]
if num_embeddings_per_feature is not None
else args.num_embeddings
),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
Expand All @@ -232,9 +234,9 @@ def main(argv: List[str]):

train_model = fuse_embedding_optimizer(
train_model,
optimizer_type=torchrec.optim.RowWiseAdagrad
if args.adagrad
else torch.optim.SGD,
optimizer_type=(
torchrec.optim.RowWiseAdagrad if args.adagrad else torch.optim.SGD
),
optimizer_kwargs={"learning_rate": args.learning_rate},
device=torch.device("meta"),
)
Expand Down Expand Up @@ -270,9 +272,11 @@ def main(argv: List[str]):

non_fused_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.Adagrad(params, lr=args.learning_rate)
if args.adagrad
else torch.optim.SGD(params, lr=args.learning_rate),
lambda params: (
torch.optim.Adagrad(params, lr=args.learning_rate)
if args.adagrad
else torch.optim.SGD(params, lr=args.learning_rate)
),
)

opt = trec_optim.keyed.CombinedOptimizer(
Expand Down
12 changes: 6 additions & 6 deletions examples/retrieval/modules/two_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def __init__(
embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[
0
].embedding_dim
self._feature_names_query: List[
str
] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
self._candidate_feature_names: List[
str
] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
self._feature_names_query: List[str] = (
embedding_bag_collection.embedding_bag_configs()[0].feature_names
)
self._candidate_feature_names: List[str] = (
embedding_bag_collection.embedding_bag_configs()[1].feature_names
)
self.ebc = embedding_bag_collection
self.query_proj = MLP(
in_size=embedding_dim, layer_sizes=layer_sizes, device=device
Expand Down
10 changes: 5 additions & 5 deletions tools/lint/black_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def main() -> None:

logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s",
level=logging.NOTSET
if args.verbose
else logging.DEBUG
if len(args.filenames) < 1000
else logging.INFO,
level=(
logging.NOTSET
if args.verbose
else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO
),
stream=sys.stderr,
)

Expand Down
14 changes: 8 additions & 6 deletions torchrec/datasets/criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,13 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]:
return dense, sparse, label

dense, sparse, labels = [], [], []
for (row_dense, row_sparse, row_label) in CriteoIterDataPipe(
for row_dense, row_sparse, row_label in CriteoIterDataPipe(
[in_file],
row_mapper=row_mapper
if not (dataset_name == "criteo_kaggle" and "test" in in_file)
else row_mapper_with_fake_label_constant,
row_mapper=(
row_mapper
if not (dataset_name == "criteo_kaggle" and "test" in in_file)
else row_mapper_with_fake_label_constant
),
):
dense.append(row_dense)
sparse.append(row_sparse)
Expand All @@ -261,7 +263,7 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]:
labels_np = labels_np.reshape((-1, 1))

path_manager = PathManagerFactory().get(path_manager_key)
for (fname, arr) in [
for fname, arr in [
(out_dense_file, dense_np),
(out_sparse_file, sparse_np),
(out_labels_file, labels_np),
Expand Down Expand Up @@ -665,7 +667,7 @@ def shuffle(
curr_first_row = curr_last_row

# Directly copy over the last day's files since they will be used for validation and testing.
for (part, input_dir) in [
for part, input_dir in [
("sparse", input_dir_sparse),
("dense", input_dir_labels_and_dense),
("labels", input_dir_labels_and_dense),
Expand Down
2 changes: 1 addition & 1 deletion torchrec/datasets/tests/test_criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _validate_dataloader_sample(
) -> None:
unbatched_samples = [{} for _ in range(self._sample_len(sample))]
for k, batched_values in sample.items():
for (idx, value) in enumerate(batched_values):
for idx, value in enumerate(batched_values):
unbatched_samples[idx][k] = value
for sample in unbatched_samples:
self._validate_sample(sample, train=train)
Expand Down
4 changes: 1 addition & 3 deletions torchrec/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def train_filter(
decimal_places: int,
idx: int,
) -> bool:
return (key_fn(idx) % 10**decimal_places) < round(
train_perc * 10**decimal_places
)
return (key_fn(idx) % 10**decimal_places) < round(train_perc * 10**decimal_places)


def val_filter(
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/benchmark/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def main() -> None:
mb = int(float(num * dim) / 1024 / 1024)
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb"

report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
report: str = (
f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
)
report += f"Module: {module_name}\n"
report += tables_info
report += "\n"
Expand Down
6 changes: 3 additions & 3 deletions torchrec/distributed/composable/tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def _test_sharding( # noqa C901
kjt_input_per_rank[ctx.rank]
)

unsharded_model_pred_jt_dict_this_rank: Dict[
str, JaggedTensor
] = unsharded_model_pred_jt_dict[ctx.rank]
unsharded_model_pred_jt_dict_this_rank: Dict[str, JaggedTensor] = (
unsharded_model_pred_jt_dict[ctx.rank]
)

embedding_names = unsharded_model_pred_jt_dict_this_rank.keys()
assert set(unsharded_model_pred_jt_dict_this_rank.keys()) == set(
Expand Down
8 changes: 5 additions & 3 deletions torchrec/distributed/composable/tests/test_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,11 @@ def test_sharding_ebc(
},
kjt_input_per_rank=kjt_input_per_rank,
sharder=TestEmbeddingBagCollectionSharder(sharding_type=sharding_type),
backend="nccl"
if (torch.cuda.is_available() and torch.cuda.device_count() >= 2)
else "gloo",
backend=(
"nccl"
if (torch.cuda.is_available() and torch.cuda.device_count() >= 2)
else "gloo"
),
constraints=constraints,
is_data_parallel=(sharding_type == ShardingType.DATA_PARALLEL.value),
use_apply_optimizer_in_backward=use_apply_optimizer_in_backward,
Expand Down
62 changes: 40 additions & 22 deletions torchrec/distributed/composable/tests/test_fused_optim_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,60 +77,78 @@ def _test_sharded_fused_optimizer_state_dict(
0
].state_dict()["state"][""]["table_0.momentum1"].gather(
dst=0,
out=None if ctx.rank != 0
# sharded column, each shard will have rowwise state
else torch.empty((4 * tables[0].num_embeddings,), device=ctx.device),
out=(
None
if ctx.rank != 0
# sharded column, each shard will have rowwise state
else torch.empty((4 * tables[0].num_embeddings,), device=ctx.device)
),
)

ebc.embedding_bags["table_1"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_1.momentum1"].gather(
dst=0,
out=None if ctx.rank != 0
# sharded rowwise
else torch.empty((tables[1].num_embeddings,), device=ctx.device),
out=(
None
if ctx.rank != 0
# sharded rowwise
else torch.empty((tables[1].num_embeddings,), device=ctx.device)
),
)

ebc.embedding_bags["table_2"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_2.momentum1"].gather(
dst=0,
out=None if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty(
(tables[2].num_embeddings, tables[2].embedding_dim),
device=ctx.device,
out=(
None
if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty(
(tables[2].num_embeddings, tables[2].embedding_dim),
device=ctx.device,
)
),
)

ebc.embedding_bags["table_2"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_2.exp_avg_sq"].gather(
dst=0,
out=None if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty((4 * tables[2].num_embeddings,), device=ctx.device),
out=(
None
if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty((4 * tables[2].num_embeddings,), device=ctx.device)
),
)

ebc.embedding_bags["table_3"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_3.momentum1"].gather(
dst=0,
out=None if ctx.rank != 0
# Row wise - with partial rowwise adam, first state is point wise
else torch.empty(
(tables[3].num_embeddings, tables[3].embedding_dim),
device=ctx.device,
out=(
None
if ctx.rank != 0
# Row wise - with partial rowwise adam, first state is point wise
else torch.empty(
(tables[3].num_embeddings, tables[3].embedding_dim),
device=ctx.device,
)
),
)

ebc.embedding_bags["table_3"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_3.exp_avg_sq"].gather(
dst=0,
out=None if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty((tables[2].num_embeddings,), device=ctx.device),
out=(
None
if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty((tables[2].num_embeddings,), device=ctx.device)
),
)

# pyre-ignore
Expand Down
10 changes: 7 additions & 3 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,13 @@ def forward(self, kjt: KeyedJaggedTensor) -> KJTList:
fx_marker("KJT_ONE_TO_ALL_FORWARD_BEGIN", kjt)
kjts: List[KeyedJaggedTensor] = kjt.split(self._splits)
dist_kjts = [
kjts[rank]
if self._device_type == "meta"
else kjts[rank].to(torch.device(self._device_type, rank), non_blocking=True)
(
kjts[rank]
if self._device_type == "meta"
else kjts[rank].to(
torch.device(self._device_type, rank), non_blocking=True
)
)
for rank in range(self._world_size)
]
ret = KJTList(dist_kjts)
Expand Down
Loading

0 comments on commit bd460e5

Please sign in to comment.