From 131d884a22895397145a4c4f0f85bed35c2fcbae Mon Sep 17 00:00:00 2001 From: marcosolime Date: Tue, 29 Oct 2024 14:43:01 +0900 Subject: [PATCH] fix window selection --- src/dataset/tigre_mlg.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/dataset/tigre_mlg.py b/src/dataset/tigre_mlg.py index 853f1b2..ba2f9f4 100644 --- a/src/dataset/tigre_mlg.py +++ b/src/dataset/tigre_mlg.py @@ -140,8 +140,9 @@ def __getitem__(self, index): # 全是 valid 的 window projs_window_valid_indx = ((projs_window > 0).sum(dim=-1).sum(dim=-1) == self.window_size[0] * self.window_size[1]) # 选取 window_inds - select_inds_window = np.random.choice(projs_window_valid_indx.shape[0], size=[self.window_num], replace=False) # 从 0 ~ 64-1 中选取 window_num 个值 - + valid_inds = torch.where(projs_window_valid_indx)[0] + select_inds_window = valid_inds[torch.randperm(len(valid_inds))[:self.window_num]] # shape: (window_num) + projs_window_select = projs_window[select_inds_window] # [36, 32, 32] rays_window_select = rays_window[select_inds_window] # [36, 32, 32, 8] # stx()