Skip to content

Commit

Permalink
Add numerical equivalence test for prefetch train pipeline (#1717)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1717

Add test for prefetch train pipeline for embedding offloading. Test checks that predictions are equal.

Reviewed By: henrylhtsang

Differential Revision: D53673182

fbshipit-source-id: a0b0ecef0afac4a7f009e9367fd231688fa61ceb
  • Loading branch information
sarckk authored and facebook-github-bot committed Feb 20, 2024
1 parent 03faa0b commit cff411e
Showing 1 changed file with 179 additions and 43 deletions.
222 changes: 179 additions & 43 deletions torchrec/distributed/tests/test_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import os
import unittest
from dataclasses import dataclass
from typing import cast, Dict, List, Optional, Tuple
from typing import Any, cast, Dict, List, Optional, Tuple
from unittest.mock import MagicMock

import torch
import torch.distributed as dist
from hypothesis import given, settings, strategies as st, Verbosity
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, KJTList
from torchrec.distributed.embeddingbag import (
Expand All @@ -41,6 +42,7 @@
from torchrec.distributed.train_pipeline import (
DataLoadingThread,
EvalPipelineSparseDist,
PrefetchTrainPipelineSparseDist,
TrainPipelineBase,
TrainPipelineSparseDist,
)
Expand All @@ -52,7 +54,7 @@
ShardingPlan,
ShardingType,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection

from torchrec.optim.keyed import KeyedOptimizerWrapper
Expand Down Expand Up @@ -177,10 +179,12 @@ def test_equal_to_non_pipelined(self) -> None:
pred_gpu = pipeline.progress(dataloader)

self.assertEqual(pred_gpu.device, self.device)
# Results will be close but not exactly equal as one model is on CPU and other on GPU
# If both were on GPU, the results will be exactly the same
self.assertTrue(torch.isclose(pred_gpu.cpu(), pred))


class TrainPipelineSparseDistTest(unittest.TestCase):
class TrainPipelineSparseDistTestBase(unittest.TestCase):
def setUp(self) -> None:
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
Expand All @@ -191,7 +195,6 @@ def setUp(self) -> None:

num_features = 4
num_weighted_features = 2

self.tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 100,
Expand All @@ -217,6 +220,31 @@ def tearDown(self) -> None:
super().tearDown()
dist.destroy_process_group(self.pg)

def _generate_data(
self,
num_batches: int = 5,
batch_size: int = 1,
) -> List[ModelInput]:
return [
ModelInput.generate(
tables=self.tables,
weighted_tables=self.weighted_tables,
batch_size=batch_size,
world_size=1,
num_float_features=10,
)[0]
for i in range(num_batches)
]

def _set_table_weights_precision(self, dtype: DataType) -> None:
for i in range(len(self.tables)):
self.tables[i].data_type = dtype

for i in range(len(self.weighted_tables)):
self.weighted_tables[i].data_type = dtype


class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
not torch.cuda.is_available(),
Expand Down Expand Up @@ -336,7 +364,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
optimizer_no_pipeline.step()

pred_pipeline = pipeline.progress(dataloader)
torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu())
self.assertTrue(torch.equal(pred_pipeline.cpu(), pred.cpu()))

self.assertEqual(len(pipeline._pipelined_modules), 1)
self.assertIsInstance(
Expand Down Expand Up @@ -405,43 +433,6 @@ def _setup_cpu_model_and_opt(self) -> Tuple[TestSparseNN, optim.SGD]:
cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=0.1)
return cpu_model, cpu_optimizer

def _generate_data(self, num_batches: int = 5) -> List[ModelInput]:
return [
ModelInput.generate(
tables=self.tables,
weighted_tables=self.weighted_tables,
batch_size=1,
world_size=1,
num_float_features=10,
)[0]
for i in range(num_batches)
]

def _test_pipelining(
self,
sharder: EmbeddingBagCollectionSharder,
execute_all_batches: bool,
) -> None:
pipeline = self._setup_pipeline(sharder, execute_all_batches)
cpu_model, cpu_optimizer = self._setup_cpu_model_and_opt()
data = self._generate_data()

dataloader = iter(data)
if not execute_all_batches:
data = data[:-2]

for batch in data:
cpu_optimizer.zero_grad()
loss, pred = cpu_model(batch)
loss.backward()
cpu_optimizer.step()

pred_gpu = pipeline.progress(dataloader)

self.assertEqual(len(pipeline._pipelined_modules), 2)
self.assertEqual(pred_gpu.device, self.device)
self.assertEqual(pred_gpu.cpu().size(), pred.size())

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
Expand All @@ -459,7 +450,6 @@ def test_pipelining(self, execute_all_batches: bool) -> None:
)
cpu_model, cpu_optimizer = self._setup_cpu_model_and_opt()
data = self._generate_data()

