diff --git a/tests/estimators/object_detection/conftest.py b/tests/estimators/object_detection/conftest.py index 69dd17d9d9..ddbdd22a93 100644 --- a/tests/estimators/object_detection/conftest.py +++ b/tests/estimators/object_detection/conftest.py @@ -287,14 +287,16 @@ def get_pytorch_detr(get_default_cifar10_subset): "scores": np.ones_like(result[0]["labels"]), }, { - "boxes": result[0]["boxes"], - "labels": result[0]["labels"], - "scores": np.ones_like(result[0]["labels"]), + "boxes": result[1]["boxes"], + "labels": result[1]["labels"], + "scores": np.ones_like(result[1]["labels"]), }, ] + y_test[0]["scores"] = y_test[0]["scores"] * 0.5 + y_test[1]["scores"] = y_test[1]["scores"] * 0.5 print("y_test['scores'].shape") - print(y_test[0]["scores"].shape) - print(y_test[1]["scores"].shape) + print(y_test[0]) + print(y_test[1]) yield object_detector, x_test, y_test