Skip to content

Commit

Permalink
Deepgrow - Fix probability scalar value during collate (Project-MONAI…
Browse files Browse the repository at this point in the history
…#2539)

* Deepgrow - Fix probability scalar value during collate

Signed-off-by: Sachidanand Alle <[email protected]>

* Fix CI checks

Signed-off-by: Sachidanand Alle <[email protected]>

Co-authored-by: Sachidanand Alle <[email protected]>
  • Loading branch information
SachidanandAlle and SachidanandAlle authored Jul 7, 2021
1 parent c6793fd commit b4fca1d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,15 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd
engine.fire_event(IterationEvents.INNER_ITERATION_COMPLETED)

batchdata.update({CommonKeys.PRED: predictions})
batchdata[self.key_probability] = torch.as_tensor(
([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs)
)

# decollate batch data to execute click transforms
batchdata_list = [self.transforms(i) for i in decollate_batch(batchdata, detach=True)]
batchdata_list = decollate_batch(batchdata, detach=True)
for i in range(len(batchdata_list)):
batchdata_list[i][self.key_probability] = (
(1.0 - ((1.0 / self.max_interactions) * j)) if self.train else 1.0
)
batchdata_list[i] = self.transforms(batchdata_list[i])

# collate list into a batch for next round interaction
batchdata = list_data_collate(batchdata_list)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_deepgrow_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_interaction(self, train, compose):
engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)

engine.run()
self.assertIsNotNone(engine.state.batch[0].get("probability"), "Probability is missing")
self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing")
self.assertEqual(engine.state.best_metric, 9)

def test_train_interaction(self):
Expand Down

0 comments on commit b4fca1d

Please sign in to comment.