Skip to content

Commit

Permalink
corrected more GPU logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ptbrown1729 committed Apr 27, 2023
1 parent 7a601a1 commit 62f89f9
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions mcsim/analysis/sim_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,9 @@ def estimate_parameters(self,
raise ValueError(f"frq_estimation_mode=`band-correlation`, but this requires phase guesses,"
f"and no phase guesses were provided")

mempool = cp.get_default_memory_pool()
memory_start = mempool.used_bytes()
if self.use_gpu:
mempool = cp.get_default_memory_pool()
memory_start = mempool.used_bytes()

self.print_log("starting parameter estimation...")

Expand Down Expand Up @@ -913,10 +914,11 @@ def estimate_parameters(self,
phases = np.array(phases)
amps = np.array(amps)

# find this is necessary, else mempool gets too big for 8GB GPU's
mempool.free_all_blocks()
# self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
if self.use_gpu:
# find this is necessary, else mempool gets too big for 8GB GPU's
mempool.free_all_blocks()
# self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")

elif self._recon_settings["phase_estimation_mode"] == "real-space":
phase_guess = self.phases_guess
Expand All @@ -941,10 +943,11 @@ def estimate_parameters(self,
phases = np.array(results).reshape((self.nangles, self.nphases))
amps = np.ones((self.nangles, self.nphases))

# find this is necessary, else mempool gets too big for 8GB GPU's
mempool.free_all_blocks()
# self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
if self.use_gpu:
# find this is necessary, else mempool gets too big for 8GB GPU's
mempool.free_all_blocks()
# self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
else:
raise ValueError(f"phase_estimation_mode must be one of {self.allowed_phase_estimation_modes}"
f" but was '{self._recon_settings['phase_estimation_mode']:s}'")
Expand Down Expand Up @@ -997,9 +1000,10 @@ def estimate_parameters(self,
(self.dy, self.dx),
self.upsample_fact)

# self.print_log(f"after upsampling and band shifting used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
mempool.free_all_blocks()
if self.use_gpu:
# self.print_log(f"after upsampling and band shifting used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
mempool.free_all_blocks()

# upsample and shift OTFs
otf_us = resample_bandlimited_ft(self.otf,
Expand Down Expand Up @@ -1058,9 +1062,10 @@ def estimate_parameters(self,

self.print_log(f"estimated global phases and modulation depths in {time.perf_counter() - tstart_mod_depth:.2f}s")

# self.print_log(f"after phase correction used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
mempool.free_all_blocks()
if self.use_gpu:
# self.print_log(f"after phase correction used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
mempool.free_all_blocks()

del mask
del otf_shifted
Expand Down

0 comments on commit 62f89f9

Please sign in to comment.