diff --git a/pycaption/pl_stt.py b/pycaption/pl_stt.py index d1e03de0..49d330d5 100644 --- a/pycaption/pl_stt.py +++ b/pycaption/pl_stt.py @@ -44,21 +44,42 @@ def detect(self, content): else: return False + def _guess_framerate(self, nonempty_splits): + # Try to guess the framerate by taking the highest frame number encountered across all start and end times. + # Once found, add 1 to it to get the best guess framerate, clamp it to 24fps as the minimum framerate and return it. + frame_nums = [] + for sub in nonempty_splits: + sub_start, sub_end, sub_text = self._parse_sub(sub) + frame_nums.append(int(sub_start.split(":")[-1])) + frame_nums.append(int(sub_end.split(":")[-1])) + + frame_nums = sorted(list(set(frame_nums)), reverse=True) + return float(max(24, frame_nums[0] + 1)) + def read(self, content, lang="en-US"): if type(content) != str: raise InvalidInputError("The content is not a unicode string.") try: header = self._get_header(content) + if not header: + raise InvalidInputError("Invalid or missing header.") + except: + raise InvalidInputError("Invalid or missing header.") + + framerate = None + try: framerate = float(header.get("TIME_FRAME_RATE")) except: - raise InvalidInputError("Invalid or missing header or cannot get or parse TIME_FRAME_RATE.") + framerate = None body = self._get_body(content) captions = CaptionList() all_splits = re.split(PLSTTReader.RE_SUBS_SPLIT, body, flags=re.MULTILINE) nonempty_splits = [split.strip() for split in all_splits if split and split.strip()] + if framerate is None: + framerate = self._guess_framerate(nonempty_splits) for sub in nonempty_splits: sub_start, sub_end, sub_text = self._parse_sub(sub) diff --git a/tests/samples/pl_stt.py b/tests/samples/pl_stt.py index ea8d8440..c2989b26 100644 --- a/tests/samples/pl_stt.py +++ b/tests/samples/pl_stt.py @@ -53,10 +53,6 @@ SAMPLE_PL_STT_NO_HEADER = f"""{SAMPLE_PL_STT_BODY} """ -SAMPLE_PL_STT_BAD_HEADER_1 = f"""{SAMPLE_PL_STT_HEADER_NO_FRAMERATE} -{SAMPLE_PL_STT_BODY} -""" - -SAMPLE_PL_STT_BAD_HEADER_2 = f"""{SAMPLE_PL_STT_HEADER_WRONG_FORMAT} +SAMPLE_PL_STT_BAD_HEADER_1 = f"""{SAMPLE_PL_STT_HEADER_WRONG_FORMAT} {SAMPLE_PL_STT_BODY} """ diff --git a/tests/test_pl_stt.py b/tests/test_pl_stt.py index 5e99135c..cf174c7f 100644 --- a/tests/test_pl_stt.py +++ b/tests/test_pl_stt.py @@ -8,7 +8,6 @@ SAMPLE_PL_STT, SAMPLE_PL_STT_NO_HEADER, SAMPLE_PL_STT_BAD_HEADER_1, - SAMPLE_PL_STT_BAD_HEADER_2, ) @@ -50,7 +49,3 @@ def test_no_header_file(self): def test_bad_header_1(self): self.assertRaises(InvalidInputError, self.reader.read, SAMPLE_PL_STT_BAD_HEADER_1) - self.assertRaises(InvalidInputError, self.reader.read, SAMPLE_PL_STT_BAD_HEADER_2) - - def test_bad_header_2(self): - self.assertRaises(InvalidInputError, self.reader.read, SAMPLE_PL_STT_BAD_HEADER_2)