Skip to content

Commit

Permalink
!fix: user specified pca.h5 gets ignored
Browse files Browse the repository at this point in the history
!fix: user specified pca.h5 gets ignored
  • Loading branch information
versey-sherry authored Apr 5, 2023
2 parents 9da5e9b + a1f8a71 commit be14f6b
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 22 deletions.
2 changes: 1 addition & 1 deletion moseq2_pca/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = 'v1.1.3'
__version__ = 'v1.2.0'
31 changes: 16 additions & 15 deletions moseq2_pca/helpers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,44 @@ def get_pca_paths(config_data, output_dir):
pca_file_scores (str): path to pca_scores file
"""

# Get path to pre-computed PCA file
pca_file_components = join(output_dir, 'pca.h5')
if 'pca_file_components' not in config_data:
config_data['pca_file_components'] = pca_file_components
elif config_data['pca_file_components'] is not None:
pca_file_components = config_data['pca_file_components']
# Check if there is PCA file from config_data
if config_data.get('pca_file', None) is not None:
pca_file = config_data['pca_file']
else:
# Assume PCA file is in output_dir
pca_file = join(output_dir, 'pca.h5')
config_data['pca_file'] = pca_file

if not exists(pca_file_components):
raise IOError(f'Could not find PCA components file {pca_file_components}')
if not exists(pca_file):
raise IOError(f'Could not find PCA components file {pca_file}')

# Get path to PCA Scores
pca_file_scores = config_data.get('pca_file_scores', join(output_dir, 'pca_scores.h5'))
config_data['pca_file_scores'] = pca_file_scores

return config_data, pca_file_components, pca_file_scores
return config_data, pca_file, pca_file_scores

def load_pcs_for_cp(pca_file_components, config_data):
def load_pcs_for_cp(pca_file, config_data):
"""
Load computed Principal Components for Model-free Changepoint Analysis.
Args:
pca_file_components (str): path to pca h5 file to read PCs
pca_file (str): path to pca h5 file to read PCs
config_data (dict): config parameters
Returns:
pca_components (str): path to pca components
pca_file (str): path to pca components
changepoint_params (dict): dict of relevant changepoint parameters
missing_data (bool): Indicates whether to use mask_params for missing data pca
mask_params (dict): Mask parameters to use when computing CPs
"""

print(f'Loading PCs from {pca_file_components}')
with h5py.File(pca_file_components, 'r') as f:
print(f'Loading PCs from {pca_file}')
with h5py.File(pca_file, 'r') as f:
pca_components = f[config_data['pca_path']][()]

# get the yaml for pca, check parameters, if we used fft, be sure to turn on here...
pca_yaml = splitext(pca_file_components)[0] + '.yaml'
pca_yaml = splitext(pca_file)[0] + '.yaml'

if exists(pca_yaml):
with open(pca_yaml, 'r') as f:
Expand Down
6 changes: 3 additions & 3 deletions moseq2_pca/helpers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def apply_pca_wrapper(input_dir, config_data, output_dir, output_file):
config_data, pca_file, pca_file_scores = get_pca_paths(config_data, output_dir)

print('Loading PCs from', pca_file)
with h5py.File(config_data['pca_file_components'], 'r') as f:
with h5py.File(config_data['pca_file'], 'r') as f:
pca_components = f[config_data['pca_path']][()]

# Get the yaml for pca, check parameters, if we used fft, be sure to turn on here...
Expand Down Expand Up @@ -313,10 +313,10 @@ def compute_changepoints_wrapper(input_dir, config_data, output_dir, output_file
save_file = join(output_dir, output_file)

# Get paths to PCA, PCA Scores file
config_data, pca_file_components, pca_file_scores = get_pca_paths(config_data, output_dir)
config_data, pca_file, pca_file_scores = get_pca_paths(config_data, output_dir)

# Load Principal components, set up changepoint parameter dict, and optionally load reconstructed PCs.
pca_components, changepoint_params, missing_data, mask_params = load_pcs_for_cp(pca_file_components, config_data)
pca_components, changepoint_params, missing_data, mask_params = load_pcs_for_cp(pca_file, config_data)

# Initialize Dask client
client, cluster, workers = \
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_compute_changepoints(self):
with open(config, 'r') as f:
config_data = yaml.safe_load(f)

config_data['pca_file_components'] = None
config_data['pca_file'] = None
config_data['pca_file_scores'] = None

with open(config, 'w') as f:
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_apply_pca_command(self):

with open(config_file, 'r') as f:
config_data = yaml.safe_load(f)
config_data['pca_file_components'] = join(outpath, 'pca.h5')
config_data['pca_file'] = join(outpath, 'pca.h5')

config_data['use_fft'] = True
config_data['missing_data'] = False
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_compute_changepoints_command(self):

with open(config_file, 'r') as f:
config_data = yaml.safe_load(f)
config_data['pca_file_components'] = join(outpath, 'pca.h5')
config_data['pca_file'] = join(outpath, 'pca.h5')

config_data['use_fft'] = True
config_data['missing_data'] = False
Expand Down

0 comments on commit be14f6b

Please sign in to comment.