dataloader = iter(data)
if not execute_all_batches:
data = data[:-2]
Expand Down Expand Up @@ -574,6 +564,152 @@ def test_multi_dataloader_pipelining(self) -> None:
)


class PrefetchTrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
def generate_sharded_model_and_optimizer(
self,
model: nn.Module,
sharding_type: str,
kernel_type: str,
fused_params: Optional[Dict[str, Any]] = None,
) -> Tuple[nn.Module, Optimizer]:
sharder = TestEBCSharder(
sharding_type=sharding_type,
kernel_type=kernel_type,
fused_params=fused_params,
)
sharded_model = DistributedModelParallel(
module=copy.deepcopy(model),
env=ShardingEnv.from_process_group(self.pg),
init_data_parallel=False,
device=self.device,
sharders=[
cast(
ModuleSharder[nn.Module],
sharder,
)
],
)
optimizer = optim.SGD(sharded_model.parameters(), lr=0.1)
return sharded_model, optimizer

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
# pyre-ignore[56]
@given(
execute_all_batches=st.booleans(),
weight_precision=st.sampled_from(
[
DataType.FP16,
DataType.FP32,
]
),
cache_precision=st.sampled_from(
[
DataType.FP16,
DataType.FP32,
]
),
load_factor=st.sampled_from(
[
0.2,
0.4,
0.6,
]
),
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
]
),
)
@settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None)
def test_equal_to_non_pipelined(
self,
execute_all_batches: bool,
weight_precision: DataType,
cache_precision: DataType,
load_factor: float,
sharding_type: str,
kernel_type: str,
) -> None:
"""
Checks that pipelined training is equivalent to non-pipelined training.
"""
mixed_precision: bool = weight_precision != cache_precision
self._set_table_weights_precision(weight_precision)
data = self._generate_data(
num_batches=12,
batch_size=32,
)
dataloader = iter(data)

fused_params = {
"cache_load_factor": load_factor,
"cache_precision": cache_precision,
"stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16
}
fused_params_pipelined = {
**fused_params,
"prefetch_pipeline": True,
}

model = TestSparseNN(
tables=self.tables,
weighted_tables=self.weighted_tables,
dense_device=self.device,
sparse_device=torch.device("meta"),
)
sharded_model, optim = self.generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)

(
sharded_model_pipelined,
optim_pipelined,
) = self.generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params_pipelined
)
copy_state_dict(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = PrefetchTrainPipelineSparseDist(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
execute_all_batches=execute_all_batches,
)

if not execute_all_batches:
data = data[:-3]

for batch in data:
# Forward + backward w/o pipelining
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

# Forward + backward w/ pipelining
pred_pipeline = pipeline.progress(dataloader)

if not mixed_precision:
# Rounding error is expected when using different precisions for weights and cache
self.assertTrue(torch.equal(pred, pred_pipeline))
else:
torch.testing.assert_close(pred, pred_pipeline)


class DataLoadingThreadTest(unittest.TestCase):
def test_fetch_data(self) -> None:
data = []
Expand Down

0 comments on commit cff411e

Please sign in to comment.