Skip to content

Commit 555d3db

Browse files
author
pytorchbot
committed
2024-10-17 nightly release (54ec8aa)
1 parent 5e6ebb9 commit 555d3db

16 files changed

+604
-226
lines changed

torchrec/metrics/test_utils/__init__.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def rec_metric_gpu_sync_test_launcher(
365365
entry_point: Callable[..., None],
366366
batch_size: int = BATCH_SIZE,
367367
batch_window_size: int = BATCH_WINDOW_SIZE,
368-
**kwargs: Any,
368+
**kwargs: Dict[str, Any],
369369
) -> None:
370370
with tempfile.TemporaryDirectory() as tmpdir:
371371
lc = get_launch_config(
@@ -385,6 +385,7 @@ def rec_metric_gpu_sync_test_launcher(
385385
should_validate_update,
386386
batch_size,
387387
batch_window_size,
388+
kwargs.get("n_classes", None),
388389
)
389390

390391

@@ -402,6 +403,7 @@ def sync_test_helper(
402403
batch_window_size: int = BATCH_WINDOW_SIZE,
403404
n_classes: Optional[int] = None,
404405
zero_weights: bool = False,
406+
**kwargs: Dict[str, Any],
405407
) -> None:
406408
rank = int(os.environ["RANK"])
407409
world_size = int(os.environ["WORLD_SIZE"])
@@ -413,13 +415,19 @@ def sync_test_helper(
413415

414416
tasks = gen_test_tasks(task_names)
415417

418+
if n_classes:
419+
# pyre-ignore[6]: Incompatible parameter type
420+
kwargs["number_of_classes"] = n_classes
421+
416422
auc = target_clazz(
417423
world_size=world_size,
418424
batch_size=batch_size,
419425
my_rank=rank,
420426
compute_on_all_ranks=compute_on_all_ranks,
421427
tasks=tasks,
422428
window_size=batch_window_size * world_size,
429+
# pyre-ignore[6]: Incompatible parameter type
430+
**kwargs,
423431
)
424432

425433
weight_value: Optional[torch.Tensor] = None
@@ -466,10 +474,17 @@ def sync_test_helper(
466474
res = auc.compute()
467475

468476
if rank == 0:
469-
assert torch.allclose(
470-
test_metrics[1][task_names[0]],
471-
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
472-
)
477+
# Serving Calibration uses Calibration naming inconsistently
478+
if metric_name == "serving_calibration":
479+
assert torch.allclose(
480+
test_metrics[1][task_names[0]],
481+
res[f"{metric_name}-{task_names[0]}|window_calibration"],
482+
)
483+
else:
484+
assert torch.allclose(
485+
test_metrics[1][task_names[0]],
486+
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
487+
)
473488

474489
# we also test the case where other rank has more tensors than rank 0
475490
auc.reset()
@@ -489,10 +504,17 @@ def sync_test_helper(
489504
res = auc.compute()
490505

491506
if rank == 0:
492-
assert torch.allclose(
493-
test_metrics[1][task_names[0]],
494-
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
495-
)
507+
# Serving Calibration uses Calibration naming inconsistently
508+
if metric_name == "serving_calibration":
509+
assert torch.allclose(
510+
test_metrics[1][task_names[0]],
511+
res[f"{metric_name}-{task_names[0]}|window_calibration"],
512+
)
513+
else:
514+
assert torch.allclose(
515+
test_metrics[1][task_names[0]],
516+
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
517+
)
496518

497519
dist.destroy_process_group()
498520

torchrec/metrics/tests/test_accuracy.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1818
from torchrec.metrics.test_utils import (
1919
metric_test_helper,
20+
rec_metric_gpu_sync_test_launcher,
2021
rec_metric_value_test_launcher,
2122
RecTaskInfo,
23+
sync_test_helper,
2224
TestMetric,
2325
)
2426

@@ -251,3 +253,24 @@ def test_accuracy(self) -> None:
251253
except AssertionError:
252254
print("Assertion error caught with data set ", inputs)
253255
raise
256+
257+
258+
class AccuracyGPUSyncTest(unittest.TestCase):
259+
clazz: Type[RecMetric] = AccuracyMetric
260+
task_name: str = "accuracy"
261+
262+
def test_sync_accuracy(self) -> None:
263+
rec_metric_gpu_sync_test_launcher(
264+
target_clazz=AccuracyMetric,
265+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
266+
test_clazz=TestAccuracyMetric,
267+
metric_name=AccuracyGPUSyncTest.task_name,
268+
task_names=["t1"],
269+
fused_update_limit=0,
270+
compute_on_all_ranks=False,
271+
should_validate_update=False,
272+
world_size=2,
273+
batch_size=5,
274+
batch_window_size=20,
275+
entry_point=sync_test_helper,
276+
)

torchrec/metrics/tests/test_auprc.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
)
2424
from torchrec.metrics.test_utils import (
2525
metric_test_helper,
26+
rec_metric_gpu_sync_test_launcher,
2627
rec_metric_value_test_launcher,
28+
sync_test_helper,
2729
TestMetric,
2830
)
2931

@@ -346,3 +348,24 @@ def test_required_input_for_grouped_auprc(self) -> None:
346348
)
347349

348350
self.assertIn("grouping_keys", auprc.get_required_inputs())
351+
352+
353+
class AUPRCGPUSyncTest(unittest.TestCase):
354+
clazz: Type[RecMetric] = AUPRCMetric
355+
task_name: str = "auprc"
356+
357+
def test_sync_auprc(self) -> None:
358+
rec_metric_gpu_sync_test_launcher(
359+
target_clazz=AUPRCMetric,
360+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
361+
test_clazz=TestAUPRCMetric,
362+
metric_name=AUPRCGPUSyncTest.task_name,
363+
task_names=["t1"],
364+
fused_update_limit=0,
365+
compute_on_all_ranks=False,
366+
should_validate_update=False,
367+
world_size=2,
368+
batch_size=5,
369+
batch_window_size=20,
370+
entry_point=sync_test_helper,
371+
)

torchrec/metrics/tests/test_calibration.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1616
from torchrec.metrics.test_utils import (
1717
metric_test_helper,
18+
rec_metric_gpu_sync_test_launcher,
1819
rec_metric_value_test_launcher,
20+
sync_test_helper,
1921
TestMetric,
2022
)
2123

@@ -77,3 +79,24 @@ def test_fused_calibration(self) -> None:
7779
world_size=WORLD_SIZE,
7880
entry_point=metric_test_helper,
7981
)
82+
83+
84+
class CalibrationGPUSyncTest(unittest.TestCase):
85+
clazz: Type[RecMetric] = CalibrationMetric
86+
task_name: str = "calibration"
87+
88+
def test_sync_calibration(self) -> None:
89+
rec_metric_gpu_sync_test_launcher(
90+
target_clazz=CalibrationMetric,
91+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
92+
test_clazz=TestCalibrationMetric,
93+
metric_name=CalibrationGPUSyncTest.task_name,
94+
task_names=["t1"],
95+
fused_update_limit=0,
96+
compute_on_all_ranks=False,
97+
should_validate_update=False,
98+
world_size=2,
99+
batch_size=5,
100+
batch_window_size=20,
101+
entry_point=sync_test_helper,
102+
)

torchrec/metrics/tests/test_ctr.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1616
from torchrec.metrics.test_utils import (
1717
metric_test_helper,
18+
rec_metric_gpu_sync_test_launcher,
1819
rec_metric_value_test_launcher,
20+
sync_test_helper,
1921
TestMetric,
2022
)
2123

@@ -71,3 +73,24 @@ def test_fused_ctr(self) -> None:
7173
world_size=WORLD_SIZE,
7274
entry_point=metric_test_helper,
7375
)
76+
77+
78+
class CTRGPUSyncTest(unittest.TestCase):
79+
clazz: Type[RecMetric] = CTRMetric
80+
task_name: str = "ctr"
81+
82+
def test_sync_ctr(self) -> None:
83+
rec_metric_gpu_sync_test_launcher(
84+
target_clazz=CTRMetric,
85+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
86+
test_clazz=TestCTRMetric,
87+
metric_name=CTRGPUSyncTest.task_name,
88+
task_names=["t1"],
89+
fused_update_limit=0,
90+
compute_on_all_ranks=False,
91+
should_validate_update=False,
92+
world_size=2,
93+
batch_size=5,
94+
batch_window_size=20,
95+
entry_point=sync_test_helper,
96+
)

