Skip to content

Commit

Permalink
Merge branch 'facebookresearch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
sdy623 authored Nov 13, 2023
2 parents b695e24 + 017abbf commit f51c508
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 24 deletions.
22 changes: 7 additions & 15 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,7 @@ def build_batch_data_loader(
num_workers=0,
collate_fn=None,
drop_last: bool = True,
prefetch_factor=None,
persistent_workers=False,
pin_memory=False,
**kwargs,
):
"""
Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
Expand All @@ -328,6 +326,8 @@ def build_batch_data_loader(
total_batch_size, world_size
)
batch_size = total_batch_size // world_size
logger = logging.getLogger(__name__)
logger.info("Making batched data loader with batch_size=%d", batch_size)

if isinstance(dataset, torchdata.IterableDataset):
assert sampler is None, "sampler must be None if dataset is IterableDataset"
Expand All @@ -341,9 +341,7 @@ def build_batch_data_loader(
num_workers=num_workers,
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
worker_init_fn=worker_init_reset_seed,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
**kwargs
) # yield individual mapped dict
data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
if collate_fn is None:
Expand All @@ -357,9 +355,7 @@ def build_batch_data_loader(
num_workers=num_workers,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
worker_init_fn=worker_init_reset_seed,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
**kwargs
)


Expand Down Expand Up @@ -499,9 +495,7 @@ def build_detection_train_loader(
aspect_ratio_grouping=True,
num_workers=0,
collate_fn=None,
prefetch_factor=None,
persistent_workers=False,
pin_memory=False,
**kwargs
):
"""
Build a dataloader for object detection with some default features.
Expand Down Expand Up @@ -553,9 +547,7 @@ def build_detection_train_loader(
aspect_ratio_grouping=aspect_ratio_grouping,
num_workers=num_workers,
collate_fn=collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
**kwargs
)


Expand Down
2 changes: 2 additions & 0 deletions detectron2/engine/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,10 @@ def __call__(self, original_image):
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
image.to(self.cfg.MODEL.DEVICE)

inputs = {"image": image, "height": height, "width": width}

predictions = self.model([inputs])[0]
return predictions

Expand Down
39 changes: 30 additions & 9 deletions detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,40 @@ HOST_DEVICE_INLINE int convex_hull_graham(
}
#else
// CPU version
std::sort(
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
T temp = cross_2d<T>(A, B);
if (fabs(temp) < 1e-6) {
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
} else {
return temp > 0;
}
});
// std::sort(
// q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
// T temp = cross_2d<T>(A, B);

// if (fabs(temp) < 1e-6) {
// return dot_2d<T>(A, A) < dot_2d<T>(B, B);
// } else {
// return temp > 0;
// }
// });
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}

for (int i = 1; i < num_in - 1; i++) {
for (int j = i + 1; j < num_in; j++) {
T crossProduct = cross_2d<T>(q[i], q[j]);
if ((crossProduct < -1e-6) ||
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
auto q_tmp = q[i];
q[i] = q[j];
q[j] = q_tmp;
auto dist_tmp = dist[i];
dist[i] = dist[j];
dist[j] = dist_tmp;
}
}
}

// compute distance to origin after sort, since the points are now different.
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}

#endif

// Step 4:
Expand Down

0 comments on commit f51c508

Please sign in to comment.