Skip to content

Commit

Permalink
chore: merge dev into release
Browse files Browse the repository at this point in the history
chore: merge dev into release
  • Loading branch information
versey-sherry authored Apr 5, 2023
2 parents 879456c + be14f6b commit dd4d572
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 22 deletions.
2 changes: 1 addition & 1 deletion moseq2_pca/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = 'v1.1.3'
__version__ = 'v1.2.0'
9 changes: 9 additions & 0 deletions moseq2_pca/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,17 @@ def common_dask_parameters(function):
@click.option('--output-file', default='pca', type=str, help='Name of h5 file for storing pca results')
@click.option('--local-processes', default=False, type=bool, help='Used with a local cluster. If True: use processes, If False: use threads')
@click.option('--overwrite-pca-train', default=False, type=bool, help='Used to bypass the pca overwrite question. If True: skip question, run automatically')
@click.option('--camera-type', default='k2', type=str, help='specify the camera type (k2 or azure), default is k2')
def train_pca(input_dir, output_dir, output_file, **cli_args):
# function writes output pca path to config_data
if cli_args.get('camera_type') == 'azure':
# check if parameters are set to k2 default, change to azure default
print('Updating parameters for Azure Kinect camera...')
if cli_args['gaussfilter_space'] == (1.5, 1):
cli_args['gaussfilter_space'] = (2.25, 1.5)
if cli_args['tailfilter_size'] == (9, 9):
cli_args['tailfilter_size'] = (13, 13)

config_data = train_pca_wrapper(input_dir, cli_args, output_dir, output_file)
# write config_data to config_file if there is one
if cli_args.get('config_file'):
Expand Down
31 changes: 16 additions & 15 deletions moseq2_pca/helpers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,44 @@ def get_pca_paths(config_data, output_dir):
pca_file_scores (str): path to pca_scores file
"""

# Get path to pre-computed PCA file
pca_file_components = join(output_dir, 'pca.h5')
if 'pca_file_components' not in config_data:
config_data['pca_file_components'] = pca_file_components
elif config_data['pca_file_components'] is not None:
pca_file_components = config_data['pca_file_components']
# Check if there is PCA file from config_data
if config_data.get('pca_file', None) is not None:
pca_file = config_data['pca_file']
else:
# Assume PCA file is in output_dir
pca_file = join(output_dir, 'pca.h5')
config_data['pca_file'] = pca_file

if not exists(pca_file_components):
raise IOError(f'Could not find PCA components file {pca_file_components}')
if not exists(pca_file):
raise IOError(f'Could not find PCA components file {pca_file}')

# Get path to PCA Scores
pca_file_scores = config_data.get('pca_file_scores', join(output_dir, 'pca_scores.h5'))
config_data['pca_file_scores'] = pca_file_scores

return config_data, pca_file_components, pca_file_scores
return config_data, pca_file, pca_file_scores

def load_pcs_for_cp(pca_file_components, config_data):
def load_pcs_for_cp(pca_file, config_data):
"""
Load computed Principal Components for Model-free Changepoint Analysis.
Args:
pca_file_components (str): path to pca h5 file to read PCs
pca_file (str): path to pca h5 file to read PCs
config_data (dict): config parameters
Returns:
pca_components (str): path to pca components
pca_file (str): path to pca components
changepoint_params (dict): dict of relevant changepoint parameters
missing_data (bool): Indicates whether to use mask_params for missing data pca
mask_params (dict): Mask parameters to use when computing CPs
"""

print(f'Loading PCs from {pca_file_components}')
with h5py.File(pca_file_components, 'r') as f:
print(f'Loading PCs from {pca_file}')
with h5py.File(pca_file, 'r') as f:
pca_components = f[config_data['pca_path']][()]

# get the yaml for pca, check parameters, if we used fft, be sure to turn on here...
pca_yaml = splitext(pca_file_components)[0] + '.yaml'
pca_yaml = splitext(pca_file)[0] + '.yaml'

if exists(pca_yaml):
with open(pca_yaml, 'r') as f:
Expand Down
6 changes: 3 additions & 3 deletions moseq2_pca/helpers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def apply_pca_wrapper(input_dir, config_data, output_dir, output_file):
config_data, pca_file, pca_file_scores = get_pca_paths(config_data, output_dir)

print('Loading PCs from', pca_file)
with h5py.File(config_data['pca_file_components'], 'r') as f:
with h5py.File(config_data['pca_file'], 'r') as f:
pca_components = f[config_data['pca_path']][()]

# Get the yaml for pca, check parameters, if we used fft, be sure to turn on here...
Expand Down Expand Up @@ -313,10 +313,10 @@ def compute_changepoints_wrapper(input_dir, config_data, output_dir, output_file
save_file = join(output_dir, output_file)

# Get paths to PCA, PCA Scores file
config_data, pca_file_components, pca_file_scores = get_pca_paths(config_data, output_dir)
config_data, pca_file, pca_file_scores = get_pca_paths(config_data, output_dir)

# Load Principal components, set up changepoint parameter dict, and optionally load reconstructed PCs.
pca_components, changepoint_params, missing_data, mask_params = load_pcs_for_cp(pca_file_components, config_data)
pca_components, changepoint_params, missing_data, mask_params = load_pcs_for_cp(pca_file, config_data)

# Initialize Dask client
client, cluster, workers = \
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_compute_changepoints(self):
with open(config, 'r') as f:
config_data = yaml.safe_load(f)

config_data['pca_file_components'] = None
config_data['pca_file'] = None
config_data['pca_file_scores'] = None

with open(config, 'w') as f:
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_apply_pca_command(self):

with open(config_file, 'r') as f:
config_data = yaml.safe_load(f)
config_data['pca_file_components'] = join(outpath, 'pca.h5')
config_data['pca_file'] = join(outpath, 'pca.h5')

config_data['use_fft'] = True
config_data['missing_data'] = False
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_compute_changepoints_command(self):

with open(config_file, 'r') as f:
config_data = yaml.safe_load(f)
config_data['pca_file_components'] = join(outpath, 'pca.h5')
config_data['pca_file'] = join(outpath, 'pca.h5')

config_data['use_fft'] = True
config_data['missing_data'] = False
Expand Down

0 comments on commit dd4d572

Please sign in to comment.