forked from susheelsk/image-background-removal
-
Notifications
You must be signed in to change notification settings - Fork 278
/
Copy pathconftest.py
182 lines (141 loc) · 4.9 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
from pathlib import Path
import pytest
import torch
from PIL import Image
from typing import Callable, Tuple, List, Union, Optional, Any
from carvekit.api.high import HiInterface
from carvekit.api.interface import Interface
from carvekit.trimap.cv_gen import CV2TrimapGenerator
from carvekit.trimap.generator import TrimapGenerator
from carvekit.utils.image_utils import convert_image, load_image
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.ml.wrap.u2net import U2NET
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
@pytest.fixture()
def u2net_model() -> Callable[[bool], U2NET]:
return lambda fb16: U2NET(
layers_cfg="full",
device="cuda" if torch.cuda.is_available() else "cpu",
input_image_size=320,
batch_size=10,
load_pretrained=True,
fp16=fb16,
)
@pytest.fixture()
def tracer_model() -> Callable[[bool], TracerUniversalB7]:
return lambda fb16: TracerUniversalB7(
device="cuda" if torch.cuda.is_available() else "cpu",
input_image_size=320,
batch_size=10,
load_pretrained=True,
fp16=fb16,
)
@pytest.fixture()
def trimap_instance() -> Callable[[], TrimapGenerator]:
return lambda: TrimapGenerator()
@pytest.fixture()
def cv2_trimap_instance() -> Callable[[], CV2TrimapGenerator]:
return lambda: CV2TrimapGenerator(kernel_size=30, erosion_iters=0)
@pytest.fixture()
def preprocessing_stub_instance() -> Callable[[], PreprocessingStub]:
return lambda: PreprocessingStub()
@pytest.fixture()
def matting_method_instance(fba_model, trimap_instance):
return lambda: MattingMethod(
matting_module=fba_model(False),
trimap_generator=trimap_instance(),
device="cpu",
)
@pytest.fixture()
def high_interface_instance() -> Callable[[], HiInterface]:
return lambda: HiInterface(
batch_size_seg=5,
batch_size_matting=1,
device="cuda" if torch.cuda.is_available() else "cpu",
seg_mask_size=320,
matting_mask_size=2048,
)
@pytest.fixture()
def interface_instance(
u2net_model, preprocessing_stub_instance, matting_method_instance
) -> Callable[[], Interface]:
return lambda: Interface(
u2net_model(False),
pre_pipe=preprocessing_stub_instance(),
post_pipe=matting_method_instance(),
device="cuda" if torch.cuda.is_available() else "cpu",
)
@pytest.fixture()
def fba_model() -> Callable[[bool], FBAMatting]:
return lambda fp16: FBAMatting(
device="cuda" if torch.cuda.is_available() else "cpu",
input_tensor_size=1024,
batch_size=2,
load_pretrained=True,
fp16=fp16,
)
@pytest.fixture()
def deeplabv3_model() -> Callable[[bool], DeepLabV3]:
return lambda fp16: DeepLabV3(
device="cuda" if torch.cuda.is_available() else "cpu",
batch_size=10,
load_pretrained=True,
fp16=fp16,
)
@pytest.fixture()
def basnet_model() -> Callable[[bool], BASNET]:
return lambda fp16: BASNET(
device="cuda" if torch.cuda.is_available() else "cpu",
input_image_size=320,
batch_size=10,
load_pretrained=True,
fp16=fp16,
)
@pytest.fixture()
def image_str(image_path) -> str:
return str(image_path.absolute())
@pytest.fixture()
def image_path() -> Path:
return Path(__file__).parent.joinpath("tests").joinpath("data", "cat.jpg")
@pytest.fixture()
def image_mask(image_path) -> Image.Image:
return Image.open(image_path.with_name("cat_mask").with_suffix(".png"))
@pytest.fixture()
def image_trimap(image_path) -> Image.Image:
return Image.open(image_path.with_name("cat_trimap").with_suffix(".png")).convert(
"L"
)
@pytest.fixture()
def image_pil(image_path) -> Image.Image:
return Image.open(image_path)
@pytest.fixture()
def black_image_pil() -> Image.Image:
return Image.new("RGB", (512, 512))
@pytest.fixture()
def converted_pil_image(image_pil) -> Image.Image:
return convert_image(load_image(image_pil))
@pytest.fixture()
def available_models(
u2net_model,
deeplabv3_model,
basnet_model,
preprocessing_stub_instance,
matting_method_instance,
) -> Tuple[
List[Union[Callable[[], U2NET], Callable[[], DeepLabV3], Callable[[], BASNET]]],
List[Optional[Callable[[], PreprocessingStub]]],
List[Union[Optional[Callable[[], MattingMethod]], Any]],
]:
models = [u2net_model, deeplabv3_model, basnet_model]
pre_pipes = [None, preprocessing_stub_instance]
post_pipes = [None, matting_method_instance]
return models, pre_pipes, post_pipes