Skip to content

Commit

Permalink
Merge pull request #949 from astronomerritt/indexing_sql
Browse files Browse the repository at this point in the history
Adding indexes to SQL outputs.
  • Loading branch information
mschwamb authored Jun 3, 2024
2 parents 9d63109 + 36ea33f commit 81cb463
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
20 changes: 15 additions & 5 deletions src/sorcha/modules/PPOutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def PPOutWriteHDF5(pp_results, outf, keyin):
return of


def PPOutWriteSqlite3(pp_results, outf):
def PPOutWriteSqlite3(pp_results, outf, lastchunk=False, tablename="sorcha_results"):
"""
Writes a pandas dataframe out to a CSV file at a location given by the user.
Expand All @@ -90,12 +90,22 @@ def PPOutWriteSqlite3(pp_results, outf):

cnx = sqlite3.connect(outf)

pp_results.to_sql("sorcha_results", con=cnx, if_exists="append", index=False)
pp_results.to_sql(tablename, con=cnx, if_exists="append", index=False)

pplogger.info("SQL results saved in table sorcha_results in database {}.".format(outf))
# we don't want to index the table until we're sure we're done appending to it
# as recreating the indexes on every append is slow
if lastchunk:
pplogger.info("Last chunk detected. Indexing SQL table...")
cur = cnx.cursor()
cur.execute("CREATE INDEX ObjID ON {} (ObjID)".format(tablename))
cur.execute("CREATE INDEX fieldMJD_TAI ON {} (fieldMJD_TAI)".format(tablename))
cur.execute("CREATE INDEX optFilter ON {} (optFilter)".format(tablename))
cnx.commit()

pplogger.info("SQL results saved in table {} in database {}.".format(tablename, outf))

def PPWriteOutput(cmd_args, configs, observations_in, endChunk=0, verbose=False):

def PPWriteOutput(cmd_args, configs, observations_in, endChunk=0, verbose=False, lastchunk=False):
"""
Writes the output in the format specified in the config file to a location
specified by the user.
Expand Down Expand Up @@ -214,7 +224,7 @@ def PPWriteOutput(cmd_args, configs, observations_in, endChunk=0, verbose=False)
outputsuffix = ".db"
out = os.path.join(cmd_args.outpath, cmd_args.outfilestem + outputsuffix)
verboselog("Output to sqlite3 database...")
observations = PPOutWriteSqlite3(observations, out)
observations = PPOutWriteSqlite3(observations, out, lastchunk)

elif configs["output_format"] == "hdf5" or configs["output_format"] == "h5":
outputsuffix = ".h5"
Expand Down
6 changes: 5 additions & 1 deletion src/sorcha/sorcha.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def runLSSTSimulation(args, configs):
startChunk = 0
endChunk = 0
loopCounter = 0
lastChunk = False

ii = -1
with open(args.orbinfile) as f:
Expand All @@ -178,6 +179,9 @@ def runLSSTSimulation(args, configs):
endChunk = startChunk + configs["size_serial_chunk"]
verboselog("Working on objects {}-{}".format(startChunk, endChunk))

if endChunk >= lenf:
lastChunk = True

# Processing begins, all processing is done for chunks
if configs["ephemerides_type"].casefold() == "external":
verboselog("Reading in chunk of orbits and associated ephemeris from an external file")
Expand Down Expand Up @@ -334,7 +338,7 @@ def runLSSTSimulation(args, configs):
pplogger.info("Output results for this chunk")

# write output
PPWriteOutput(args, configs, observations, endChunk, verbose=args.verbose)
PPWriteOutput(args, configs, observations, endChunk, verbose=args.verbose, lastchunk=lastChunk)

startChunk = startChunk + configs["size_serial_chunk"]
loopCounter = loopCounter + 1
Expand Down
9 changes: 8 additions & 1 deletion tests/sorcha/test_PPOutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_PPWriteOutput_sql(tmp_path):
dtype=object,
)

PPWriteOutput(args, configs, observations, 10)
PPWriteOutput(args, configs, observations, 10, lastchunk=True)
cnx = sqlite3.connect(os.path.join(tmp_path, "PPOutput_test_out.db"))
cur = cnx.cursor()
cur.execute("select * from sorcha_results")
Expand All @@ -130,6 +130,13 @@ def test_PPWriteOutput_sql(tmp_path):

assert_equal(sql_test_in.loc[0, :].values, expected)

# check indexes were properly created
cur.execute("PRAGMA index_list('sorcha_results')")
indexes = cur.fetchall()

index_list = [indexes[i][1] for i in range(0, 3)]
assert index_list == ["optFilter", "fieldMJD_TAI", "ObjID"]


def test_PPWriteOutput_all(tmp_path):
# additional test to ensure that "all" output option and no rounding works
Expand Down

0 comments on commit 81cb463

Please sign in to comment.