Skip to content

Commit

Permalink
fix: forced memory layout
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardodalinky committed Aug 31, 2024
1 parent c832ba4 commit 8e99afd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
11 changes: 5 additions & 6 deletions python/fpsample/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ def fps_sampling(pc: np.ndarray, n_samples: int, start_idx: Optional[int] = None
assert (
start_idx is None or 0 <= start_idx < n_pts
), "start_idx should be None or 0 <= start_idx < n_pts"
pc = pc.astype(np.float32)
# best performance with fortran array
pc = np.asfortranarray(pc)
pc = np.asfortranarray(pc, dtype=np.float32)
# Random pick a start
start_idx = np.random.randint(low=0, high=n_pts) if start_idx is None else start_idx
return _fps_sampling(pc, n_samples, start_idx)
Expand Down Expand Up @@ -60,7 +59,7 @@ def fps_npdu_sampling(
assert (
start_idx is None or 0 <= start_idx < n_pts
), "start_idx should be None or 0 <= start_idx < n_pts"
pc = pc.astype(np.float32)
pc = np.ascontiguousarray(pc, dtype=np.float32)
w = w or int(n_pts / n_samples * 16)
if w >= n_pts - 1:
warnings.warn(f"k is too large, set to {n_pts - 1}")
Expand Down Expand Up @@ -93,7 +92,7 @@ def fps_npdu_kdtree_sampling(
assert (
start_idx is None or 0 <= start_idx < n_pts
), "start_idx should be None or 0 <= start_idx < n_pts"
pc = pc.astype(np.float32)
pc = np.ascontiguousarray(pc, dtype=np.float32)
w = w or int(n_pts / n_samples * 16)
if w >= n_pts:
warnings.warn(f"k is too large, set to {n_pts}")
Expand Down Expand Up @@ -123,7 +122,7 @@ def bucket_fps_kdtree_sampling(
assert (
start_idx is None or 0 <= start_idx < n_pts
), "start_idx should be None or 0 <= start_idx < n_pts"
pc = pc.astype(np.float32)
pc = np.ascontiguousarray(pc, dtype=np.float32)
# Random pick a start
start_idx = np.random.randint(low=0, high=n_pts) if start_idx is None else start_idx
return _bucket_fps_kdtree_sampling(pc, n_samples, start_idx)
Expand Down Expand Up @@ -155,7 +154,7 @@ def bucket_fps_kdline_sampling(
assert (
start_idx is None or 0 <= start_idx < n_pts
), "start_idx should be None or 0 <= start_idx < n_pts"
pc = pc.astype(np.float32)
pc = np.ascontiguousarray(pc, dtype=np.float32)
# Random pick a start
start_idx = np.random.randint(low=0, high=n_pts) if start_idx is None else start_idx
return _bucket_fps_kdline_sampling(pc, n_samples, h, start_idx)
8 changes: 4 additions & 4 deletions src/bucket_fps/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ pub fn bucket_fps_kdtree_sampling(
start_idx: usize,
) -> Array1<usize> {
let[p, c] = points.shape() else {panic !("points must be a 2D array")};
let raw_data = points.as_standard_layout().as_ptr();
let raw_data = points.as_standard_layout();
let mut sampled_point_indices = vec![0; n_samples];
let ret_code;
unsafe {
ret_code = ffi::bucket_fps_kdtree(
raw_data,
raw_data.as_ptr(),
*p,
*c,
n_samples,
Expand All @@ -33,12 +33,12 @@ pub fn bucket_fps_kdline_sampling(
start_idx: usize,
) -> Array1<usize> {
let[p, c] = points.shape() else {panic !("points must be a 2D array")};
let raw_data = points.as_standard_layout().as_ptr();
let raw_data = points.as_standard_layout();
let mut sampled_point_indices = vec![0; n_samples];
let ret_code;
unsafe {
ret_code = ffi::bucket_fps_kdline(
raw_data,
raw_data.as_ptr(),
*p,
*c,
n_samples,
Expand Down

0 comments on commit 8e99afd

Please sign in to comment.