Skip to content

Commit

Permalink
update UT and fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
chunyuan-w committed Oct 19, 2023
1 parent b4de21e commit a792605
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 26 deletions.
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "torchvision", "csrc")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "ops", "*.cpp")
) + glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
main_file = (
glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
)
source_cpu = (
glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
Expand Down
25 changes: 9 additions & 16 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, determinist
tol = 5e-3
else:
tol = 4e-3

if x_dtype == torch.bfloat16:
elif x_dtype == torch.bfloat16:
tol = 5e-3

pool_size = 5
Expand Down Expand Up @@ -509,7 +508,7 @@ def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
aligned=aligned,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
)
)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
Expand Down Expand Up @@ -730,14 +729,19 @@ def _create_tensors_with_iou(self, N, iou_thresh):

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("seed", range(10))
def test_nms_ref(self, iou, seed):
def test_nms_ref(self, iou, seed, dtype=torch.float):
torch.random.manual_seed(seed)
err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))

if dtype == torch.bfloat16:
keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
torch.testing.assert_close(keep_ref_float, keep_dtype)

def test_nms_input_errors(self):
with pytest.raises(RuntimeError):
ops.nms(torch.rand(4), torch.rand(3), 0.5)
Expand Down Expand Up @@ -769,17 +773,6 @@ def test_qnms(self, iou, scale, zero_point):

torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_cpu(self, iou, dtype=torch.float):
err_msg = "NMS incompatible between float and {dtype} for IoU={}"

boxes, scores = self._create_tensors_with_iou(1000, iou)
r_ref = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
r_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)

is_eq = torch.allclose(r_ref, r_dtype)
assert is_eq, err_msg.format(iou)

@pytest.mark.parametrize(
"device",
(
Expand Down Expand Up @@ -815,7 +808,7 @@ def test_autocast(self, iou, dtype):
@pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, iou, dtype):
with torch.cpu.amp.autocast():
self.test_nms_cpu(iou=iou, dtype=dtype)
self.test_nms_ref(iou=iou, seed=0, dtype=dtype)

@pytest.mark.parametrize(
"device",
Expand Down
14 changes: 10 additions & 4 deletions torchvision/csrc/ops/autocast/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ namespace ops {

namespace {

template<c10::DispatchKey autocast_key, c10::DeviceType device_type>
template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor nms_autocast(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);

return nms(
at::autocast::cached_cast(at::kFloat, dets, device_type),
at::autocast::cached_cast(at::kFloat, scores, device_type),
Expand All @@ -25,11 +25,17 @@ at::Tensor nms_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN((nms_autocast<c10::DispatchKey::Autocast, c10::DeviceType::CUDA>)));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::Autocast, c10::DeviceType::CUDA>)));
}

TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN((nms_autocast<c10::DispatchKey::AutocastCPU, c10::DeviceType::CPU>)));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::AutocastCPU, c10::DeviceType::CPU>)));
}

} // namespace ops
Expand Down
10 changes: 7 additions & 3 deletions torchvision/csrc/ops/autocast/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace ops {

namespace {

template<c10::DispatchKey autocast_key, c10::DeviceType device_type>
template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor roi_align_autocast(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -35,13 +35,17 @@ at::Tensor roi_align_autocast(
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN((roi_align_autocast<c10::DispatchKey::Autocast, c10::DeviceType::CUDA>)));
TORCH_FN((roi_align_autocast<
c10::DispatchKey::Autocast,
c10::DeviceType::CUDA>)));
}

TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN((roi_align_autocast<c10::DispatchKey::AutocastCPU, c10::DeviceType::CPU>)));
TORCH_FN((roi_align_autocast<
c10::DispatchKey::AutocastCPU,
c10::DeviceType::CPU>)));
}

} // namespace ops
Expand Down

0 comments on commit a792605

Please sign in to comment.