Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numerical equivalence test for prefetch train pipeline #1717

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 179 additions & 43 deletions torchrec/distributed/tests/test_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import os
import unittest
from dataclasses import dataclass
from typing import cast, Dict, List, Optional, Tuple
from typing import Any, cast, Dict, List, Optional, Tuple
from unittest.mock import MagicMock

import torch
import torch.distributed as dist
from hypothesis import given, settings, strategies as st, Verbosity
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, KJTList
from torchrec.distributed.embeddingbag import (
Expand All @@ -41,6 +42,7 @@
from torchrec.distributed.train_pipeline import (
DataLoadingThread,
EvalPipelineSparseDist,
PrefetchTrainPipelineSparseDist,
TrainPipelineBase,
TrainPipelineSparseDist,
)
Expand All @@ -52,7 +54,7 @@
ShardingPlan,
ShardingType,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection

from torchrec.optim.keyed import KeyedOptimizerWrapper
Expand Down Expand Up @@ -177,10 +179,12 @@ def test_equal_to_non_pipelined(self) -> None:
pred_gpu = pipeline.progress(dataloader)

self.assertEqual(pred_gpu.device, self.device)
# Results will be close but not exactly equal as one model is on CPU and other on GPU
# If both were on GPU, the results will be exactly the same
self.assertTrue(torch.isclose(pred_gpu.cpu(), pred))


class TrainPipelineSparseDistTest(unittest.TestCase):
class TrainPipelineSparseDistTestBase(unittest.TestCase):
def setUp(self) -> None:
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
Expand All @@ -191,7 +195,6 @@ def setUp(self) -> None:

num_features = 4
num_weighted_features = 2

self.tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 100,
Expand All @@ -217,6 +220,31 @@ def tearDown(self) -> None:
super().tearDown()
dist.destroy_process_group(self.pg)

def _generate_data(
self,
num_batches: int = 5,
batch_size: int = 1,
) -> List[ModelInput]:
return [
ModelInput.generate(
tables=self.tables,
weighted_tables=self.weighted_tables,
batch_size=batch_size,
world_size=1,
num_float_features=10,
)[0]
for i in range(num_batches)
]

def _set_table_weights_precision(self, dtype: DataType) -> None:
for i in range(len(self.tables)):
self.tables[i].data_type = dtype

for i in range(len(self.weighted_tables)):
self.weighted_tables[i].data_type = dtype


class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
not torch.cuda.is_available(),
Expand Down Expand Up @@ -336,7 +364,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
optimizer_no_pipeline.step()

pred_pipeline = pipeline.progress(dataloader)
torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu())
self.assertTrue(torch.equal(pred_pipeline.cpu(), pred.cpu()))

self.assertEqual(len(pipeline._pipelined_modules), 1)
self.assertIsInstance(
Expand Down Expand Up @@ -405,43 +433,6 @@ def _setup_cpu_model_and_opt(self) -> Tuple[TestSparseNN, optim.SGD]:
cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=0.1)
return cpu_model, cpu_optimizer

def _generate_data(self, num_batches: int = 5) -> List[ModelInput]:
return [
ModelInput.generate(
tables=self.tables,
weighted_tables=self.weighted_tables,
batch_size=1,
world_size=1,
num_float_features=10,
)[0]
for i in range(num_batches)
]

def _test_pipelining(
self,
sharder: EmbeddingBagCollectionSharder,
execute_all_batches: bool,
) -> None:
pipeline = self._setup_pipeline(sharder, execute_all_batches)
cpu_model, cpu_optimizer = self._setup_cpu_model_and_opt()
data = self._generate_data()

dataloader = iter(data)
if not execute_all_batches:
data = data[:-2]

for batch in data:
cpu_optimizer.zero_grad()
loss, pred = cpu_model(batch)
loss.backward()
cpu_optimizer.step()

pred_gpu = pipeline.progress(dataloader)

self.assertEqual(len(pipeline._pipelined_modules), 2)
self.assertEqual(pred_gpu.device, self.device)
self.assertEqual(pred_gpu.cpu().size(), pred.size())

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
Expand All @@ -459,7 +450,6 @@ def test_pipelining(self, execute_all_batches: bool) -> None:
)
cpu_model, cpu_optimizer = self._setup_cpu_model_and_opt()
data = self._generate_data()

dataloader = iter(data)
if not execute_all_batches:
data = data[:-2]
Expand Down Expand Up @@ -574,6 +564,152 @@ def test_multi_dataloader_pipelining(self) -> None:
)


class PrefetchTrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
def generate_sharded_model_and_optimizer(
self,
model: nn.Module,
sharding_type: str,
kernel_type: str,
fused_params: Optional[Dict[str, Any]] = None,
) -> Tuple[nn.Module, Optimizer]:
sharder = TestEBCSharder(
sharding_type=sharding_type,
kernel_type=kernel_type,
fused_params=fused_params,
)
sharded_model = DistributedModelParallel(
module=copy.deepcopy(model),
env=ShardingEnv.from_process_group(self.pg),
init_data_parallel=False,
device=self.device,
sharders=[
cast(
ModuleSharder[nn.Module],
sharder,
)
],
)
optimizer = optim.SGD(sharded_model.parameters(), lr=0.1)
return sharded_model, optimizer

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
# pyre-ignore[56]
@given(
execute_all_batches=st.booleans(),
weight_precision=st.sampled_from(
[
DataType.FP16,
DataType.FP32,
]
),
cache_precision=st.sampled_from(
[
DataType.FP16,
DataType.FP32,
]
),
load_factor=st.sampled_from(
[
0.2,
0.4,
0.6,
]
),
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
]
),
)
@settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None)
def test_equal_to_non_pipelined(
self,
execute_all_batches: bool,
weight_precision: DataType,
cache_precision: DataType,
load_factor: float,
sharding_type: str,
kernel_type: str,
) -> None:
"""
Checks that pipelined training is equivalent to non-pipelined training.
"""
mixed_precision: bool = weight_precision != cache_precision
self._set_table_weights_precision(weight_precision)
data = self._generate_data(
num_batches=12,
batch_size=32,
)
dataloader = iter(data)

fused_params = {
"cache_load_factor": load_factor,
"cache_precision": cache_precision,
"stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16
}
fused_params_pipelined = {
**fused_params,
"prefetch_pipeline": True,
}

model = TestSparseNN(
tables=self.tables,
weighted_tables=self.weighted_tables,
dense_device=self.device,
sparse_device=torch.device("meta"),
)
sharded_model, optim = self.generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)

(
sharded_model_pipelined,
optim_pipelined,
) = self.generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params_pipelined
)
copy_state_dict(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = PrefetchTrainPipelineSparseDist(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
execute_all_batches=execute_all_batches,
)

if not execute_all_batches:
data = data[:-3]

for batch in data:
# Forward + backward w/o pipelining
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

# Forward + backward w/ pipelining
pred_pipeline = pipeline.progress(dataloader)

if not mixed_precision:
# Rounding error is expected when using different precisions for weights and cache
self.assertTrue(torch.equal(pred, pred_pipeline))
else:
torch.testing.assert_close(pred, pred_pipeline)


class DataLoadingThreadTest(unittest.TestCase):
def test_fetch_data(self) -> None:
data = []
Expand Down
Loading