torchrec/metrics/tests/test_mae.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1616
from torchrec.metrics.test_utils import (
1717
metric_test_helper,
18+
rec_metric_gpu_sync_test_launcher,
1819
rec_metric_value_test_launcher,
20+
sync_test_helper,
1921
TestMetric,
2022
)
2123

@@ -74,3 +76,24 @@ def test_fused_mae(self) -> None:
7476
world_size=WORLD_SIZE,
7577
entry_point=metric_test_helper,
7678
)
79+
80+
81+
class MAEGPUSyncTest(unittest.TestCase):
82+
clazz: Type[RecMetric] = MAEMetric
83+
task_name: str = "mae"
84+
85+
def test_sync_mae(self) -> None:
86+
rec_metric_gpu_sync_test_launcher(
87+
target_clazz=MAEMetric,
88+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
89+
test_clazz=TestMAEMetric,
90+
metric_name=MAEGPUSyncTest.task_name,
91+
task_names=["t1"],
92+
fused_update_limit=0,
93+
compute_on_all_ranks=False,
94+
should_validate_update=False,
95+
world_size=2,
96+
batch_size=5,
97+
batch_window_size=20,
98+
entry_point=sync_test_helper,
99+
)

torchrec/metrics/tests/test_mse.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1616
from torchrec.metrics.test_utils import (
1717
metric_test_helper,
18+
rec_metric_gpu_sync_test_launcher,
1819
rec_metric_value_test_launcher,
20+
sync_test_helper,
1921
TestMetric,
2022
)
2123

