From 7445171034708095228176812af4a6270cf8ff35 Mon Sep 17 00:00:00 2001 From: sgbaird Date: Mon, 20 Jun 2022 12:38:46 -0600 Subject: [PATCH] tqdm_if_verbose wrapper fn --- src/xtal2png/core.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/xtal2png/core.py b/src/xtal2png/core.py index 46f3ef7..892f2e9 100644 --- a/src/xtal2png/core.py +++ b/src/xtal2png/core.py @@ -206,6 +206,11 @@ def __init__( self.channels = channels self.verbose = verbose + if self.verbose: + self.tqdm_if_verbose = tqdm + else: + self.tqdm_if_verbose = lambda x: x + Path(save_dir).mkdir(exist_ok=True, parents=True) def xtal2png( @@ -480,7 +485,7 @@ def png2xtal( S = self.arrays_to_structures(data) if save: - for s in S: + for s in self.tqdm_if_verbose(S): fpath = path.join(self.save_dir, construct_save_name(s) + ".cif") CifWriter( s, @@ -554,7 +559,7 @@ def structures_to_arrays( distance_matrix_tmp: List[NDArray[np.float64]] = [] sym_structures = [] - for s in structures: + for s in self.tqdm_if_verbose(structures): spa = SpacegroupAnalyzer( s, symprec=self.encode_symprec, @@ -568,7 +573,7 @@ def structures_to_arrays( structures = sym_structures - for s in structures: + for s in self.tqdm_if_verbose(structures): n_sites = len(s.atomic_numbers) if n_sites > self.max_sites: raise ValueError( @@ -972,7 +977,8 @@ def arrays_to_structures( # build Structure-s S: List[Structure] = [] num_structures = len(atomic_numbers) - for i in range(num_structures): + + for i in self.tqdm_if_verbose(range(num_structures)): at = atomic_numbers[i] fr = frac_coords[i] site_ids = np.where(at > 0)