forked from oladeha2/shot_boudary_detector
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TestVideo.py
52 lines (40 loc) · 1.88 KB
/
TestVideo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# data set object for inference of video for shot boundary detection
# videos are processed in batches of 100 frames with an overlap of 9 frames
from torch.utils.data import Dataset, DataLoader
import numpy as np
from snippet import getSnippet
from math import floor
from utilities import normalize_frame, print_shape
from IPython import embed
def return_start_and_end(idx, sample_size=100, overlap=9):
if idx == 0:
start = 0
end = start + sample_size
else:
start = (idx*sample_size) - (overlap*idx)
end = start + sample_size
return start, end
def get_len(total_frames, sample_size=100, overlap=9):
return floor(total_frames/(sample_size - overlap))
class TestVideo(Dataset):
def __init__(self, video_file, sample_size=100, overlap=9):
# video file is text file path with all frame listings
with open(video_file) as f:
lines = f.readlines()
self.lines = [line.strip() for line in lines]
self.line_number = len(self.lines)
self.sample_size = sample_size
self.overlap = overlap
def __len__(self):
return get_len(self.line_number, sample_size=self.sample_size, overlap=self.overlap)
def __getitem__(self, idx):
start, end = return_start_and_end(idx=idx, sample_size=self.sample_size, overlap=self.overlap)
end = end if end < self.line_number else self.line_number
video_snippet = np.array(getSnippet(self.lines, start, end))
# transpose the individual frames to the in the correct format and the fully returned structure
video_snippet = np.array([normalize_frame(frame) for frame in video_snippet])
video_snippet = np.array([np.transpose(frame, (2, 0, 1)) for frame in video_snippet])
video_snippet = np.transpose(video_snippet, (1, 0, 2, 3))
return video_snippet
def get_line_number(self):
return self.line_number