Skip to content

Commit

Permalink
make sure the video have min(frame_count, 64) frames
Browse files Browse the repository at this point in the history
  • Loading branch information
leftthomas committed May 1, 2019
1 parent 09ac160 commit 72718b0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
3 changes: 2 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def center_crop(image):

# read video
cap, retaining, clips = cv2.VideoCapture(VIDEO_NAME), True, []
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
while retaining:
retaining, frame = cap.read()
if not retaining and frame is None:
Expand All @@ -55,7 +56,7 @@ def center_crop(image):
tmp_ = center_crop(cv2.resize(frame, (resize_width, resize_height)))
tmp = tmp_.astype(np.float32) / 255.0
clips.append(tmp)
if len(clips) == clip_len:
if len(clips) == clip_len or len(clips) == frame_count:
inputs = np.array(clips)
inputs = np.expand_dims(inputs, axis=0)
inputs = np.transpose(inputs, (0, 4, 1, 2, 3))
Expand Down
10 changes: 8 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ def load_frames(self, file_dir):
def crop(self, buffer, clip_len, crop_size):
if self.split == 'train':
# randomly select time index for temporal jitter
time_index = np.random.randint(buffer.shape[0] - clip_len)
if buffer.shape[0] > clip_len:
time_index = np.random.randint(buffer.shape[0] - clip_len)
else:
time_index = 0
# randomly select start indices in order to crop the video
height_index = np.random.randint(buffer.shape[1] - crop_size)
width_index = np.random.randint(buffer.shape[2] - crop_size)
Expand All @@ -172,7 +175,10 @@ def crop(self, buffer, clip_len, crop_size):
# jitter takes place via the selection of consecutive frames
else:
# for val and test, select the middle and center frames
time_index = math.floor((buffer.shape[0] - clip_len) / 2)
if buffer.shape[0] > clip_len:
time_index = math.floor((buffer.shape[0] - clip_len) / 2)
else:
time_index = 0
height_index = math.floor((buffer.shape[1] - crop_size) / 2)
width_index = math.floor((buffer.shape[2] - crop_size) / 2)
buffer = buffer[time_index:time_index + clip_len, height_index:height_index + crop_size,
Expand Down

0 comments on commit 72718b0

Please sign in to comment.