Skip to content

Commit 620ba99

Browse files
committed
feat: add scene index and search
1 parent df45aea commit 620ba99

File tree

4 files changed

+103
-9
lines changed

4 files changed

+103
-9
lines changed

videodb/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from videodb._constants import (
99
VIDEO_DB_API,
1010
MediaType,
11+
SceneModels,
1112
SearchType,
1213
SubtitleAlignment,
1314
SubtitleBorderStyle,
@@ -37,6 +38,7 @@
3738
"SubtitleAlignment",
3839
"SubtitleBorderStyle",
3940
"SubtitleStyle",
41+
"SceneModels",
4042
]
4143

4244

videodb/_constants.py

+9
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,18 @@ class MediaType:
1414
class SearchType:
1515
semantic = "semantic"
1616
keyword = "keyword"
17+
scene = "scene"
1718

1819

1920
class IndexType:
2021
semantic = "semantic"
22+
scene = "scene"
23+
24+
25+
class SceneModels:
26+
gemini_vision: str = "gemini-vision"
27+
gpt4_vision: str = "gpt4-v"
28+
all: str = "all"
2129

2230

2331
class Workflows:
@@ -44,6 +52,7 @@ class ApiPath:
4452
compile = "compile"
4553
workflow = "workflow"
4654
timeline = "timeline"
55+
delete = "delete"
4756

4857

4958
class Status:

videodb/search.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def search_inside_video(
112112
result_threshold: Optional[int] = None,
113113
score_threshold: Optional[int] = None,
114114
dynamic_score_percentage: Optional[int] = None,
115+
**kwargs,
115116
):
116117
search_data = self._connection.post(
117118
path=f"{ApiPath.video}/{video_id}/{ApiPath.search}",
@@ -133,6 +134,7 @@ def search_inside_collection(
133134
result_threshold: Optional[int] = None,
134135
score_threshold: Optional[int] = None,
135136
dynamic_score_percentage: Optional[int] = None,
137+
**kwargs,
136138
):
137139
search_data = self._connection.post(
138140
path=f"{ApiPath.collection}/{collection_id}/{ApiPath.search}",
@@ -176,7 +178,42 @@ def search_inside_collection(**kwargs):
176178
raise NotImplementedError("Keyword search will be implemented in the future")
177179

178180

179-
search_type = {SearchType.semantic: SemanticSearch, SearchType.keyword: KeywordSearch}
181+
class SceneSearch(Search):
182+
def __init__(self, _connection):
183+
self._connection = _connection
184+
185+
def search_inside_video(
186+
self,
187+
video_id: str,
188+
query: str,
189+
scene_model: Optional[str] = None,
190+
result_threshold: Optional[int] = None,
191+
score_threshold: Optional[int] = None,
192+
dynamic_score_percentage: Optional[int] = None,
193+
**kwargs,
194+
):
195+
search_data = self._connection.post(
196+
path=f"{ApiPath.video}/{video_id}/{ApiPath.search}",
197+
data={
198+
"index_type": SearchType.scene,
199+
"query": query,
200+
"model_name": scene_model,
201+
"score_threshold": score_threshold,
202+
"result_threshold": result_threshold,
203+
},
204+
)
205+
print(search_data)
206+
return SearchResult(self._connection, **search_data)
207+
208+
def search_inside_collection(**kwargs):
209+
raise NotImplementedError("Scene search will be implemented in the future")
210+
211+
212+
search_type = {
213+
SearchType.semantic: SemanticSearch,
214+
SearchType.keyword: KeywordSearch,
215+
SearchType.scene: SceneSearch,
216+
}
180217

181218

182219
class SearchFactory:

videodb/video.py

+54-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22
from videodb._utils._video import play_stream
33
from videodb._constants import (
44
ApiPath,
5-
SearchType,
65
IndexType,
7-
Workflows,
6+
SceneModels,
7+
SearchType,
88
SubtitleStyle,
9+
Workflows,
910
)
1011
from videodb.search import SearchFactory, SearchResult
1112
from videodb.shot import Shot
@@ -24,6 +25,7 @@ def __init__(self, _connection, id: str, collection_id: str, **kwargs) -> None:
2425
self.length = float(kwargs.get("length", 0.0))
2526
self.transcript = kwargs.get("transcript", None)
2627
self.transcript_text = kwargs.get("transcript_text", None)
28+
self.scenes = kwargs.get("scenes", None)
2729

2830
def __repr__(self) -> str:
2931
return (
@@ -45,17 +47,19 @@ def search(
4547
self,
4648
query: str,
4749
search_type: Optional[str] = SearchType.semantic,
50+
scene_model: Optional[str] = SceneModels.gemini_vision,
4851
result_threshold: Optional[int] = None,
4952
score_threshold: Optional[int] = None,
5053
dynamic_score_percentage: Optional[int] = None,
5154
) -> SearchResult:
5255
search = SearchFactory(self._connection).get_search(search_type)
5356
return search.search_inside_video(
54-
self.id,
55-
query,
56-
result_threshold,
57-
score_threshold,
58-
dynamic_score_percentage,
57+
video_id=self.id,
58+
query=query,
59+
result_threshold=result_threshold,
60+
score_threshold=score_threshold,
61+
dynamic_score_percentage=dynamic_score_percentage,
62+
scene_model=scene_model,
5963
)
6064

6165
def delete(self) -> None:
@@ -130,6 +134,48 @@ def index_spoken_words(self) -> None:
130134
},
131135
)
132136

137+
def index_scenes(
138+
self,
139+
scene_model: str = SceneModels.gemini_vision,
140+
force: bool = False,
141+
prompt: str = None,
142+
callback_url: str = None,
143+
) -> None:
144+
self._connection.post(
145+
path=f"{ApiPath.video}/{self.id}/{ApiPath.index}",
146+
data={
147+
"index_type": IndexType.scene,
148+
"model_name": scene_model,
149+
"force": force,
150+
"prompt": prompt,
151+
"callback_url": callback_url,
152+
},
153+
)
154+
155+
def get_scenes(
156+
self, scene_model: str = SceneModels.gemini_vision
157+
) -> Union[list, None]:
158+
if self.scenes:
159+
return self.scenes
160+
scene_data = self._connection.get(
161+
path=f"{ApiPath.video}/{self.id}/{ApiPath.index}",
162+
params={
163+
"index_type": IndexType.scene,
164+
"model_name": scene_model,
165+
},
166+
)
167+
self.scenes = scene_data
168+
return scene_data if scene_data else None
169+
170+
def delete_scene_index(self, scene_model: str = SceneModels.gemini_vision) -> None:
171+
self._connection.post(
172+
path=f"{ApiPath.video}/{self.id}/{ApiPath.index}/{ApiPath.delete}",
173+
data={
174+
"index_type": IndexType.scene,
175+
"model_name": scene_model,
176+
},
177+
)
178+
133179
def add_subtitle(self, style: SubtitleStyle = SubtitleStyle()) -> str:
134180
if not isinstance(style, SubtitleStyle):
135181
raise ValueError("style must be of type SubtitleStyle")

0 commit comments

Comments
 (0)