Skip to content

Commit

Permalink
Merge pull request #5 from LSSTDESC/issue/4/FlowCreator
Browse files Browse the repository at this point in the history
Issue/4/flow creator
  • Loading branch information
eacharles authored Sep 8, 2022
2 parents b2e5acc + 2d0ab96 commit a25333f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"pz-rail",
"pz-rail-hub",
"pz-rail-hub>=0.0.3",
]

[project.optional-dependencies]
Expand Down
10 changes: 5 additions & 5 deletions src/rail/pipelines/examples/goldenspike/goldenspike.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self):
band_dict = {band:f'mag_{band}_lsst' for band in bands}
rename_dict = {f'mag_{band}_lsst_err':f'mag_err_{band}_lsst' for band in bands}

self.flow_engine_train = FlowEngine.build(
flow=flow_file,
self.flow_engine_train = FlowCreator.build(
model=flow_file,
n_samples=50,
seed=1235,
output=os.path.join(namer.get_data_dir(DataType.catalog, CatalogType.created), "output_flow_engine_train.pq"),
Expand Down Expand Up @@ -73,8 +73,8 @@ def __init__(self):
output=os.path.join(namer.get_data_dir(DataType.catalog, CatalogType.degraded), "output_table_conv_train.hdf5"),
)

self.flow_engine_test = FlowEngine.build(
flow=flow_file,
self.flow_engine_test = FlowCreator.build(
model=flow_file,
n_samples=50,
output=os.path.join(namer.get_data_dir(DataType.catalog, CatalogType.degraded), "output_flow_engine_test.pq"),
)
Expand Down Expand Up @@ -176,5 +176,5 @@ def __init__(self):

if __name__ == '__main__':
pipe = GoldenspikePipeline()
pipe.initialize(dict(flow=flow_file), dict(output_dir='.', log_dir='.', resume=False), None)
pipe.initialize(dict(model=flow_file), dict(output_dir='.', log_dir='.', resume=False), None)
pipe.save('tmp_goldenspike.yml')
2 changes: 1 addition & 1 deletion tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ def test_golden():
except OSError: # pragma: no cover
pass
pipe = GoldenspikePipeline()
pipe.initialize(dict(flow=flow_file), dict(output_dir=output_dir, log_dir=output_dir, resume=False), None)
pipe.initialize(dict(model=flow_file), dict(output_dir=output_dir, log_dir=output_dir, resume=False), None)
pipe.save('tmp_goldenspike.yml')
os.system(f"\\rm -rf {output_dir}")

0 comments on commit a25333f

Please sign in to comment.