diff --git a/tests/unit_test.py b/tests/unit_test.py index ef336e3..56dc946 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -169,6 +169,34 @@ def test_batch_process(self): # assertion self.assertEqual(True, ret) + @idata(( + [np.ones([5, 5, 3, 1]), np.ones([5, 5, 3, 1]), False], + [np.ones([5, 5, 3]), np.ones([5, 5, 3]), True], + [np.ones([5, 5, 3]), np.ones([9, 9, 3]), True], + [np.ones([5, 5, 3]), np.ones([2, 2, 3]), True], + [np.ones([5, 5, 1]), np.ones([5, 5, 1]), False], + [np.ones([5, 5, 2]), np.ones([5, 5, 2]), False], + [np.ones([5, 5, 3]), np.ones([5, 5, 1]), False], + [np.ones([5, 5, 1]), np.ones([5, 5, 3]), False], + [np.ones([5, 5]), np.ones([5, 5]), False], + [np.ones([5]), np.ones([5]), False], + [np.ones([1]), np.ones([1]), False], + [np.ones([0]), np.ones([0]), False], + [None, None, False] + )) + @unpack + def test_img_dims(self, src_img, img_ref, exp_val): + + try: + obj = ColorMatcher(src=src_img, ref=img_ref) + res = obj.main() + ret = res.mean().astype('bool') + msg = '' + except BaseException as e: + ret = False + msg = e + + self.assertEqual(exp_val, ret, msg=msg) if __name__ == '__main__': unittest.main()