Skip to content

Commit

Permalink
Better error handling when sleep duration is less than 30 mins
Browse files Browse the repository at this point in the history
  • Loading branch information
angerhang committed Oct 18, 2023
1 parent 8839562 commit 1f48683
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions src/asleep/get_sleep.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,23 +310,30 @@ def main():
(binary_y, all_sleep_wins_df, sleep_wins_long_per_day_df, master_acc,
master_npids) = get_sleep_windows(data2model, times, non_wear, args)

y_pred, test_pids = start_sleep_net(
master_acc, master_npids, args.outdir,
args.model_weight_path, local_repo_path=args.local_repo_path,
device_id=args.pytorch_device)
sleepnet_output = binary_y
if len(master_npids) <= 0:
print("No sleep windows >30 mins detected. Exiting...")
print("Non-wear time has been written to %s" % non_wear_path)
print("Current sleep classification has been written to %s" %
os.path.join(args.outdir, 'ssl_sleep.npy'))
exit()
else:
y_pred, test_pids = start_sleep_net(
master_acc, master_npids, args.outdir,
args.model_weight_path, local_repo_path=args.local_repo_path,
device_id=args.pytorch_device)
sleepnet_output = binary_y

for block_id in range(len(all_sleep_wins_df)):
start_t = all_sleep_wins_df.iloc[block_id]['start']
end_t = all_sleep_wins_df.iloc[block_id]['end']
for block_id in range(len(all_sleep_wins_df)):
start_t = all_sleep_wins_df.iloc[block_id]['start']
end_t = all_sleep_wins_df.iloc[block_id]['end']

time_filter = (times >= start_t) & (times < end_t)
time_filter = (times >= start_t) & (times < end_t)

# get the corresponding sleepnet predictions
sleepnet_pred = y_pred[test_pids == block_id]
# get the corresponding sleepnet predictions
sleepnet_pred = y_pred[test_pids == block_id]

# fill the sleepnet predictions back to the original dataframe
sleepnet_output[time_filter] = sleepnet_pred
# fill the sleepnet predictions back to the original dataframe
sleepnet_output[time_filter] = sleepnet_pred

# 3. Skip this step if predictions already exist
# Output pandas dataframe
Expand Down

0 comments on commit 1f48683

Please sign in to comment.