Skip to content

Commit

Permalink
update do_seg
Browse files Browse the repository at this point in the history
  • Loading branch information
dbuscombe-usgs committed Nov 26, 2024
1 parent 1dcc146 commit 0df563e
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions doodleverse_utils/prediction_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,13 @@ def est_label_binary(image,M,MODEL,TESTTIMEAUG,NCLASSES,TARGET_SIZE,w,h):
est_label = est_label + est_label2 + est_label3 + est_label4
# del est_label2, est_label3, est_label4

est_label = est_label.numpy().astype('float32')
# est_label = est_label.numpy().astype('float32')

if not isinstance(est_label, np.ndarray):
# If not, convert it to a numpy array
est_label = est_label.numpy()
# Now, convert to 'float32'
est_label = est_label.astype('float32')

if MODEL=='segformer':
est_label = resize(est_label, (1, NCLASSES, TARGET_SIZE[0],TARGET_SIZE[1]), preserve_range=True, clip=True).squeeze()
Expand Down Expand Up @@ -396,7 +402,13 @@ def do_seg(

est_label /= counter + 1
# est_label cannot be float16 so convert to float32
est_label = est_label.numpy().astype('float32')
# est_label = est_label.numpy().astype('float32')

if not isinstance(est_label, np.ndarray):
# If not, convert it to a numpy array
est_label = est_label.numpy()
# Now, convert to 'float32'
est_label = est_label.astype('float32')

if MODEL=='segformer':
est_label = resize(est_label, (1, NCLASSES, TARGET_SIZE[0],TARGET_SIZE[1]), preserve_range=True, clip=True).squeeze()
Expand Down

0 comments on commit 0df563e

Please sign in to comment.