@@ -123,3 +125,24 @@ def test_fused_rmse(self) -> None:
123125
world_size=WORLD_SIZE,
124126
entry_point=metric_test_helper,
125127
)
128+
129+
130+
class MSEGPUSyncTest(unittest.TestCase):
131+
clazz: Type[RecMetric] = MSEMetric
132+
task_name: str = "mse"
133+
134+
def test_sync_mse(self) -> None:
135+
rec_metric_gpu_sync_test_launcher(
136+
target_clazz=MSEMetric,
137+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
138+
test_clazz=TestMSEMetric,
139+
metric_name=MSEGPUSyncTest.task_name,
140+
task_names=["t1"],
141+
fused_update_limit=0,
142+
compute_on_all_ranks=False,
143+
should_validate_update=False,
144+
world_size=2,
145+
batch_size=5,
146+
batch_window_size=20,
147+
entry_point=sync_test_helper,
148+
)

torchrec/metrics/tests/test_multiclass_recall.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
2020
from torchrec.metrics.test_utils import (
2121
metric_test_helper,
22+
rec_metric_gpu_sync_test_launcher,
2223
rec_metric_value_test_launcher,
24+
sync_test_helper,
2325
TestMetric,
2426
)
2527

@@ -113,3 +115,26 @@ def test_multiclass_recall_update_fused(self) -> None:
113115
batch_window_size=10,
114116
n_classes=N_CLASSES,
115117
)
118+
119+
120+
class MulticlassRecallGPUSyncTest(unittest.TestCase):
121+
clazz: Type[RecMetric] = MulticlassRecallMetric
122+
task_name: str = "multiclass_recall"
123+
124+
def test_sync_multiclass_recall(self) -> None:
125+
rec_metric_gpu_sync_test_launcher(
126+
target_clazz=MulticlassRecallMetric,
127+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
128+
test_clazz=TestMulticlassRecallMetric,
129+
metric_name=MulticlassRecallGPUSyncTest.task_name,
130+
task_names=["t1"],
131+
fused_update_limit=0,
132+
compute_on_all_ranks=False,
133+
should_validate_update=False,
134+
world_size=2,
135+
batch_size=5,
136+
batch_window_size=20,
137+
entry_point=sync_test_helper,
138+
# pyre-ignore[6] Incompatible parameter type
139+
n_classes=N_CLASSES,
140+
)

0 commit comments

Comments
 (0)