Skip to content

Commit

Permalink
2024-10-17 nightly release (54ec8aa)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 17, 2024
1 parent 5e6ebb9 commit 555d3db
Show file tree
Hide file tree
Showing 16 changed files with 604 additions and 226 deletions.
40 changes: 31 additions & 9 deletions torchrec/metrics/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def rec_metric_gpu_sync_test_launcher(
entry_point: Callable[..., None],
batch_size: int = BATCH_SIZE,
batch_window_size: int = BATCH_WINDOW_SIZE,
**kwargs: Any,
**kwargs: Dict[str, Any],
) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
lc = get_launch_config(
Expand All @@ -385,6 +385,7 @@ def rec_metric_gpu_sync_test_launcher(
should_validate_update,
batch_size,
batch_window_size,
kwargs.get("n_classes", None),
)


Expand All @@ -402,6 +403,7 @@ def sync_test_helper(
batch_window_size: int = BATCH_WINDOW_SIZE,
n_classes: Optional[int] = None,
zero_weights: bool = False,
**kwargs: Dict[str, Any],
) -> None:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
Expand All @@ -413,13 +415,19 @@ def sync_test_helper(

tasks = gen_test_tasks(task_names)

if n_classes:
# pyre-ignore[6]: Incompatible parameter type
kwargs["number_of_classes"] = n_classes

auc = target_clazz(
world_size=world_size,
batch_size=batch_size,
my_rank=rank,
compute_on_all_ranks=compute_on_all_ranks,
tasks=tasks,
window_size=batch_window_size * world_size,
# pyre-ignore[6]: Incompatible parameter type
**kwargs,
)

weight_value: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -466,10 +474,17 @@ def sync_test_helper(
res = auc.compute()

if rank == 0:
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
)
# Serving Calibration uses Calibration naming inconsistently
if metric_name == "serving_calibration":
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_calibration"],
)
else:
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
)

# we also test the case where other rank has more tensors than rank 0
auc.reset()
Expand All @@ -489,10 +504,17 @@ def sync_test_helper(
res = auc.compute()

if rank == 0:
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
)
# Serving Calibration uses Calibration naming inconsistently
if metric_name == "serving_calibration":
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_calibration"],
)
else:
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
)

dist.destroy_process_group()

Expand Down
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
RecTaskInfo,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -251,3 +253,24 @@ def test_accuracy(self) -> None:
except AssertionError:
print("Assertion error caught with data set ", inputs)
raise


class AccuracyGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = AccuracyMetric
task_name: str = "accuracy"

def test_sync_accuracy(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=AccuracyMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestAccuracyMetric,
metric_name=AccuracyGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_auprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
)
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -346,3 +348,24 @@ def test_required_input_for_grouped_auprc(self) -> None:
)

self.assertIn("grouping_keys", auprc.get_required_inputs())


class AUPRCGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = AUPRCMetric
task_name: str = "auprc"

def test_sync_auprc(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=AUPRCMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestAUPRCMetric,
metric_name=AUPRCGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -77,3 +79,24 @@ def test_fused_calibration(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CalibrationGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CalibrationMetric
task_name: str = "calibration"

def test_sync_calibration(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=CalibrationMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestCalibrationMetric,
metric_name=CalibrationGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_ctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -71,3 +73,24 @@ def test_fused_ctr(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CTRGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CTRMetric
task_name: str = "ctr"

def test_sync_ctr(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=CTRMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestCTRMetric,
metric_name=CTRGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -74,3 +76,24 @@ def test_fused_mae(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class MAEGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MAEMetric
task_name: str = "mae"

def test_sync_mae(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MAEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMAEMetric,
metric_name=MAEGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -123,3 +125,24 @@ def test_fused_rmse(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class MSEGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MSEMetric
task_name: str = "mse"

def test_sync_mse(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MSEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMSEMetric,
metric_name=MSEGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
25 changes: 25 additions & 0 deletions torchrec/metrics/tests/test_multiclass_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -113,3 +115,26 @@ def test_multiclass_recall_update_fused(self) -> None:
batch_window_size=10,
n_classes=N_CLASSES,
)


class MulticlassRecallGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MulticlassRecallMetric
task_name: str = "multiclass_recall"

def test_sync_multiclass_recall(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MulticlassRecallMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMulticlassRecallMetric,
metric_name=MulticlassRecallGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
# pyre-ignore[6] Incompatible parameter type
n_classes=N_CLASSES,
)
Loading

0 comments on commit 555d3db

Please sign in to comment.