From 4edd2d421fd7b9cf4e835420b56206e5ed94bf43 Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 19:16:29 +0200 Subject: [PATCH 01/17] Use package structure and relative imports --- sushi/__init__.py | 0 sushi.py => sushi/__main__.py | 12 ++++++------ chapters.py => sushi/chapters.py | 2 +- common.py => sushi/common.py | 0 demux.py => sushi/demux.py | 4 ++-- keyframes.py => sushi/keyframes.py | 2 +- regression-tests.py => sushi/regression-tests.py | 8 ++++---- subs.py => sushi/subs.py | 2 +- wav.py => sushi/wav.py | 3 ++- tests/demuxing.py | 6 +++--- tests/main.py | 5 +++-- tests/subtitles.py | 3 ++- tests/timecodes.py | 3 ++- 13 files changed, 27 insertions(+), 23 deletions(-) create mode 100644 sushi/__init__.py rename sushi.py => sushi/__main__.py (99%) rename chapters.py => sushi/chapters.py (97%) rename common.py => sushi/common.py (100%) rename demux.py => sushi/demux.py (99%) rename keyframes.py => sushi/keyframes.py (89%) rename regression-tests.py => sushi/regression-tests.py (98%) rename subs.py => sushi/subs.py (99%) rename wav.py => sushi/wav.py (99%) diff --git a/sushi/__init__.py b/sushi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sushi.py b/sushi/__main__.py similarity index 99% rename from sushi.py rename to sushi/__main__.py index 65cb35d..33515c6 100755 --- a/sushi.py +++ b/sushi/__main__.py @@ -11,12 +11,12 @@ import numpy as np -import chapters -from common import SushiError, get_extension, format_time, ensure_static_collection -from demux import Timecodes, Demuxer -import keyframes -from subs import AssScript, SrtScript -from wav import WavStream +from . import chapters +from .common import SushiError, get_extension, format_time, ensure_static_collection +from .demux import Timecodes, Demuxer +from . import keyframes +from .subs import AssScript, SrtScript +from .wav import WavStream try: diff --git a/chapters.py b/sushi/chapters.py similarity index 97% rename from chapters.py rename to sushi/chapters.py index 4bf1de3..5eee928 100644 --- a/chapters.py +++ b/sushi/chapters.py @@ -1,5 +1,5 @@ import re -import common +from . import common def parse_times(times): diff --git a/common.py b/sushi/common.py similarity index 100% rename from common.py rename to sushi/common.py diff --git a/demux.py b/sushi/demux.py similarity index 99% rename from demux.py rename to sushi/demux.py index d5a5e37..bda06d8 100644 --- a/demux.py +++ b/sushi/demux.py @@ -5,8 +5,8 @@ import logging import bisect -from common import SushiError, get_extension -import chapters +from .common import SushiError, get_extension +from . import chapters MediaStreamInfo = namedtuple('MediaStreamInfo', ['id', 'info', 'default', 'title']) SubtitlesStreamInfo = namedtuple('SubtitlesStreamInfo', ['id', 'info', 'type', 'default', 'title']) diff --git a/keyframes.py b/sushi/keyframes.py similarity index 89% rename from keyframes.py rename to sushi/keyframes.py index f749393..776fa55 100644 --- a/keyframes.py +++ b/sushi/keyframes.py @@ -1,4 +1,4 @@ -from common import SushiError, read_all_text +from .common import SushiError, read_all_text def parse_scxvid_keyframes(text): diff --git a/regression-tests.py b/sushi/regression-tests.py similarity index 98% rename from regression-tests.py rename to sushi/regression-tests.py index f1161e8..efcd9bb 100644 --- a/regression-tests.py +++ b/sushi/regression-tests.py @@ -9,10 +9,10 @@ import subprocess import argparse -from common import format_time -from demux import Timecodes -from subs import AssScript -from wav import WavStream +from .common import format_time +from .demux import Timecodes +from .subs import AssScript +from .wav import WavStream root_logger = logging.getLogger('') diff --git a/subs.py b/sushi/subs.py similarity index 99% rename from subs.py rename to sushi/subs.py index 6b31d9a..66ccaa7 100644 --- a/subs.py +++ b/sushi/subs.py @@ -3,7 +3,7 @@ import re import collections -from common import SushiError, format_time, format_srt_time +from .common import SushiError, format_time, format_srt_time def _parse_ass_time(string): diff --git a/wav.py b/sushi/wav.py similarity index 99% rename from wav.py rename to sushi/wav.py index 4d4f785..e9cce25 100644 --- a/wav.py +++ b/sushi/wav.py @@ -6,7 +6,8 @@ import math from time import time import os.path -from common import SushiError, clip + +from .common import SushiError, clip WAVE_FORMAT_PCM = 0x0001 WAVE_FORMAT_EXTENSIBLE = 0xFFFE diff --git a/tests/demuxing.py b/tests/demuxing.py index 8eb4d42..8e081b2 100644 --- a/tests/demuxing.py +++ b/tests/demuxing.py @@ -1,9 +1,9 @@ import unittest import mock -from demux import FFmpeg, MkvToolnix, SCXviD -from common import SushiError -import chapters +from sushi.demux import FFmpeg, MkvToolnix, SCXviD +from sushi.common import SushiError +from sushi import chapters def create_popen_mock(): diff --git a/tests/main.py b/tests/main.py index 17b0a8c..c234d53 100644 --- a/tests/main.py +++ b/tests/main.py @@ -3,8 +3,9 @@ import re import unittest from mock import patch, ANY -from common import SushiError, format_time -import sushi + +from sushi.common import SushiError, format_time +from sushi import __main__ as sushi here = os.path.dirname(os.path.abspath(__file__)) diff --git a/tests/subtitles.py b/tests/subtitles.py index e442756..b24c852 100644 --- a/tests/subtitles.py +++ b/tests/subtitles.py @@ -2,7 +2,8 @@ import tempfile import os import codecs -from subs import AssEvent, AssScript, SrtEvent, SrtScript + +from sushi.subs import AssEvent, AssScript, SrtEvent, SrtScript SINGLE_LINE_SRT_EVENT = """1 00:14:21,960 --> 00:14:22,960 diff --git a/tests/timecodes.py b/tests/timecodes.py index 335006c..5f4fb04 100644 --- a/tests/timecodes.py +++ b/tests/timecodes.py @@ -1,5 +1,6 @@ import unittest -from demux import Timecodes + +from sushi.demux import Timecodes class CfrTimecodesTestCase(unittest.TestCase): From 823f7a63137339423ee413a9e691157aeac8c89c Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 19:49:32 +0200 Subject: [PATCH 02/17] Run 2to3 with manual adjustments --- sushi/__main__.py | 46 ++++++++++----------- sushi/chapters.py | 2 +- sushi/common.py | 4 +- sushi/demux.py | 2 +- sushi/regression-tests.py | 10 ++--- sushi/subs.py | 50 +++++++++++------------ sushi/wav.py | 3 +- tests/demuxing.py | 12 +++--- tests/main.py | 14 +++---- tests/subtitles.py | 86 +++++++++++++++++++-------------------- 10 files changed, 115 insertions(+), 114 deletions(-) diff --git a/sushi/__main__.py b/sushi/__main__.py index 33515c6..696f691 100755 --- a/sushi/__main__.py +++ b/sushi/__main__.py @@ -6,7 +6,7 @@ import os import bisect import collections -from itertools import takewhile, izip, chain +from itertools import takewhile, chain import time import numpy as np @@ -70,26 +70,26 @@ def abs_diff(a, b): def interpolate_nones(data, points): data = ensure_static_collection(data) - values_lookup = {p: v for p, v in izip(points, data) if v is not None} + values_lookup = {p: v for p, v in zip(points, data) if v is not None} if not values_lookup: return [] - zero_points = {p for p, v in izip(points, data) if v is None} + zero_points = {p for p, v in zip(points, data) if v is None} if not zero_points: return data - data_list = sorted(values_lookup.iteritems()) + data_list = sorted(values_lookup.items()) zero_points = sorted(x for x in zero_points if x not in values_lookup) out = np.interp(x=zero_points, - xp=map(operator.itemgetter(0), data_list), - fp=map(operator.itemgetter(1), data_list)) + xp=list(map(operator.itemgetter(0), data_list)), + fp=list(map(operator.itemgetter(1), data_list))) - values_lookup.update(izip(zero_points, out)) + values_lookup.update(zip(zero_points, out)) return [ values_lookup[point] if value is None else value - for point, value in izip(points, data) + for point, value in zip(points, data) ] @@ -100,7 +100,7 @@ def running_median(values, window_size): half_window = window_size // 2 medians = [] items_count = len(values) - for idx in xrange(items_count): + for idx in range(items_count): radius = min(half_window, idx, items_count-idx-1) med = np.median(values[idx-radius:idx+radius+1]) medians.append(med) @@ -113,7 +113,7 @@ def smooth_events(events, radius): window_size = radius*2+1 shifts = [e.shift for e in events] smoothed = running_median(shifts, window_size) - for event, new_shift in izip(events, smoothed): + for event, new_shift in zip(events, smoothed): event.set_shift(new_shift, event.diff) @@ -128,7 +128,7 @@ def detect_groups(events_iter): def groups_from_chapters(events, times): - logging.info(u'Chapter start points: {0}'.format([format_time(t) for t in times])) + logging.info('Chapter start points: {0}'.format([format_time(t) for t in times])) groups = [[]] chapter_times = iter(times[1:] + [36000000000]) # very large event at the end current_chapter = next(chapter_times) @@ -141,7 +141,7 @@ def groups_from_chapters(events, times): groups[-1].append(event) - groups = filter(None, groups) # non-empty groups + groups = [g for g in groups if g] # non-empty groups # check if we have any groups where every event is linked # for example a chapter with only comments inside broken_groups = [group for group in groups if not any(e for e in group if not e.linked)] @@ -152,7 +152,7 @@ def groups_from_chapters(events, times): parent_group = next(group for group in groups if parent in group) parent_group.append(event) del group[:] - groups = filter(None, groups) + groups = [g for g in groups if g] # re-sort the groups again since we might break the order when inserting linked events # sorting everything again is far from optimal but python sorting is very fast for sorted arrays anyway for group in groups: @@ -167,9 +167,9 @@ def split_broken_groups(groups): for g in groups: std = np.std([e.shift for e in g]) if std > MAX_GROUP_STD: - logging.warn(u'Shift is not consistent between {0} and {1}, most likely chapters are wrong (std: {2}). ' - u'Switching to automatic grouping.'.format(format_time(g[0].start), format_time(g[-1].end), - std)) + logging.warn('Shift is not consistent between {0} and {1}, most likely chapters are wrong (std: {2}). ' + 'Switching to automatic grouping.'.format(format_time(g[0].start), format_time(g[-1].end), + std)) correct_groups.extend(detect_groups(g)) broken_found = True else: @@ -281,10 +281,10 @@ def snap_groups_to_keyframes(events, chapter_times, max_ts_duration, max_ts_dist shifts = interpolate_nones(shifts, times) if shifts: mean_shift = np.mean(shifts) - shifts = zip(*(iter(shifts), ) * 2) + shifts = zip(*[iter(shifts)] * 2) logging.info('Group {0}-{1} corrected by {2}'.format(format_time(events[0].start), format_time(events[-1].end), mean_shift)) - for group, (start_shift, end_shift) in izip(groups, shifts): + for group, (start_shift, end_shift) in zip(groups, shifts): if abs(start_shift-end_shift) > 0.001 and len(group) > 1: actual_shift = min(start_shift, end_shift, key=lambda x: abs(x - mean_shift)) logging.warning("Typesetting group at {0} had different shift at start/end points ({1} and {2}). Shifting by {3}." @@ -359,7 +359,7 @@ def prepare_search_groups(events, source_duration, chapter_times, max_ts_duratio event.link_event(last_unlinked) continue if (event.start + event.duration / 2.0) > source_duration: - logging.info('Event time outside of audio range, ignoring: %s' % unicode(event)) + logging.info('Event time outside of audio range, ignoring: %s', event) event.link_event(last_unlinked) continue elif event.end == event.start: @@ -400,7 +400,7 @@ def prepare_search_groups(events, source_duration, chapter_times, max_ts_duratio def calculate_shifts(src_stream, dst_stream, groups_list, normal_window, max_window, rewind_thresh): def log_shift(state): logging.info('{0}-{1}: shift: {2:0.10f}, diff: {3:0.10f}' - .format(format_time(state["start_time"]), format_time(state["end_time"]), state["shift"], state["diff"])) + .format(format_time(state["start_time"]), format_time(state["end_time"]), state["shift"], state["diff"])) def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offset): logging.debug('{0}-{1}: shift: {2:0.5f} [{3:0.5f}, {4:0.5f}], search offset: {5:0.6f}' @@ -495,7 +495,7 @@ def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offs for state in uncommitted_states: log_shift(state) - for idx, (search_group, group_state) in enumerate(izip(groups_list, chain(committed_states, uncommitted_states))): + for idx, (search_group, group_state) in enumerate(zip(groups_list, chain(committed_states, uncommitted_states))): if group_state["shift"] is None: for group in reversed(groups_list[:idx]): link_to = next((x for x in reversed(group) if not x.linked), None) @@ -698,8 +698,8 @@ def select_timecodes(external_file, fps_arg, demuxer): start_shift = g[0].shift end_shift = g[-1].shift avg_shift = average_shifts(g) - logging.info(u'Group (start: {0}, end: {1}, lines: {2}), ' - u'shifts (start: {3}, end: {4}, average: {5})' + logging.info('Group (start: {0}, end: {1}, lines: {2}), ' + 'shifts (start: {3}, end: {4}, average: {5})' .format(format_time(g[0].start), format_time(g[-1].end), len(g), start_shift, end_shift, avg_shift)) diff --git a/sushi/chapters.py b/sushi/chapters.py index 5eee928..104fb2d 100644 --- a/sushi/chapters.py +++ b/sushi/chapters.py @@ -5,7 +5,7 @@ def parse_times(times): result = [] for t in times: - hours, minutes, seconds = map(float, t.split(':')) + hours, minutes, seconds = list(map(float, t.split(':'))) result.append(hours * 3600 + minutes * 60 + seconds) result.sort() diff --git a/sushi/common.py b/sushi/common.py index 595fbe5..7a6d96e 100644 --- a/sushi/common.py +++ b/sushi/common.py @@ -22,7 +22,7 @@ def ensure_static_collection(value): def format_srt_time(seconds): cs = round(seconds * 1000) - return u'{0:02d}:{1:02d}:{2:02d},{3:03d}'.format( + return '{0:02d}:{1:02d}:{2:02d},{3:03d}'.format( int(cs // 3600000), int((cs // 60000) % 60), int((cs // 1000) % 60), @@ -31,7 +31,7 @@ def format_srt_time(seconds): def format_time(seconds): cs = round(seconds * 100) - return u'{0}:{1:02d}:{2:02d}.{3:02d}'.format( + return '{0}:{1:02d}:{2:02d}.{3:02d}'.format( int(cs // 360000), int((cs // 6000) % 60), int((cs // 100) % 60), diff --git a/sushi/demux.py b/sushi/demux.py index bda06d8..4743b3a 100644 --- a/sushi/demux.py +++ b/sushi/demux.py @@ -75,7 +75,7 @@ def _get_video_streams(info): @staticmethod def _get_chapters_times(info): - return map(float, re.findall(r'Chapter #0.\d+: start (\d+\.\d+)', info)) + return list(map(float, re.findall(r'Chapter #0.\d+: start (\d+\.\d+)', info))) @staticmethod def _get_subtitles_streams(info): diff --git a/sushi/regression-tests.py b/sushi/regression-tests.py index efcd9bb..30ce5e7 100644 --- a/sushi/regression-tests.py +++ b/sushi/regression-tests.py @@ -54,19 +54,19 @@ def compare_scripts(ideal_path, test_path, timecodes, test_name, expected_errors test_end_frame = timecodes.get_frame_number(test.end) if ideal_start_frame != test_start_frame and ideal_end_frame != test_end_frame: - logging.debug(u'{0}: start and end time failed at "{1}". {2}-{3} vs {4}-{5}'.format( + logging.debug('{0}: start and end time failed at "{1}". {2}-{3} vs {4}-{5}'.format( idx, strip_tags(ideal.text), ft(ideal.start), ft(ideal.end), ft(test.start), ft(test.end)) ) failed += 1 elif ideal_end_frame != test_end_frame: logging.debug( - u'{0}: end time failed at "{1}". {2} vs {3}'.format( + '{0}: end time failed at "{1}". {2} vs {3}'.format( idx, strip_tags(ideal.text), ft(ideal.end), ft(test.end)) ) failed += 1 elif ideal_start_frame != test_start_frame: logging.debug( - u'{0}: start time failed at "{1}". {2} vs {3}'.format( + '{0}: start time failed at "{1}". {2} vs {3}'.format( idx, strip_tags(ideal.text), ft(ideal.start), ft(test.start)) ) failed += 1 @@ -189,7 +189,7 @@ def should_run(name): return not args.run_only or name in args.run_only failed = ran = 0 - for test_name, params in config.get('tests', {}).iteritems(): + for test_name, params in config.get('tests', {}).items(): if not should_run(test_name): continue if not params.get('disabled', False): @@ -201,7 +201,7 @@ def should_run(name): logging.warn('Test "{0}" disabled'.format(test_name)) if should_run("wavs"): - for test_name, params in config.get('wavs', {}).iteritems(): + for test_name, params in config.get('wavs', {}).items(): ran += 1 if not run_wav_test(test_name, os.path.join(config['basepath'], params['file']), params): failed += 1 diff --git a/sushi/subs.py b/sushi/subs.py index 66ccaa7..670f2ad 100644 --- a/sushi/subs.py +++ b/sushi/subs.py @@ -80,7 +80,7 @@ def adjust_shift(self, value): self._shift += value def __repr__(self): - return unicode(self) + return str(self) class ScriptBase(object): @@ -113,8 +113,8 @@ def from_string(cls, text): return SrtEvent(int(match.group(1)), start, end, match.group(4).strip()) def __unicode__(self): - return u'{0}\n{1} --> {2}\n{3}'.format(self.source_index, self._format_time(self.start), - self._format_time(self.end), self.text) + return '{0}\n{1} --> {2}\n{3}'.format(self.source_index, self._format_time(self.start), + self._format_time(self.end), self.text) @staticmethod def parse_time(time_string): @@ -142,7 +142,7 @@ def from_file(cls, path): raise SushiError("Script {0} not found".format(path)) def save_to_file(self, path): - text = '\n\n'.join(map(unicode, self.events)) + text = '\n\n'.join(map(str, self.events)) with codecs.open(path, encoding='utf-8', mode='w') as script: script.write(text) @@ -169,13 +169,13 @@ def __init__(self, text, position=0): self.effect = split[8] def __unicode__(self): - return u'{0}: {1},{2},{3},{4},{5},{6},{7},{8},{9},{10}'.format(self.kind, self.layer, - self._format_time(self.start), - self._format_time(self.end), - self.style, self.name, - self.margin_left, self.margin_right, - self.margin_vertical, self.effect, - self.text) + return '{0}: {1},{2},{3},{4},{5},{6},{7},{8},{9},{10}'.format(self.kind, self.layer, + self._format_time(self.start), + self._format_time(self.end), + self.style, self.name, + self.margin_left, self.margin_right, + self.margin_vertical, self.effect, + self.text) @staticmethod def _format_time(seconds): @@ -195,17 +195,17 @@ def from_file(cls, path): other_sections = collections.OrderedDict() def parse_script_info_line(line): - if line.startswith(u'Format:'): + if line.startswith('Format:'): return script_info.append(line) def parse_styles_line(line): - if line.startswith(u'Format:'): + if line.startswith('Format:'): return styles.append(line) def parse_event_line(line): - if line.startswith(u'Format:'): + if line.startswith('Format:'): return events.append(AssEvent(line, position=len(events)+1)) @@ -224,11 +224,11 @@ def create_generic_parse(section_name): if not line: continue low = line.lower() - if low == u'[script info]': + if low == '[script info]': parse_function = parse_script_info_line - elif low == u'[v4+ styles]': + elif low == '[v4+ styles]': parse_function = parse_styles_line - elif low == u'[events]': + elif low == '[events]': parse_function = parse_event_line elif re.match(r'\[.+?\]', low): parse_function = create_generic_parse(line) @@ -248,27 +248,27 @@ def save_to_file(self, path): # raise RuntimeError('File %s already exists' % path) lines = [] if self.script_info: - lines.append(u'[Script Info]') + lines.append('[Script Info]') lines.extend(self.script_info) lines.append('') if self.styles: - lines.append(u'[V4+ Styles]') - lines.append(u'Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding') + lines.append('[V4+ Styles]') + lines.append('Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding') lines.extend(self.styles) lines.append('') if self.events: events = sorted(self.events, key=lambda x: x.source_index) - lines.append(u'[Events]') - lines.append(u'Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text') - lines.extend(map(unicode, events)) + lines.append('[Events]') + lines.append('Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text') + lines.extend(map(str, events)) if self.other: - for section_name, section_lines in self.other.iteritems(): + for section_name, section_lines in self.other.items(): lines.append('') lines.append(section_name) lines.extend(section_lines) with codecs.open(path, encoding='utf-8-sig', mode='w') as script: - script.write(unicode(os.linesep).join(lines)) + script.write(str(os.linesep).join(lines)) diff --git a/sushi/wav.py b/sushi/wav.py index e9cce25..28fc472 100644 --- a/sushi/wav.py +++ b/sushi/wav.py @@ -8,6 +8,7 @@ import os.path from .common import SushiError, clip +from functools import reduce WAVE_FORMAT_PCM = 0x0001 WAVE_FORMAT_EXTENSIBLE = 0xFFFE @@ -86,7 +87,7 @@ def readframes(self, count): if min_length != real_length: logging.error("Length of audio channels didn't match. This might result in broken output") - channels = (unpacked[i::self.channels_count] for i in xrange(self.channels_count)) + channels = (unpacked[i::self.channels_count] for i in range(self.channels_count)) data = reduce(lambda a, b: a[:min_length]+b[:min_length], channels) data /= float(self.channels_count) return data diff --git a/tests/demuxing.py b/tests/demuxing.py index 8e081b2..4d9a610 100644 --- a/tests/demuxing.py +++ b/tests/demuxing.py @@ -1,5 +1,5 @@ import unittest -import mock +from unittest import mock from sushi.demux import FFmpeg, MkvToolnix, SCXviD from sushi.common import SushiError @@ -60,7 +60,7 @@ def test_parses_subtitles_stream(self): @mock.patch('subprocess.Popen', new_callable=create_popen_mock) def test_get_info_call_args(self, popen_mock): FFmpeg.get_info('random_file.mkv') - self.assertEquals(popen_mock.call_args[0][0], ['ffmpeg', '-hide_banner', '-i', 'random_file.mkv']) + self.assertEqual(popen_mock.call_args[0][0], ['ffmpeg', '-hide_banner', '-i', 'random_file.mkv']) @mock.patch('subprocess.Popen', new_callable=create_popen_mock) def test_get_info_fail_when_no_mmpeg(self, popen_mock): @@ -112,8 +112,8 @@ def raise_no_app(cmd_args, **kwargs): raise OSError(2, 'ignored') popen_mock.side_effect = raise_no_app - self.assertRaisesRegexp(SushiError, '[fF][fF][mM][pP][eE][gG]', - lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) + self.assertRaisesRegex(SushiError, '[fF][fF][mM][pP][eE][gG]', + lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) @mock.patch('subprocess.Popen') def test_no_scxvid(self, popen_mock): @@ -123,8 +123,8 @@ def raise_no_app(cmd_args, **kwargs): return mock.Mock() popen_mock.side_effect = raise_no_app - self.assertRaisesRegexp(SushiError, '[sS][cC][xX][vV][iI][dD]', - lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) + self.assertRaisesRegex(SushiError, '[sS][cC][xX][vV][iI][dD]', + lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) class ExternalChaptersTestCase(unittest.TestCase): diff --git a/tests/main.py b/tests/main.py index c234d53..99133df 100644 --- a/tests/main.py +++ b/tests/main.py @@ -2,7 +2,7 @@ import os import re import unittest -from mock import patch, ANY +from unittest.mock import patch, ANY from sushi.common import SushiError, format_time from sushi import __main__ as sushi @@ -34,7 +34,7 @@ def __eq__(self, other): class InterpolateNonesTestCase(unittest.TestCase): def test_returns_empty_array_when_passed_empty_array(self): - self.assertEquals(sushi.interpolate_nones([], []), []) + self.assertEqual(sushi.interpolate_nones([], []), []) def test_returns_false_when_no_valid_points(self): self.assertFalse(sushi.interpolate_nones([None, None, None], [1, 2, 3])) @@ -112,7 +112,7 @@ def test_events_in_two_groups_one_chapter(self): self.assertItemsEqual([events[1], events[2]], groups[1]) def test_multiple_groups_multiple_chapters(self): - events = [FakeEvent(end=x) for x in xrange(1, 10)] + events = [FakeEvent(end=x) for x in range(1, 10)] groups = sushi.groups_from_chapters(events, [0.0, 3.2, 4.4, 7.7]) self.assertEqual(4, len(groups)) self.assertItemsEqual(events[0:3], groups[0]) @@ -207,16 +207,16 @@ def test_checks_that_files_exist(self, mock_object): def test_raises_on_unknown_script_type(self, ignore): keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.mp4'] - self.assertRaisesRegexp(SushiError, self.any_case_regex(r'script.*type'), lambda: sushi.parse_args_and_run(keys)) + self.assertRaisesRegex(SushiError, self.any_case_regex(r'script.*type'), lambda: sushi.parse_args_and_run(keys)) def test_raises_on_script_type_not_matching(self, ignore): keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '-o', 'd.srt'] - self.assertRaisesRegexp(SushiError, self.any_case_regex(r'script.*type.*match'), - lambda: sushi.parse_args_and_run(keys)) + self.assertRaisesRegex(SushiError, self.any_case_regex(r'script.*type.*match'), + lambda: sushi.parse_args_and_run(keys)) def test_raises_on_timecodes_and_fps_being_defined_together(self, ignore): keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '--src-timecodes', 'tc.txt', '--src-fps', '25'] - self.assertRaisesRegexp(SushiError, self.any_case_regex(r'timecodes'), lambda: sushi.parse_args_and_run(keys)) + self.assertRaisesRegex(SushiError, self.any_case_regex(r'timecodes'), lambda: sushi.parse_args_and_run(keys)) class FormatTimeTestCase(unittest.TestCase): diff --git a/tests/subtitles.py b/tests/subtitles.py index b24c852..6eaf196 100644 --- a/tests/subtitles.py +++ b/tests/subtitles.py @@ -23,45 +23,45 @@ class SrtEventTestCase(unittest.TestCase): def test_simple_parsing(self): event = SrtEvent.from_string(SINGLE_LINE_SRT_EVENT) - self.assertEquals(14*60+21.960, event.start) - self.assertEquals(14*60+22.960, event.end) - self.assertEquals("HOW DID IT END UP LIKE THIS?", event.text) + self.assertEqual(14*60+21.960, event.start) + self.assertEqual(14*60+22.960, event.end) + self.assertEqual("HOW DID IT END UP LIKE THIS?", event.text) def test_multi_line_event_parsing(self): event = SrtEvent.from_string(MULTILINE_SRT_EVENT) - self.assertEquals(13*60+12.140, event.start) - self.assertEquals(13*60+14.100, event.end) - self.assertEquals("APPEARANCE!\nAppearrance (teisai)!\nNo wait, you're the worst (saitei)!", event.text) + self.assertEqual(13*60+12.140, event.start) + self.assertEqual(13*60+14.100, event.end) + self.assertEqual("APPEARANCE!\nAppearrance (teisai)!\nNo wait, you're the worst (saitei)!", event.text) def test_parsing_and_printing(self): - self.assertEquals(SINGLE_LINE_SRT_EVENT, unicode(SrtEvent.from_string(SINGLE_LINE_SRT_EVENT))) - self.assertEquals(MULTILINE_SRT_EVENT, unicode(SrtEvent.from_string(MULTILINE_SRT_EVENT))) + self.assertEqual(SINGLE_LINE_SRT_EVENT, str(SrtEvent.from_string(SINGLE_LINE_SRT_EVENT))) + self.assertEqual(MULTILINE_SRT_EVENT, str(SrtEvent.from_string(MULTILINE_SRT_EVENT))) class AssEventTestCase(unittest.TestCase): def test_simple_parsing(self): event = AssEvent(ASS_EVENT) self.assertFalse(event.is_comment) - self.assertEquals("Dialogue", event.kind) - self.assertEquals(18*60+50.98, event.start) - self.assertEquals(18*60+55.28, event.end) - self.assertEquals("0", event.layer) - self.assertEquals("Default", event.style) - self.assertEquals("", event.name) - self.assertEquals("0", event.margin_left) - self.assertEquals("0", event.margin_right) - self.assertEquals("0", event.margin_vertical) - self.assertEquals("", event.effect) - self.assertEquals("Are you trying to (ouch) crush it (ouch)\N like a (ouch) vise (ouch, ouch)?", event.text) + self.assertEqual("Dialogue", event.kind) + self.assertEqual(18*60+50.98, event.start) + self.assertEqual(18*60+55.28, event.end) + self.assertEqual("0", event.layer) + self.assertEqual("Default", event.style) + self.assertEqual("", event.name) + self.assertEqual("0", event.margin_left) + self.assertEqual("0", event.margin_right) + self.assertEqual("0", event.margin_vertical) + self.assertEqual("", event.effect) + self.assertEqual("Are you trying to (ouch) crush it (ouch)\N like a (ouch) vise (ouch, ouch)?", event.text) def test_comment_parsing(self): event = AssEvent(ASS_COMMENT) self.assertTrue(event.is_comment) - self.assertEquals("Comment", event.kind) + self.assertEqual("Comment", event.kind) def test_parsing_and_printing(self): - self.assertEquals(ASS_EVENT, unicode(AssEvent(ASS_EVENT))) - self.assertEquals(ASS_COMMENT, unicode(AssEvent(ASS_COMMENT))) + self.assertEqual(ASS_EVENT, str(AssEvent(ASS_EVENT))) + self.assertEqual(ASS_COMMENT, str(AssEvent(ASS_COMMENT))) class ScriptTestBase(unittest.TestCase): @@ -78,7 +78,7 @@ def test_write_to_file(self): SrtScript(events).save_to_file(self.script_path) with open(self.script_path) as script: text = script.read() - self.assertEquals(SINGLE_LINE_SRT_EVENT + "\n\n" + MULTILINE_SRT_EVENT, text) + self.assertEqual(SINGLE_LINE_SRT_EVENT + "\n\n" + MULTILINE_SRT_EVENT, text) def test_read_from_file(self): os.write(self.script_description, """1 @@ -98,18 +98,18 @@ def test_read_from_file(self): 00:00:21,250 --> 00:00:22,750 Serves you right.""") parsed = SrtScript.from_file(self.script_path).events - self.assertEquals(17.5, parsed[0].start) - self.assertEquals(18.87, parsed[0].end) - self.assertEquals("Yeah, really!", parsed[0].text) - self.assertEquals(17.5, parsed[1].start) - self.assertEquals(18.87, parsed[1].end) - self.assertEquals("", parsed[1].text) - self.assertEquals(17.5, parsed[2].start) - self.assertEquals(18.87, parsed[2].end) - self.assertEquals("House number\n35", parsed[2].text) - self.assertEquals(21.25, parsed[3].start) - self.assertEquals(22.75, parsed[3].end) - self.assertEquals("Serves you right.", parsed[3].text) + self.assertEqual(17.5, parsed[0].start) + self.assertEqual(18.87, parsed[0].end) + self.assertEqual("Yeah, really!", parsed[0].text) + self.assertEqual(17.5, parsed[1].start) + self.assertEqual(18.87, parsed[1].end) + self.assertEqual("", parsed[1].text) + self.assertEqual(17.5, parsed[2].start) + self.assertEqual(18.87, parsed[2].end) + self.assertEqual("House number\n35", parsed[2].text) + self.assertEqual(21.25, parsed[3].start) + self.assertEqual(22.75, parsed[3].end) + self.assertEqual("Serves you right.", parsed[3].text) class AssScriptTestCase(ScriptTestBase): @@ -121,7 +121,7 @@ def test_write_to_file(self): with codecs.open(self.script_path, encoding='utf-8-sig') as script: text = script.read() - self.assertEquals("""[V4+ Styles] + self.assertEqual("""[V4+ Styles] Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1 @@ -147,10 +147,10 @@ def test_read_from_file(self): os.write(self.script_description, text) script = AssScript.from_file(self.script_path) - self.assertEquals(["; Script generated by Aegisub 3.1.1", "Title: script title"], script.script_info) - self.assertEquals(["Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1", - "Style: Signs,Gentium Basic,40,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,0,0,2,10,10,10,1"], - script.styles) - self.assertEquals([1, 2], [x.source_index for x in script.events]) - self.assertEquals(u"Dialogue: 0,0:00:01.42,0:00:03.36,Default,,0000,0000,0000,,As you already know,", unicode(script.events[0])) - self.assertEquals(u"Dialogue: 0,0:00:03.36,0:00:05.93,Default,,0000,0000,0000,,I'm concerned about the hair on my nipples.", unicode(script.events[1])) + self.assertEqual(["; Script generated by Aegisub 3.1.1", "Title: script title"], script.script_info) + self.assertEqual(["Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1", + "Style: Signs,Gentium Basic,40,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,0,0,2,10,10,10,1"], + script.styles) + self.assertEqual([1, 2], [x.source_index for x in script.events]) + self.assertEqual("Dialogue: 0,0:00:01.42,0:00:03.36,Default,,0000,0000,0000,,As you already know,", str(script.events[0])) + self.assertEqual("Dialogue: 0,0:00:03.36,0:00:05.93,Default,,0000,0000,0000,,I'm concerned about the hair on my nipples.", str(script.events[1])) From 0d28fb5a7d07592719b88bb868a727ac3ca5deb7 Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 20:02:49 +0200 Subject: [PATCH 03/17] Address flake8 warnings Ignoring line length. --- run-tests.py | 9 +++++---- setup.py | 1 - sushi/__init__.py | 1 + sushi/__main__.py | 30 +++++++++++++++--------------- sushi/demux.py | 14 +++++++------- sushi/keyframes.py | 3 ++- sushi/regression-tests.py | 4 ++-- sushi/subs.py | 4 ++-- sushi/wav.py | 8 ++++---- tests/main.py | 1 - tests/subtitles.py | 14 +++++++------- tests/timecodes.py | 38 ++++++++++++++++++-------------------- 12 files changed, 63 insertions(+), 64 deletions(-) diff --git a/run-tests.py b/run-tests.py index 30c8194..171a31f 100644 --- a/run-tests.py +++ b/run-tests.py @@ -1,7 +1,8 @@ import unittest -from tests.timecodes import * -from tests.main import * -from tests.subtitles import * -from tests.demuxing import * + +from tests.timecodes import * # noqa +from tests.main import * # noqa +from tests.subtitles import * # noqa +from tests.demuxing import * # noqa unittest.main(verbosity=0) diff --git a/setup.py b/setup.py index 5a2c7dd..312a0bb 100644 --- a/setup.py +++ b/setup.py @@ -9,4 +9,3 @@ console=['sushi.py'], license='MIT' ) - diff --git a/sushi/__init__.py b/sushi/__init__.py index e69de29..63d497c 100644 --- a/sushi/__init__.py +++ b/sushi/__init__.py @@ -0,0 +1 @@ +from __main__ import ALLOWED_ERROR, MAX_GROUP_STD, VERSION # noqa diff --git a/sushi/__main__.py b/sushi/__main__.py index 696f691..3702911 100755 --- a/sushi/__main__.py +++ b/sushi/__main__.py @@ -5,7 +5,6 @@ import argparse import os import bisect -import collections from itertools import takewhile, chain import time @@ -57,7 +56,7 @@ def format(self, record): elif record.levelno == logging.WARN: self._fmt = self.warn_format elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: - self._fmt = self.error_format + self._fmt = self.error_format else: self._fmt = self.default_format @@ -101,8 +100,8 @@ def running_median(values, window_size): medians = [] items_count = len(values) for idx in range(items_count): - radius = min(half_window, idx, items_count-idx-1) - med = np.median(values[idx-radius:idx+radius+1]) + radius = min(half_window, idx, items_count - idx - 1) + med = np.median(values[idx - radius:idx + radius + 1]) medians.append(med) return medians @@ -110,7 +109,7 @@ def running_median(values, window_size): def smooth_events(events, radius): if not radius: return - window_size = radius*2+1 + window_size = radius * 2 + 1 shifts = [e.shift for e in events] smoothed = running_median(shifts, window_size) for event, new_shift in zip(events, smoothed): @@ -254,7 +253,7 @@ def find_keyframe_distance(src_time, dst_time): dst = get_distance_to_closest_kf(dst_time, dst_keytimes) snapping_limit = timecodes.get_frame_size(src_time) * max_kf_distance - if abs(src) < snapping_limit and abs(dst) < snapping_limit and abs(src-dst) < snapping_limit: + if abs(src) < snapping_limit and abs(dst) < snapping_limit and abs(src - dst) < snapping_limit: return dst - src return 0 @@ -285,7 +284,7 @@ def snap_groups_to_keyframes(events, chapter_times, max_ts_duration, max_ts_dist logging.info('Group {0}-{1} corrected by {2}'.format(format_time(events[0].start), format_time(events[-1].end), mean_shift)) for group, (start_shift, end_shift) in zip(groups, shifts): - if abs(start_shift-end_shift) > 0.001 and len(group) > 1: + if abs(start_shift - end_shift) > 0.001 and len(group) > 1: actual_shift = min(start_shift, end_shift, key=lambda x: abs(x - mean_shift)) logging.warning("Typesetting group at {0} had different shift at start/end points ({1} and {2}). Shifting by {3}." .format(format_time(group[0].start), start_shift, end_shift, actual_shift)) @@ -336,7 +335,7 @@ def merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts else: group = [event] group_end = event.end - i = idx+1 + i = idx + 1 while i < len(events) and abs(group_end - events[i].start) < max_ts_distance: if events[i].end < next_chapter and events[i].duration <= max_ts_duration: processed.add(i) @@ -354,7 +353,7 @@ def prepare_search_groups(events, source_duration, chapter_times, max_ts_duratio for idx, event in enumerate(events): if event.is_comment: try: - event.link_event(events[idx+1]) + event.link_event(events[idx + 1]) except IndexError: event.link_event(last_unlinked) continue @@ -372,8 +371,9 @@ def prepare_search_groups(events, source_duration, chapter_times, max_ts_duratio # link lines with start and end times identical to some other event # assuming scripts are sorted by start time so we don't search the entire collection - same_start = lambda x: event.start == x.start - processed = next((x for x in takewhile(same_start, reversed(events[:idx])) if not x.linked and x.end == event.end),None) + def same_start(x): + return event.start == x.start + processed = next((x for x in takewhile(same_start, reversed(events[:idx])) if not x.linked and x.end == event.end), None) if processed: event.link_event(processed) else: @@ -442,7 +442,7 @@ def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offs idx += 1 continue - left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0])/2], axis=1) + left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0]) / 2], axis=1) right_half_offset = len(left_audio_half[0]) / float(src_stream.sample_rate) terminate = False # searching from last committed shift @@ -456,7 +456,7 @@ def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offs if not terminate and uncommitted_states and uncommitted_states[-1]["shift"] is not None \ and original_time + uncommitted_states[-1]["shift"] < dst_stream.duration_seconds: - start_offset = uncommitted_states[-1]["shift"] + start_offset = uncommitted_states[-1]["shift"] diff, new_time = dst_stream.find_substream(tv_audio, original_time + start_offset, window) left_side_time = dst_stream.find_substream(left_audio_half, original_time + start_offset, window)[1] right_side_time = dst_stream.find_substream(right_audio_half, original_time + start_offset + right_half_offset, window)[1] - right_half_offset @@ -580,7 +580,7 @@ def run(args): src_script_path = args.script_file else: stype = src_demuxer.get_subs_type(args.src_script_idx) - src_script_path = format_full_path(args.temp_dir, args.source, '.sushi'+ stype) + src_script_path = format_full_path(args.temp_dir, args.source, '.sushi' + stype) src_demuxer.set_script(stream_idx=args.src_script_idx, output_path=src_script_path) script_extension = get_extension(src_script_path) @@ -726,7 +726,7 @@ def select_timecodes(external_file, fps_arg, demuxer): script.save_to_file(dst_script_path) if write_plot: - plt.plot([x.shift + (x._start_shift + x._end_shift)/2.0 for x in events], label='After correction') + plt.plot([x.shift + (x._start_shift + x._end_shift) / 2.0 for x in events], label='After correction') plt.legend(fontsize=5, frameon=False, fancybox=False) plt.savefig(args.plot_path, dpi=300) diff --git a/sushi/demux.py b/sushi/demux.py index 4743b3a..f833cc8 100644 --- a/sushi/demux.py +++ b/sushi/demux.py @@ -113,10 +113,11 @@ class SCXviD(object): def make_keyframes(cls, video_path, log_path): try: ffmpeg_process = subprocess.Popen(['ffmpeg', '-i', video_path, - '-f', 'yuv4mpegpipe', - '-vf', 'scale=640:360', - '-pix_fmt', 'yuv420p', - '-vsync', 'drop', '-'], stdout=subprocess.PIPE) + '-f', 'yuv4mpegpipe', + '-vf', 'scale=640:360', + '-pix_fmt', 'yuv420p', + '-vsync', 'drop', '-'], + stdout=subprocess.PIPE) except OSError as e: if e.errno == 2: raise SushiError("Couldn't invoke ffmpeg, check that it's installed") @@ -143,7 +144,7 @@ def get_frame_time(self, number): return self.times[number] except IndexError: if not self.default_frame_duration: - return self.get_frame_time(len(self.times)-1) + return self.get_frame_time(len(self.times) - 1) if self.times: return self.times[-1] + (self.default_frame_duration) * (number - len(self.times) + 1) else: @@ -157,7 +158,7 @@ def get_frame_number(self, timestamp): def get_frame_size(self, timestamp): try: number = bisect.bisect_left(self.times, timestamp) - except: + except Exception: return self.default_frame_duration c = self.get_frame_time(number) @@ -353,4 +354,3 @@ def _select_stream(self, streams, chosen_idx, name): raise SushiError("Stream with index {0} doesn't exist in {1}.\n" "Here are all that do:\n" "{2}".format(chosen_idx, self._path, self._format_streams_list(streams))) - diff --git a/sushi/keyframes.py b/sushi/keyframes.py index 776fa55..f1b0282 100644 --- a/sushi/keyframes.py +++ b/sushi/keyframes.py @@ -2,7 +2,8 @@ def parse_scxvid_keyframes(text): - return [i-3 for i,line in enumerate(text.splitlines()) if line and line[0] == 'i'] + return [i - 3 for i, line in enumerate(text.splitlines()) if line and line[0] == 'i'] + def parse_keyframes(path): text = read_all_text(path) diff --git a/sushi/regression-tests.py b/sushi/regression-tests.py index 30ce5e7..825f787 100644 --- a/sushi/regression-tests.py +++ b/sushi/regression-tests.py @@ -71,7 +71,7 @@ def compare_scripts(ideal_path, test_path, timecodes, test_name, expected_errors ) failed += 1 - logging.info('Total lines: {0}, good: {1}, failed: {2}'.format(len(ideal_script.events), len(ideal_script.events)-failed, failed)) + logging.info('Total lines: {0}, good: {1}, failed: {2}'.format(len(ideal_script.events), len(ideal_script.events) - failed, failed)) if failed > expected_errors: logging.critical('Got more failed lines than expected ({0} actual vs {1} expected)'.format(failed, expected_errors)) @@ -141,7 +141,7 @@ def run_wav_test(test_name, file_path, params): gc.collect(2) before = resource.getrusage(resource.RUSAGE_SELF) - loaded = WavStream(file_path, params.get('sample_rate', 12000), params.get('sample_type', 'uint8')) + _ = WavStream(file_path, params.get('sample_rate', 12000), params.get('sample_type', 'uint8')) after = resource.getrusage(resource.RUSAGE_SELF) total_time = (after.ru_stime - before.ru_stime) + (after.ru_utime - before.ru_utime) diff --git a/sushi/subs.py b/sushi/subs.py index 670f2ad..3b12ed7 100644 --- a/sushi/subs.py +++ b/sushi/subs.py @@ -95,7 +95,7 @@ class SrtEvent(ScriptEventBase): is_comment = False style = None - EVENT_REGEX = re.compile(""" + EVENT_REGEX = re.compile(r""" (\d+?)\s+? # line-number (\d{1,2}:\d{1,2}:\d{1,2},\d+)\s-->\s(\d{1,2}:\d{1,2}:\d{1,2},\d+). # timestamp (.+?) # actual text @@ -207,7 +207,7 @@ def parse_styles_line(line): def parse_event_line(line): if line.startswith('Format:'): return - events.append(AssEvent(line, position=len(events)+1)) + events.append(AssEvent(line, position=len(events) + 1)) def create_generic_parse(section_name): if section_name in other_sections: diff --git a/sushi/wav.py b/sushi/wav.py index 28fc472..f35d213 100644 --- a/sushi/wav.py +++ b/sushi/wav.py @@ -51,7 +51,7 @@ def __init__(self, path): chunk.skip() if not fmt_chunk_read or not data_chink_read: raise SushiError('Invalid WAV file') - except: + except Exception: self.close() raise @@ -88,7 +88,7 @@ def readframes(self, count): logging.error("Length of audio channels didn't match. This might result in broken output") channels = (unpacked[i::self.channels_count] for i in range(self.channels_count)) - data = reduce(lambda a, b: a[:min_length]+b[:min_length], channels) + data = reduce(lambda a, b: a[:min_length] + b[:min_length], channels) data /= float(self.channels_count) return data @@ -128,7 +128,7 @@ def __init__(self, path, sample_rate=12000, sample_type='uint8'): data = stream.readframes(int(self.READ_CHUNK_SIZE * stream.framerate)) new_length = int(round(len(data) * downsample_rate)) - dst_view = self.data[0][samples_read:samples_read+new_length] + dst_view = self.data[0][samples_read:samples_read + new_length] if downsample_rate != 1: data = data.reshape((1, len(data))) @@ -140,7 +140,7 @@ def __init__(self, path, sample_rate=12000, sample_type='uint8'): # padding the audio from both sides self.data[0][0:self.padding_size].fill(self.data[0][self.padding_size]) - self.data[0][-self.padding_size:].fill(self.data[0][-self.padding_size-1]) + self.data[0][-self.padding_size:].fill(self.data[0][-self.padding_size - 1]) # normalizing # also clipping the stream by 3*median value from both sides of zero diff --git a/tests/main.py b/tests/main.py index 99133df..e3e2fa7 100644 --- a/tests/main.py +++ b/tests/main.py @@ -1,4 +1,3 @@ -from collections import namedtuple import os import re import unittest diff --git a/tests/subtitles.py b/tests/subtitles.py index 6eaf196..fae397e 100644 --- a/tests/subtitles.py +++ b/tests/subtitles.py @@ -23,14 +23,14 @@ class SrtEventTestCase(unittest.TestCase): def test_simple_parsing(self): event = SrtEvent.from_string(SINGLE_LINE_SRT_EVENT) - self.assertEqual(14*60+21.960, event.start) - self.assertEqual(14*60+22.960, event.end) + self.assertEqual(14 * 60 + 21.960, event.start) + self.assertEqual(14 * 60 + 22.960, event.end) self.assertEqual("HOW DID IT END UP LIKE THIS?", event.text) def test_multi_line_event_parsing(self): event = SrtEvent.from_string(MULTILINE_SRT_EVENT) - self.assertEqual(13*60+12.140, event.start) - self.assertEqual(13*60+14.100, event.end) + self.assertEqual(13 * 60 + 12.140, event.start) + self.assertEqual(13 * 60 + 14.100, event.end) self.assertEqual("APPEARANCE!\nAppearrance (teisai)!\nNo wait, you're the worst (saitei)!", event.text) def test_parsing_and_printing(self): @@ -43,8 +43,8 @@ def test_simple_parsing(self): event = AssEvent(ASS_EVENT) self.assertFalse(event.is_comment) self.assertEqual("Dialogue", event.kind) - self.assertEqual(18*60+50.98, event.start) - self.assertEqual(18*60+55.28, event.end) + self.assertEqual(18 * 60 + 50.98, event.start) + self.assertEqual(18 * 60 + 55.28, event.end) self.assertEqual("0", event.layer) self.assertEqual("Default", event.style) self.assertEqual("", event.name) @@ -52,7 +52,7 @@ def test_simple_parsing(self): self.assertEqual("0", event.margin_right) self.assertEqual("0", event.margin_vertical) self.assertEqual("", event.effect) - self.assertEqual("Are you trying to (ouch) crush it (ouch)\N like a (ouch) vise (ouch, ouch)?", event.text) + self.assertEqual("Are you trying to (ouch) crush it (ouch)\\N like a (ouch) vise (ouch, ouch)?", event.text) def test_comment_parsing(self): event = AssEvent(ASS_COMMENT) diff --git a/tests/timecodes.py b/tests/timecodes.py index 5f4fb04..dcc3f6f 100644 --- a/tests/timecodes.py +++ b/tests/timecodes.py @@ -12,25 +12,25 @@ def test_get_frame_time_zero(self): def test_get_frame_time_sane(self): tcs = Timecodes.cfr(23.976) t = tcs.get_frame_time(10) - self.assertAlmostEqual(10.0/23.976, t) + self.assertAlmostEqual(10.0 / 23.976, t) def test_get_frame_time_insane(self): tcs = Timecodes.cfr(23.976) t = tcs.get_frame_time(100000) - self.assertAlmostEqual(100000.0/23.976, t) + self.assertAlmostEqual(100000.0 / 23.976, t) def test_get_frame_size(self): tcs = Timecodes.cfr(23.976) t1 = tcs.get_frame_size(0) t2 = tcs.get_frame_size(1000) - self.assertAlmostEqual(1.0/23.976, t1) + self.assertAlmostEqual(1.0 / 23.976, t1) self.assertAlmostEqual(t1, t2) def test_get_frame_number(self): - tcs = Timecodes.cfr(24000.0/1001.0) + tcs = Timecodes.cfr(24000.0 / 1001.0) self.assertEqual(tcs.get_frame_number(0), 0) self.assertEqual(tcs.get_frame_number(1145.353), 27461) - self.assertEqual(tcs.get_frame_number(1001.0/24000.0 * 1234567), 1234567) + self.assertEqual(tcs.get_frame_number(1001.0 / 24000.0 * 1234567), 1234567) class TimecodesTestCase(unittest.TestCase): @@ -38,9 +38,9 @@ def test_cfr_timecodes_v2(self): text = '# timecode format v2\n' + '\n'.join(str(1000 * x / 23.976) for x in range(0, 30000)) parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(0)) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(25)) - self.assertAlmostEqual(1.0/23.976*100, parsed.get_frame_time(100)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(0)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(25)) + self.assertAlmostEqual(1.0 / 23.976 * 100, parsed.get_frame_time(100)) self.assertEqual(0, parsed.get_frame_time(0)) self.assertEqual(0, parsed.get_frame_number(0)) self.assertEqual(27461, parsed.get_frame_number(1145.353)) @@ -48,9 +48,9 @@ def test_cfr_timecodes_v2(self): def test_cfr_timecodes_v1(self): text = '# timecode format v1\nAssume 23.976024' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976024, parsed.get_frame_size(0)) - self.assertAlmostEqual(1.0/23.976024, parsed.get_frame_size(25)) - self.assertAlmostEqual(1.0/23.976024*100, parsed.get_frame_time(100)) + self.assertAlmostEqual(1.0 / 23.976024, parsed.get_frame_size(0)) + self.assertAlmostEqual(1.0 / 23.976024, parsed.get_frame_size(25)) + self.assertAlmostEqual(1.0 / 23.976024 * 100, parsed.get_frame_time(100)) self.assertEqual(0, parsed.get_frame_time(0)) self.assertEqual(0, parsed.get_frame_number(0)) self.assertEqual(27461, parsed.get_frame_number(1145.353)) @@ -58,30 +58,30 @@ def test_cfr_timecodes_v1(self): def test_cfr_timecodes_v1_with_overrides(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,23.976000\n3000,5000,23.976000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(0)) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(25)) - self.assertAlmostEqual(1.0/23.976*100, parsed.get_frame_time(100)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(0)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(25)) + self.assertAlmostEqual(1.0 / 23.976 * 100, parsed.get_frame_time(100)) self.assertEqual(0, parsed.get_frame_time(0)) def test_vfr_timecodes_v1_frame_size_at_first_frame(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/29.97, parsed.get_frame_size(timestamp=0)) + self.assertAlmostEqual(1.0 / 29.97, parsed.get_frame_size(timestamp=0)) def test_vfr_timecodes_v1_frame_size_outside_of_defined_range(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(timestamp=5000.0)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(timestamp=5000.0)) def test_vft_timecodes_v1_frame_size_inside_override_block(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/29.97, parsed.get_frame_size(timestamp=49.983)) + self.assertAlmostEqual(1.0 / 29.97, parsed.get_frame_size(timestamp=49.983)) def test_vft_timecodes_v1_frame_size_between_override_blocks(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(timestamp=87.496)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(timestamp=87.496)) def test_vfr_timecodes_v1_frame_time_at_first_frame(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' @@ -102,5 +102,3 @@ def test_vft_timecodes_v1_frame_time_between_override_blocks(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) self.assertAlmostEqual(87.579, parsed.get_frame_time(number=2500), places=3) - - From 164e79d30f9eb2b1bee8b3d6a7c61e623040e0e6 Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 21:20:02 +0200 Subject: [PATCH 04/17] Some more restructuring for Python 3 --- sushi/__init__.py | 700 +++++++++++++++++++++++++++++++++++++++++++- sushi/__main__.py | 703 +-------------------------------------------- sushi/demux.py | 3 +- sushi/subs.py | 7 +- tests/main.py | 29 +- tests/subtitles.py | 4 +- 6 files changed, 724 insertions(+), 722 deletions(-) diff --git a/sushi/__init__.py b/sushi/__init__.py index 63d497c..524aaff 100644 --- a/sushi/__init__.py +++ b/sushi/__init__.py @@ -1 +1,699 @@ -from __main__ import ALLOWED_ERROR, MAX_GROUP_STD, VERSION # noqa +import bisect +from itertools import takewhile, chain +import logging +import operator +import os + +import numpy as np + +from . import chapters +from .common import SushiError, get_extension, format_time, ensure_static_collection +from .demux import Timecodes, Demuxer +from . import keyframes +from .subs import AssScript, SrtScript +from .wav import WavStream + + +try: + import matplotlib.pyplot as plt + plot_enabled = True +except ImportError: + plot_enabled = False + + +ALLOWED_ERROR = 0.01 +MAX_GROUP_STD = 0.025 +VERSION = '0.5.1' + + +def abs_diff(a, b): + return abs(a - b) + + +def interpolate_nones(data, points): + data = ensure_static_collection(data) + values_lookup = {p: v for p, v in zip(points, data) if v is not None} + if not values_lookup: + return [] + + zero_points = {p for p, v in zip(points, data) if v is None} + if not zero_points: + return data + + data_list = sorted(values_lookup.items()) + zero_points = sorted(x for x in zero_points if x not in values_lookup) + + out = np.interp(x=zero_points, + xp=list(map(operator.itemgetter(0), data_list)), + fp=list(map(operator.itemgetter(1), data_list))) + + values_lookup.update(zip(zero_points, out)) + + return [ + values_lookup[point] if value is None else value + for point, value in zip(points, data) + ] + + +# todo: implement this as a running median +def running_median(values, window_size): + if window_size % 2 != 1: + raise SushiError('Median window size should be odd') + half_window = window_size // 2 + medians = [] + items_count = len(values) + for idx in range(items_count): + radius = min(half_window, idx, items_count - idx - 1) + med = np.median(values[idx - radius:idx + radius + 1]) + medians.append(med) + return medians + + +def smooth_events(events, radius): + if not radius: + return + window_size = radius * 2 + 1 + shifts = [e.shift for e in events] + smoothed = running_median(shifts, window_size) + for event, new_shift in zip(events, smoothed): + event.set_shift(new_shift, event.diff) + + +def detect_groups(events_iter): + events_iter = iter(events_iter) + groups_list = [[next(events_iter)]] + for event in events_iter: + if abs_diff(event.shift, groups_list[-1][-1].shift) > ALLOWED_ERROR: + groups_list.append([]) + groups_list[-1].append(event) + return groups_list + + +def groups_from_chapters(events, times): + logging.info('Chapter start points: {0}'.format([format_time(t) for t in times])) + groups = [[]] + chapter_times = iter(times[1:] + [36000000000]) # very large event at the end + current_chapter = next(chapter_times) + + for event in events: + if event.end > current_chapter: + groups.append([]) + while event.end > current_chapter: + current_chapter = next(chapter_times) + + groups[-1].append(event) + + groups = [g for g in groups if g] # non-empty groups + # check if we have any groups where every event is linked + # for example a chapter with only comments inside + broken_groups = [group for group in groups if not any(e for e in group if not e.linked)] + if broken_groups: + for group in broken_groups: + for event in group: + parent = event.get_link_chain_end() + parent_group = next(group for group in groups if parent in group) + parent_group.append(event) + del group[:] + groups = [g for g in groups if g] + # re-sort the groups again since we might break the order when inserting linked events + # sorting everything again is far from optimal but python sorting is very fast for sorted arrays anyway + for group in groups: + group.sort(key=lambda event: event.start) + + return groups + + +def split_broken_groups(groups): + correct_groups = [] + broken_found = False + for g in groups: + std = np.std([e.shift for e in g]) + if std > MAX_GROUP_STD: + logging.warn('Shift is not consistent between {0} and {1}, most likely chapters are wrong (std: {2}). ' + 'Switching to automatic grouping.'.format(format_time(g[0].start), format_time(g[-1].end), + std)) + correct_groups.extend(detect_groups(g)) + broken_found = True + else: + correct_groups.append(g) + + if broken_found: + groups_iter = iter(correct_groups) + correct_groups = [list(next(groups_iter))] + for group in groups_iter: + if abs_diff(correct_groups[-1][-1].shift, group[0].shift) >= ALLOWED_ERROR \ + or np.std([e.shift for e in group + correct_groups[-1]]) >= MAX_GROUP_STD: + correct_groups.append([]) + + correct_groups[-1].extend(group) + return correct_groups + + +def fix_near_borders(events): + """ + We assume that all lines with diff greater than 5 * (median diff across all events) are broken + """ + def fix_border(event_list, median_diff): + last_ten_diff = np.median([x.diff for x in event_list[:10]], overwrite_input=True) + diff_limit = min(last_ten_diff, median_diff) + broken = [] + for event in event_list: + if not 0.2 < (event.diff / diff_limit) < 5: + broken.append(event) + else: + for x in broken: + x.link_event(event) + return len(broken) + return 0 + + median_diff = np.median([x.diff for x in events], overwrite_input=True) + + fixed_count = fix_border(events, median_diff) + if fixed_count: + logging.info('Fixing {0} border events right after {1}'.format(fixed_count, format_time(events[0].start))) + + fixed_count = fix_border(list(reversed(events)), median_diff) + if fixed_count: + logging.info('Fixing {0} border events right before {1}'.format(fixed_count, format_time(events[-1].end))) + + +def get_distance_to_closest_kf(timestamp, keyframes): + idx = bisect.bisect_left(keyframes, timestamp) + if idx == 0: + kf = keyframes[0] + elif idx == len(keyframes): + kf = keyframes[-1] + else: + before = keyframes[idx - 1] + after = keyframes[idx] + kf = after if after - timestamp < timestamp - before else before + return kf - timestamp + + +def find_keyframe_shift(group, src_keytimes, dst_keytimes, src_timecodes, dst_timecodes, max_kf_distance): + def get_distance(src_distance, dst_distance, limit): + if abs(dst_distance) > limit: + return None + shift = dst_distance - src_distance + return shift if abs(shift) < limit else None + + src_start = get_distance_to_closest_kf(group[0].start, src_keytimes) + src_end = get_distance_to_closest_kf(group[-1].end + src_timecodes.get_frame_size(group[-1].end), src_keytimes) + + dst_start = get_distance_to_closest_kf(group[0].shifted_start, dst_keytimes) + dst_end = get_distance_to_closest_kf(group[-1].shifted_end + dst_timecodes.get_frame_size(group[-1].end), dst_keytimes) + + snapping_limit_start = src_timecodes.get_frame_size(group[0].start) * max_kf_distance + snapping_limit_end = src_timecodes.get_frame_size(group[0].end) * max_kf_distance + + return (get_distance(src_start, dst_start, snapping_limit_start), + get_distance(src_end, dst_end, snapping_limit_end)) + + +def find_keyframes_distances(event, src_keytimes, dst_keytimes, timecodes, max_kf_distance): + def find_keyframe_distance(src_time, dst_time): + src = get_distance_to_closest_kf(src_time, src_keytimes) + dst = get_distance_to_closest_kf(dst_time, dst_keytimes) + snapping_limit = timecodes.get_frame_size(src_time) * max_kf_distance + + if abs(src) < snapping_limit and abs(dst) < snapping_limit and abs(src - dst) < snapping_limit: + return dst - src + return 0 + + ds = find_keyframe_distance(event.start, event.shifted_start) + de = find_keyframe_distance(event.end, event.shifted_end) + return ds, de + + +def snap_groups_to_keyframes(events, chapter_times, max_ts_duration, max_ts_distance, src_keytimes, dst_keytimes, + src_timecodes, dst_timecodes, max_kf_distance, kf_mode): + if not max_kf_distance: + return + + groups = merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance) + + if kf_mode == 'all' or kf_mode == 'shift': + # step 1: snap events without changing their duration. Useful for some slight audio imprecision correction + shifts = [] + times = [] + for group in groups: + shifts.extend(find_keyframe_shift(group, src_keytimes, dst_keytimes, src_timecodes, dst_timecodes, max_kf_distance)) + times.extend((group[0].shifted_start, group[-1].shifted_end)) + + shifts = interpolate_nones(shifts, times) + if shifts: + mean_shift = np.mean(shifts) + shifts = zip(*[iter(shifts)] * 2) + + logging.info('Group {0}-{1} corrected by {2}'.format(format_time(events[0].start), format_time(events[-1].end), mean_shift)) + for group, (start_shift, end_shift) in zip(groups, shifts): + if abs(start_shift - end_shift) > 0.001 and len(group) > 1: + actual_shift = min(start_shift, end_shift, key=lambda x: abs(x - mean_shift)) + logging.warning("Typesetting group at {0} had different shift at start/end points ({1} and {2}). Shifting by {3}." + .format(format_time(group[0].start), start_shift, end_shift, actual_shift)) + for e in group: + e.adjust_shift(actual_shift) + else: + for e in group: + e.adjust_additional_shifts(start_shift, end_shift) + + if kf_mode == 'all' or kf_mode == 'snap': + # step 2: snap start/end times separately + for group in groups: + if len(group) > 1: + pass # we don't snap typesetting + start_shift, end_shift = find_keyframes_distances(group[0], src_keytimes, dst_keytimes, src_timecodes, max_kf_distance) + if abs(start_shift) > 0.01 or abs(end_shift) > 0.01: + logging.info('Snapping {0} to keyframes, start time by {1}, end: {2}'.format(format_time(group[0].start), start_shift, end_shift)) + group[0].adjust_additional_shifts(start_shift, end_shift) + + +def average_shifts(events): + events = [e for e in events if not e.linked] + shifts = [x.shift for x in events] + weights = [1 - x.diff for x in events] + avg = np.average(shifts, weights=weights) + for e in events: + e.set_shift(avg, e.diff) + return avg + + +def merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance): + search_groups = [] + chapter_times = iter(chapter_times[1:] + [100000000]) + next_chapter = next(chapter_times) + events = ensure_static_collection(events) + + processed = set() + for idx, event in enumerate(events): + if idx in processed: + continue + + while event.end > next_chapter: + next_chapter = next(chapter_times) + + if event.duration > max_ts_duration: + search_groups.append([event]) + processed.add(idx) + else: + group = [event] + group_end = event.end + i = idx + 1 + while i < len(events) and abs(group_end - events[i].start) < max_ts_distance: + if events[i].end < next_chapter and events[i].duration <= max_ts_duration: + processed.add(i) + group.append(events[i]) + group_end = max(group_end, events[i].end) + i += 1 + + search_groups.append(group) + + return search_groups + + +def prepare_search_groups(events, source_duration, chapter_times, max_ts_duration, max_ts_distance): + last_unlinked = None + for idx, event in enumerate(events): + if event.is_comment: + try: + event.link_event(events[idx + 1]) + except IndexError: + event.link_event(last_unlinked) + continue + if (event.start + event.duration / 2.0) > source_duration: + logging.info('Event time outside of audio range, ignoring: %s', event) + event.link_event(last_unlinked) + continue + elif event.end == event.start: + logging.info('{0}: skipped because zero duration'.format(format_time(event.start))) + try: + event.link_event(events[idx + 1]) + except IndexError: + event.link_event(last_unlinked) + continue + + # link lines with start and end times identical to some other event + # assuming scripts are sorted by start time so we don't search the entire collection + def same_start(x): + return event.start == x.start + processed = next((x for x in takewhile(same_start, reversed(events[:idx])) if not x.linked and x.end == event.end), None) + if processed: + event.link_event(processed) + else: + last_unlinked = event + + events = (e for e in events if not e.linked) + + search_groups = merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance) + + # link groups contained inside other groups to the larger group + passed_groups = [] + for idx, group in enumerate(search_groups): + try: + other = next(x for x in reversed(search_groups[:idx]) + if x[0].start <= group[0].start + and x[-1].end >= group[-1].end) + for event in group: + event.link_event(other[0]) + except StopIteration: + passed_groups.append(group) + return passed_groups + + +def calculate_shifts(src_stream, dst_stream, groups_list, normal_window, max_window, rewind_thresh): + def log_shift(state): + logging.info('{0}-{1}: shift: {2:0.10f}, diff: {3:0.10f}' + .format(format_time(state["start_time"]), format_time(state["end_time"]), state["shift"], state["diff"])) + + def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offset): + logging.debug('{0}-{1}: shift: {2:0.5f} [{3:0.5f}, {4:0.5f}], search offset: {5:0.6f}' + .format(format_time(state["start_time"]), format_time(state["end_time"]), + shift, left_side_shift, right_side_shift, search_offset)) + + small_window = 1.5 + idx = 0 + committed_states = [] + uncommitted_states = [] + window = normal_window + while idx < len(groups_list): + search_group = groups_list[idx] + tv_audio = src_stream.get_substream(search_group[0].start, search_group[-1].end) + original_time = search_group[0].start + group_state = {"start_time": search_group[0].start, "end_time": search_group[-1].end, "shift": None, "diff": None} + last_committed_shift = committed_states[-1]["shift"] if committed_states else 0 + diff = new_time = None + + if not uncommitted_states: + if original_time + last_committed_shift > dst_stream.duration_seconds: + # event outside of audio range, all events past it are also guaranteed to fail + for g in groups_list[idx:]: + committed_states.append({"start_time": g[0].start, "end_time": g[-1].end, "shift": None, "diff": None}) + logging.info("{0}-{1}: outside of audio range".format(format_time(g[0].start), format_time(g[-1].end))) + break + + if small_window < window: + diff, new_time = dst_stream.find_substream(tv_audio, original_time + last_committed_shift, small_window) + + if new_time is not None and abs_diff(new_time - original_time, last_committed_shift) <= ALLOWED_ERROR: + # fastest case - small window worked, commit the group immediately + group_state.update({"shift": new_time - original_time, "diff": diff}) + committed_states.append(group_state) + log_shift(group_state) + if window != normal_window: + logging.info("Going back to window {0} from {1}".format(normal_window, window)) + window = normal_window + idx += 1 + continue + + left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0]) / 2], axis=1) + right_half_offset = len(left_audio_half[0]) / float(src_stream.sample_rate) + terminate = False + # searching from last committed shift + if original_time + last_committed_shift < dst_stream.duration_seconds: + diff, new_time = dst_stream.find_substream(tv_audio, original_time + last_committed_shift, window) + left_side_time = dst_stream.find_substream(left_audio_half, original_time + last_committed_shift, window)[1] + right_side_time = dst_stream.find_substream(right_audio_half, original_time + last_committed_shift + right_half_offset, window)[1] - right_half_offset + terminate = abs_diff(left_side_time, right_side_time) <= ALLOWED_ERROR and abs_diff(new_time, left_side_time) <= ALLOWED_ERROR + log_uncommitted(group_state, new_time - original_time, left_side_time - original_time, + right_side_time - original_time, last_committed_shift) + + if not terminate and uncommitted_states and uncommitted_states[-1]["shift"] is not None \ + and original_time + uncommitted_states[-1]["shift"] < dst_stream.duration_seconds: + start_offset = uncommitted_states[-1]["shift"] + diff, new_time = dst_stream.find_substream(tv_audio, original_time + start_offset, window) + left_side_time = dst_stream.find_substream(left_audio_half, original_time + start_offset, window)[1] + right_side_time = dst_stream.find_substream(right_audio_half, original_time + start_offset + right_half_offset, window)[1] - right_half_offset + terminate = abs_diff(left_side_time, right_side_time) <= ALLOWED_ERROR and abs_diff(new_time, left_side_time) <= ALLOWED_ERROR + log_uncommitted(group_state, new_time - original_time, left_side_time - original_time, + right_side_time - original_time, start_offset) + + shift = new_time - original_time + if not terminate: + # we aren't back on track yet - add this group to uncommitted + group_state.update({"shift": shift, "diff": diff}) + uncommitted_states.append(group_state) + idx += 1 + if rewind_thresh == len(uncommitted_states) and window < max_window: + logging.warn("Detected possibly broken segment starting at {0}, increasing the window from {1} to {2}" + .format(format_time(uncommitted_states[0]["start_time"]), window, max_window)) + window = max_window + idx = len(committed_states) + del uncommitted_states[:] + continue + + # we're back on track - apply current shift to all broken events + if uncommitted_states: + logging.warning("Events from {0} to {1} will most likely be broken!".format( + format_time(uncommitted_states[0]["start_time"]), + format_time(uncommitted_states[-1]["end_time"]))) + + uncommitted_states.append(group_state) + for state in uncommitted_states: + state.update({"shift": shift, "diff": diff}) + log_shift(state) + committed_states.extend(uncommitted_states) + del uncommitted_states[:] + idx += 1 + + for state in uncommitted_states: + log_shift(state) + + for idx, (search_group, group_state) in enumerate(zip(groups_list, chain(committed_states, uncommitted_states))): + if group_state["shift"] is None: + for group in reversed(groups_list[:idx]): + link_to = next((x for x in reversed(group) if not x.linked), None) + if link_to: + for e in search_group: + e.link_event(link_to) + break + else: + for e in search_group: + e.set_shift(group_state["shift"], group_state["diff"]) + + +def check_file_exists(path, file_title): + if path and not os.path.exists(path): + raise SushiError("{0} file doesn't exist".format(file_title)) + + +def format_full_path(temp_dir, base_path, postfix): + if temp_dir: + return os.path.join(temp_dir, os.path.basename(base_path) + postfix) + else: + return base_path + postfix + + +def create_directory_if_not_exists(path): + if path and not os.path.exists(path): + os.makedirs(path) + + +def run(args): + ignore_chapters = args.chapters_file is not None and args.chapters_file.lower() == 'none' + write_plot = plot_enabled and args.plot_path + if write_plot: + plt.clf() + plt.ylabel('Shift, seconds') + plt.xlabel('Event index') + + # first part should do all possible validation and should NOT take significant amount of time + check_file_exists(args.source, 'Source') + check_file_exists(args.destination, 'Destination') + check_file_exists(args.src_timecodes, 'Source timecodes') + check_file_exists(args.dst_timecodes, 'Source timecodes') + check_file_exists(args.script_file, 'Script') + + if not ignore_chapters: + check_file_exists(args.chapters_file, 'Chapters') + if args.src_keyframes not in ('auto', 'make'): + check_file_exists(args.src_keyframes, 'Source keyframes') + if args.dst_keyframes not in ('auto', 'make'): + check_file_exists(args.dst_keyframes, 'Destination keyframes') + + if (args.src_timecodes and args.src_fps) or (args.dst_timecodes and args.dst_fps): + raise SushiError('Both fps and timecodes file cannot be specified at the same time') + + src_demuxer = Demuxer(args.source) + dst_demuxer = Demuxer(args.destination) + + if src_demuxer.is_wav and not args.script_file: + raise SushiError("Script file isn't specified") + + if (args.src_keyframes and not args.dst_keyframes) or (args.dst_keyframes and not args.src_keyframes): + raise SushiError('Either none or both of src and dst keyframes should be provided') + + create_directory_if_not_exists(args.temp_dir) + + # selecting source audio + if src_demuxer.is_wav: + src_audio_path = args.source + else: + src_audio_path = format_full_path(args.temp_dir, args.source, '.sushi.wav') + src_demuxer.set_audio(stream_idx=args.src_audio_idx, output_path=src_audio_path, sample_rate=args.sample_rate) + + # selecting destination audio + if dst_demuxer.is_wav: + dst_audio_path = args.destination + else: + dst_audio_path = format_full_path(args.temp_dir, args.destination, '.sushi.wav') + dst_demuxer.set_audio(stream_idx=args.dst_audio_idx, output_path=dst_audio_path, sample_rate=args.sample_rate) + + # selecting source subtitles + if args.script_file: + src_script_path = args.script_file + else: + stype = src_demuxer.get_subs_type(args.src_script_idx) + src_script_path = format_full_path(args.temp_dir, args.source, '.sushi' + stype) + src_demuxer.set_script(stream_idx=args.src_script_idx, output_path=src_script_path) + + script_extension = get_extension(src_script_path) + if script_extension not in ('.ass', '.srt'): + raise SushiError('Unknown script type') + + # selection destination subtitles + if args.output_script: + dst_script_path = args.output_script + dst_script_extension = get_extension(args.output_script) + if dst_script_extension != script_extension: + raise SushiError("Source and destination script file types don't match ({0} vs {1})" + .format(script_extension, dst_script_extension)) + else: + dst_script_path = format_full_path(args.temp_dir, args.destination, '.sushi' + script_extension) + + # selecting chapters + if args.grouping and not ignore_chapters: + if args.chapters_file: + if get_extension(args.chapters_file) == '.xml': + chapter_times = chapters.get_xml_start_times(args.chapters_file) + else: + chapter_times = chapters.get_ogm_start_times(args.chapters_file) + elif not src_demuxer.is_wav: + chapter_times = src_demuxer.chapters + output_path = format_full_path(args.temp_dir, src_demuxer.path, ".sushi.chapters.txt") + src_demuxer.set_chapters(output_path) + else: + chapter_times = [] + else: + chapter_times = [] + + # selecting keyframes and timecodes + if args.src_keyframes: + def select_keyframes(file_arg, demuxer): + auto_file = format_full_path(args.temp_dir, demuxer.path, '.sushi.keyframes.txt') + if file_arg in ('auto', 'make'): + if file_arg == 'make' or not os.path.exists(auto_file): + if not demuxer.has_video: + raise SushiError("Cannot make keyframes for {0} because it doesn't have any video!" + .format(demuxer.path)) + demuxer.set_keyframes(output_path=auto_file) + return auto_file + else: + return file_arg + + def select_timecodes(external_file, fps_arg, demuxer): + if external_file: + return external_file + elif fps_arg: + return None + elif demuxer.has_video: + path = format_full_path(args.temp_dir, demuxer.path, '.sushi.timecodes.txt') + demuxer.set_timecodes(output_path=path) + return path + else: + raise SushiError('Fps, timecodes or video files must be provided if keyframes are used') + + src_keyframes_file = select_keyframes(args.src_keyframes, src_demuxer) + dst_keyframes_file = select_keyframes(args.dst_keyframes, dst_demuxer) + src_timecodes_file = select_timecodes(args.src_timecodes, args.src_fps, src_demuxer) + dst_timecodes_file = select_timecodes(args.dst_timecodes, args.dst_fps, dst_demuxer) + + # after this point nothing should fail so it's safe to start slow operations + # like running the actual demuxing + src_demuxer.demux() + dst_demuxer.demux() + + try: + if args.src_keyframes: + src_timecodes = Timecodes.cfr(args.src_fps) if args.src_fps else Timecodes.from_file(src_timecodes_file) + src_keytimes = [src_timecodes.get_frame_time(f) for f in keyframes.parse_keyframes(src_keyframes_file)] + + dst_timecodes = Timecodes.cfr(args.dst_fps) if args.dst_fps else Timecodes.from_file(dst_timecodes_file) + dst_keytimes = [dst_timecodes.get_frame_time(f) for f in keyframes.parse_keyframes(dst_keyframes_file)] + + script = AssScript.from_file(src_script_path) if script_extension == '.ass' else SrtScript.from_file(src_script_path) + script.sort_by_time() + + src_stream = WavStream(src_audio_path, sample_rate=args.sample_rate, sample_type=args.sample_type) + dst_stream = WavStream(dst_audio_path, sample_rate=args.sample_rate, sample_type=args.sample_type) + + search_groups = prepare_search_groups(script.events, + source_duration=src_stream.duration_seconds, + chapter_times=chapter_times, + max_ts_duration=args.max_ts_duration, + max_ts_distance=args.max_ts_distance) + + calculate_shifts(src_stream, dst_stream, search_groups, + normal_window=args.window, + max_window=args.max_window, + rewind_thresh=args.rewind_thresh if args.grouping else 0) + + events = script.events + + if write_plot: + plt.plot([x.shift for x in events], label='From audio') + + if args.grouping: + if not ignore_chapters and chapter_times: + groups = groups_from_chapters(events, chapter_times) + for g in groups: + fix_near_borders(g) + smooth_events([x for x in g if not x.linked], args.smooth_radius) + groups = split_broken_groups(groups) + else: + fix_near_borders(events) + smooth_events([x for x in events if not x.linked], args.smooth_radius) + groups = detect_groups(events) + + if write_plot: + plt.plot([x.shift for x in events], label='Borders fixed') + + for g in groups: + start_shift = g[0].shift + end_shift = g[-1].shift + avg_shift = average_shifts(g) + logging.info('Group (start: {0}, end: {1}, lines: {2}), ' + 'shifts (start: {3}, end: {4}, average: {5})' + .format(format_time(g[0].start), format_time(g[-1].end), len(g), start_shift, end_shift, + avg_shift)) + + if args.src_keyframes: + for e in (x for x in events if x.linked): + e.resolve_link() + for g in groups: + snap_groups_to_keyframes(g, chapter_times, args.max_ts_duration, args.max_ts_distance, src_keytimes, + dst_keytimes, src_timecodes, dst_timecodes, args.max_kf_distance, args.kf_mode) + else: + fix_near_borders(events) + if write_plot: + plt.plot([x.shift for x in events], label='Borders fixed') + + if args.src_keyframes: + for e in (x for x in events if x.linked): + e.resolve_link() + snap_groups_to_keyframes(events, chapter_times, args.max_ts_duration, args.max_ts_distance, src_keytimes, + dst_keytimes, src_timecodes, dst_timecodes, args.max_kf_distance, args.kf_mode) + + for event in events: + event.apply_shift() + + script.save_to_file(dst_script_path) + + if write_plot: + plt.plot([x.shift + (x._start_shift + x._end_shift) / 2.0 for x in events], label='After correction') + plt.legend(fontsize=5, frameon=False, fancybox=False) + plt.savefig(args.plot_path, dpi=300) + + finally: + if args.cleanup: + src_demuxer.cleanup() + dst_demuxer.cleanup() diff --git a/sushi/__main__.py b/sushi/__main__.py index 3702911..a3c56ed 100755 --- a/sushi/__main__.py +++ b/sushi/__main__.py @@ -1,28 +1,11 @@ -#!/usr/bin/env python2 -import logging -import sys -import operator import argparse +import logging import os -import bisect -from itertools import takewhile, chain +import sys import time -import numpy as np - -from . import chapters -from .common import SushiError, get_extension, format_time, ensure_static_collection -from .demux import Timecodes, Demuxer -from . import keyframes -from .subs import AssScript, SrtScript -from .wav import WavStream - - -try: - import matplotlib.pyplot as plt - plot_enabled = True -except ImportError: - plot_enabled = False +from . import run, VERSION +from .common import SushiError if sys.platform == 'win32': try: @@ -35,11 +18,6 @@ console_colors_supported = True -ALLOWED_ERROR = 0.01 -MAX_GROUP_STD = 0.025 -VERSION = '0.5.1' - - class ColoredLogFormatter(logging.Formatter): bold_code = "\033[1m" reset_code = "\033[0m" @@ -63,679 +41,6 @@ def format(self, record): return super(ColoredLogFormatter, self).format(record) -def abs_diff(a, b): - return abs(a - b) - - -def interpolate_nones(data, points): - data = ensure_static_collection(data) - values_lookup = {p: v for p, v in zip(points, data) if v is not None} - if not values_lookup: - return [] - - zero_points = {p for p, v in zip(points, data) if v is None} - if not zero_points: - return data - - data_list = sorted(values_lookup.items()) - zero_points = sorted(x for x in zero_points if x not in values_lookup) - - out = np.interp(x=zero_points, - xp=list(map(operator.itemgetter(0), data_list)), - fp=list(map(operator.itemgetter(1), data_list))) - - values_lookup.update(zip(zero_points, out)) - - return [ - values_lookup[point] if value is None else value - for point, value in zip(points, data) - ] - - -# todo: implement this as a running median -def running_median(values, window_size): - if window_size % 2 != 1: - raise SushiError('Median window size should be odd') - half_window = window_size // 2 - medians = [] - items_count = len(values) - for idx in range(items_count): - radius = min(half_window, idx, items_count - idx - 1) - med = np.median(values[idx - radius:idx + radius + 1]) - medians.append(med) - return medians - - -def smooth_events(events, radius): - if not radius: - return - window_size = radius * 2 + 1 - shifts = [e.shift for e in events] - smoothed = running_median(shifts, window_size) - for event, new_shift in zip(events, smoothed): - event.set_shift(new_shift, event.diff) - - -def detect_groups(events_iter): - events_iter = iter(events_iter) - groups_list = [[next(events_iter)]] - for event in events_iter: - if abs_diff(event.shift, groups_list[-1][-1].shift) > ALLOWED_ERROR: - groups_list.append([]) - groups_list[-1].append(event) - return groups_list - - -def groups_from_chapters(events, times): - logging.info('Chapter start points: {0}'.format([format_time(t) for t in times])) - groups = [[]] - chapter_times = iter(times[1:] + [36000000000]) # very large event at the end - current_chapter = next(chapter_times) - - for event in events: - if event.end > current_chapter: - groups.append([]) - while event.end > current_chapter: - current_chapter = next(chapter_times) - - groups[-1].append(event) - - groups = [g for g in groups if g] # non-empty groups - # check if we have any groups where every event is linked - # for example a chapter with only comments inside - broken_groups = [group for group in groups if not any(e for e in group if not e.linked)] - if broken_groups: - for group in broken_groups: - for event in group: - parent = event.get_link_chain_end() - parent_group = next(group for group in groups if parent in group) - parent_group.append(event) - del group[:] - groups = [g for g in groups if g] - # re-sort the groups again since we might break the order when inserting linked events - # sorting everything again is far from optimal but python sorting is very fast for sorted arrays anyway - for group in groups: - group.sort(key=lambda event: event.start) - - return groups - - -def split_broken_groups(groups): - correct_groups = [] - broken_found = False - for g in groups: - std = np.std([e.shift for e in g]) - if std > MAX_GROUP_STD: - logging.warn('Shift is not consistent between {0} and {1}, most likely chapters are wrong (std: {2}). ' - 'Switching to automatic grouping.'.format(format_time(g[0].start), format_time(g[-1].end), - std)) - correct_groups.extend(detect_groups(g)) - broken_found = True - else: - correct_groups.append(g) - - if broken_found: - groups_iter = iter(correct_groups) - correct_groups = [list(next(groups_iter))] - for group in groups_iter: - if abs_diff(correct_groups[-1][-1].shift, group[0].shift) >= ALLOWED_ERROR \ - or np.std([e.shift for e in group + correct_groups[-1]]) >= MAX_GROUP_STD: - correct_groups.append([]) - - correct_groups[-1].extend(group) - return correct_groups - - -def fix_near_borders(events): - """ - We assume that all lines with diff greater than 5 * (median diff across all events) are broken - """ - def fix_border(event_list, median_diff): - last_ten_diff = np.median([x.diff for x in event_list[:10]], overwrite_input=True) - diff_limit = min(last_ten_diff, median_diff) - broken = [] - for event in event_list: - if not 0.2 < (event.diff / diff_limit) < 5: - broken.append(event) - else: - for x in broken: - x.link_event(event) - return len(broken) - return 0 - - median_diff = np.median([x.diff for x in events], overwrite_input=True) - - fixed_count = fix_border(events, median_diff) - if fixed_count: - logging.info('Fixing {0} border events right after {1}'.format(fixed_count, format_time(events[0].start))) - - fixed_count = fix_border(list(reversed(events)), median_diff) - if fixed_count: - logging.info('Fixing {0} border events right before {1}'.format(fixed_count, format_time(events[-1].end))) - - -def get_distance_to_closest_kf(timestamp, keyframes): - idx = bisect.bisect_left(keyframes, timestamp) - if idx == 0: - kf = keyframes[0] - elif idx == len(keyframes): - kf = keyframes[-1] - else: - before = keyframes[idx - 1] - after = keyframes[idx] - kf = after if after - timestamp < timestamp - before else before - return kf - timestamp - - -def find_keyframe_shift(group, src_keytimes, dst_keytimes, src_timecodes, dst_timecodes, max_kf_distance): - def get_distance(src_distance, dst_distance, limit): - if abs(dst_distance) > limit: - return None - shift = dst_distance - src_distance - return shift if abs(shift) < limit else None - - src_start = get_distance_to_closest_kf(group[0].start, src_keytimes) - src_end = get_distance_to_closest_kf(group[-1].end + src_timecodes.get_frame_size(group[-1].end), src_keytimes) - - dst_start = get_distance_to_closest_kf(group[0].shifted_start, dst_keytimes) - dst_end = get_distance_to_closest_kf(group[-1].shifted_end + dst_timecodes.get_frame_size(group[-1].end), dst_keytimes) - - snapping_limit_start = src_timecodes.get_frame_size(group[0].start) * max_kf_distance - snapping_limit_end = src_timecodes.get_frame_size(group[0].end) * max_kf_distance - - return (get_distance(src_start, dst_start, snapping_limit_start), - get_distance(src_end, dst_end, snapping_limit_end)) - - -def find_keyframes_distances(event, src_keytimes, dst_keytimes, timecodes, max_kf_distance): - def find_keyframe_distance(src_time, dst_time): - src = get_distance_to_closest_kf(src_time, src_keytimes) - dst = get_distance_to_closest_kf(dst_time, dst_keytimes) - snapping_limit = timecodes.get_frame_size(src_time) * max_kf_distance - - if abs(src) < snapping_limit and abs(dst) < snapping_limit and abs(src - dst) < snapping_limit: - return dst - src - return 0 - - ds = find_keyframe_distance(event.start, event.shifted_start) - de = find_keyframe_distance(event.end, event.shifted_end) - return ds, de - - -def snap_groups_to_keyframes(events, chapter_times, max_ts_duration, max_ts_distance, src_keytimes, dst_keytimes, - src_timecodes, dst_timecodes, max_kf_distance, kf_mode): - if not max_kf_distance: - return - - groups = merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance) - - if kf_mode == 'all' or kf_mode == 'shift': - # step 1: snap events without changing their duration. Useful for some slight audio imprecision correction - shifts = [] - times = [] - for group in groups: - shifts.extend(find_keyframe_shift(group, src_keytimes, dst_keytimes, src_timecodes, dst_timecodes, max_kf_distance)) - times.extend((group[0].shifted_start, group[-1].shifted_end)) - - shifts = interpolate_nones(shifts, times) - if shifts: - mean_shift = np.mean(shifts) - shifts = zip(*[iter(shifts)] * 2) - - logging.info('Group {0}-{1} corrected by {2}'.format(format_time(events[0].start), format_time(events[-1].end), mean_shift)) - for group, (start_shift, end_shift) in zip(groups, shifts): - if abs(start_shift - end_shift) > 0.001 and len(group) > 1: - actual_shift = min(start_shift, end_shift, key=lambda x: abs(x - mean_shift)) - logging.warning("Typesetting group at {0} had different shift at start/end points ({1} and {2}). Shifting by {3}." - .format(format_time(group[0].start), start_shift, end_shift, actual_shift)) - for e in group: - e.adjust_shift(actual_shift) - else: - for e in group: - e.adjust_additional_shifts(start_shift, end_shift) - - if kf_mode == 'all' or kf_mode == 'snap': - # step 2: snap start/end times separately - for group in groups: - if len(group) > 1: - pass # we don't snap typesetting - start_shift, end_shift = find_keyframes_distances(group[0], src_keytimes, dst_keytimes, src_timecodes, max_kf_distance) - if abs(start_shift) > 0.01 or abs(end_shift) > 0.01: - logging.info('Snapping {0} to keyframes, start time by {1}, end: {2}'.format(format_time(group[0].start), start_shift, end_shift)) - group[0].adjust_additional_shifts(start_shift, end_shift) - - -def average_shifts(events): - events = [e for e in events if not e.linked] - shifts = [x.shift for x in events] - weights = [1 - x.diff for x in events] - avg = np.average(shifts, weights=weights) - for e in events: - e.set_shift(avg, e.diff) - return avg - - -def merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance): - search_groups = [] - chapter_times = iter(chapter_times[1:] + [100000000]) - next_chapter = next(chapter_times) - events = ensure_static_collection(events) - - processed = set() - for idx, event in enumerate(events): - if idx in processed: - continue - - while event.end > next_chapter: - next_chapter = next(chapter_times) - - if event.duration > max_ts_duration: - search_groups.append([event]) - processed.add(idx) - else: - group = [event] - group_end = event.end - i = idx + 1 - while i < len(events) and abs(group_end - events[i].start) < max_ts_distance: - if events[i].end < next_chapter and events[i].duration <= max_ts_duration: - processed.add(i) - group.append(events[i]) - group_end = max(group_end, events[i].end) - i += 1 - - search_groups.append(group) - - return search_groups - - -def prepare_search_groups(events, source_duration, chapter_times, max_ts_duration, max_ts_distance): - last_unlinked = None - for idx, event in enumerate(events): - if event.is_comment: - try: - event.link_event(events[idx + 1]) - except IndexError: - event.link_event(last_unlinked) - continue - if (event.start + event.duration / 2.0) > source_duration: - logging.info('Event time outside of audio range, ignoring: %s', event) - event.link_event(last_unlinked) - continue - elif event.end == event.start: - logging.info('{0}: skipped because zero duration'.format(format_time(event.start))) - try: - event.link_event(events[idx + 1]) - except IndexError: - event.link_event(last_unlinked) - continue - - # link lines with start and end times identical to some other event - # assuming scripts are sorted by start time so we don't search the entire collection - def same_start(x): - return event.start == x.start - processed = next((x for x in takewhile(same_start, reversed(events[:idx])) if not x.linked and x.end == event.end), None) - if processed: - event.link_event(processed) - else: - last_unlinked = event - - events = (e for e in events if not e.linked) - - search_groups = merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance) - - # link groups contained inside other groups to the larger group - passed_groups = [] - for idx, group in enumerate(search_groups): - try: - other = next(x for x in reversed(search_groups[:idx]) - if x[0].start <= group[0].start - and x[-1].end >= group[-1].end) - for event in group: - event.link_event(other[0]) - except StopIteration: - passed_groups.append(group) - return passed_groups - - -def calculate_shifts(src_stream, dst_stream, groups_list, normal_window, max_window, rewind_thresh): - def log_shift(state): - logging.info('{0}-{1}: shift: {2:0.10f}, diff: {3:0.10f}' - .format(format_time(state["start_time"]), format_time(state["end_time"]), state["shift"], state["diff"])) - - def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offset): - logging.debug('{0}-{1}: shift: {2:0.5f} [{3:0.5f}, {4:0.5f}], search offset: {5:0.6f}' - .format(format_time(state["start_time"]), format_time(state["end_time"]), - shift, left_side_shift, right_side_shift, search_offset)) - - small_window = 1.5 - idx = 0 - committed_states = [] - uncommitted_states = [] - window = normal_window - while idx < len(groups_list): - search_group = groups_list[idx] - tv_audio = src_stream.get_substream(search_group[0].start, search_group[-1].end) - original_time = search_group[0].start - group_state = {"start_time": search_group[0].start, "end_time": search_group[-1].end, "shift": None, "diff": None} - last_committed_shift = committed_states[-1]["shift"] if committed_states else 0 - diff = new_time = None - - if not uncommitted_states: - if original_time + last_committed_shift > dst_stream.duration_seconds: - # event outside of audio range, all events past it are also guaranteed to fail - for g in groups_list[idx:]: - committed_states.append({"start_time": g[0].start, "end_time": g[-1].end, "shift": None, "diff": None}) - logging.info("{0}-{1}: outside of audio range".format(format_time(g[0].start), format_time(g[-1].end))) - break - - if small_window < window: - diff, new_time = dst_stream.find_substream(tv_audio, original_time + last_committed_shift, small_window) - - if new_time is not None and abs_diff(new_time - original_time, last_committed_shift) <= ALLOWED_ERROR: - # fastest case - small window worked, commit the group immediately - group_state.update({"shift": new_time - original_time, "diff": diff}) - committed_states.append(group_state) - log_shift(group_state) - if window != normal_window: - logging.info("Going back to window {0} from {1}".format(normal_window, window)) - window = normal_window - idx += 1 - continue - - left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0]) / 2], axis=1) - right_half_offset = len(left_audio_half[0]) / float(src_stream.sample_rate) - terminate = False - # searching from last committed shift - if original_time + last_committed_shift < dst_stream.duration_seconds: - diff, new_time = dst_stream.find_substream(tv_audio, original_time + last_committed_shift, window) - left_side_time = dst_stream.find_substream(left_audio_half, original_time + last_committed_shift, window)[1] - right_side_time = dst_stream.find_substream(right_audio_half, original_time + last_committed_shift + right_half_offset, window)[1] - right_half_offset - terminate = abs_diff(left_side_time, right_side_time) <= ALLOWED_ERROR and abs_diff(new_time, left_side_time) <= ALLOWED_ERROR - log_uncommitted(group_state, new_time - original_time, left_side_time - original_time, - right_side_time - original_time, last_committed_shift) - - if not terminate and uncommitted_states and uncommitted_states[-1]["shift"] is not None \ - and original_time + uncommitted_states[-1]["shift"] < dst_stream.duration_seconds: - start_offset = uncommitted_states[-1]["shift"] - diff, new_time = dst_stream.find_substream(tv_audio, original_time + start_offset, window) - left_side_time = dst_stream.find_substream(left_audio_half, original_time + start_offset, window)[1] - right_side_time = dst_stream.find_substream(right_audio_half, original_time + start_offset + right_half_offset, window)[1] - right_half_offset - terminate = abs_diff(left_side_time, right_side_time) <= ALLOWED_ERROR and abs_diff(new_time, left_side_time) <= ALLOWED_ERROR - log_uncommitted(group_state, new_time - original_time, left_side_time - original_time, - right_side_time - original_time, start_offset) - - shift = new_time - original_time - if not terminate: - # we aren't back on track yet - add this group to uncommitted - group_state.update({"shift": shift, "diff": diff}) - uncommitted_states.append(group_state) - idx += 1 - if rewind_thresh == len(uncommitted_states) and window < max_window: - logging.warn("Detected possibly broken segment starting at {0}, increasing the window from {1} to {2}" - .format(format_time(uncommitted_states[0]["start_time"]), window, max_window)) - window = max_window - idx = len(committed_states) - del uncommitted_states[:] - continue - - # we're back on track - apply current shift to all broken events - if uncommitted_states: - logging.warning("Events from {0} to {1} will most likely be broken!".format( - format_time(uncommitted_states[0]["start_time"]), - format_time(uncommitted_states[-1]["end_time"]))) - - uncommitted_states.append(group_state) - for state in uncommitted_states: - state.update({"shift": shift, "diff": diff}) - log_shift(state) - committed_states.extend(uncommitted_states) - del uncommitted_states[:] - idx += 1 - - for state in uncommitted_states: - log_shift(state) - - for idx, (search_group, group_state) in enumerate(zip(groups_list, chain(committed_states, uncommitted_states))): - if group_state["shift"] is None: - for group in reversed(groups_list[:idx]): - link_to = next((x for x in reversed(group) if not x.linked), None) - if link_to: - for e in search_group: - e.link_event(link_to) - break - else: - for e in search_group: - e.set_shift(group_state["shift"], group_state["diff"]) - - -def check_file_exists(path, file_title): - if path and not os.path.exists(path): - raise SushiError("{0} file doesn't exist".format(file_title)) - - -def format_full_path(temp_dir, base_path, postfix): - if temp_dir: - return os.path.join(temp_dir, os.path.basename(base_path) + postfix) - else: - return base_path + postfix - - -def create_directory_if_not_exists(path): - if path and not os.path.exists(path): - os.makedirs(path) - - -def run(args): - ignore_chapters = args.chapters_file is not None and args.chapters_file.lower() == 'none' - write_plot = plot_enabled and args.plot_path - if write_plot: - plt.clf() - plt.ylabel('Shift, seconds') - plt.xlabel('Event index') - - # first part should do all possible validation and should NOT take significant amount of time - check_file_exists(args.source, 'Source') - check_file_exists(args.destination, 'Destination') - check_file_exists(args.src_timecodes, 'Source timecodes') - check_file_exists(args.dst_timecodes, 'Source timecodes') - check_file_exists(args.script_file, 'Script') - - if not ignore_chapters: - check_file_exists(args.chapters_file, 'Chapters') - if args.src_keyframes not in ('auto', 'make'): - check_file_exists(args.src_keyframes, 'Source keyframes') - if args.dst_keyframes not in ('auto', 'make'): - check_file_exists(args.dst_keyframes, 'Destination keyframes') - - if (args.src_timecodes and args.src_fps) or (args.dst_timecodes and args.dst_fps): - raise SushiError('Both fps and timecodes file cannot be specified at the same time') - - src_demuxer = Demuxer(args.source) - dst_demuxer = Demuxer(args.destination) - - if src_demuxer.is_wav and not args.script_file: - raise SushiError("Script file isn't specified") - - if (args.src_keyframes and not args.dst_keyframes) or (args.dst_keyframes and not args.src_keyframes): - raise SushiError('Either none or both of src and dst keyframes should be provided') - - create_directory_if_not_exists(args.temp_dir) - - # selecting source audio - if src_demuxer.is_wav: - src_audio_path = args.source - else: - src_audio_path = format_full_path(args.temp_dir, args.source, '.sushi.wav') - src_demuxer.set_audio(stream_idx=args.src_audio_idx, output_path=src_audio_path, sample_rate=args.sample_rate) - - # selecting destination audio - if dst_demuxer.is_wav: - dst_audio_path = args.destination - else: - dst_audio_path = format_full_path(args.temp_dir, args.destination, '.sushi.wav') - dst_demuxer.set_audio(stream_idx=args.dst_audio_idx, output_path=dst_audio_path, sample_rate=args.sample_rate) - - # selecting source subtitles - if args.script_file: - src_script_path = args.script_file - else: - stype = src_demuxer.get_subs_type(args.src_script_idx) - src_script_path = format_full_path(args.temp_dir, args.source, '.sushi' + stype) - src_demuxer.set_script(stream_idx=args.src_script_idx, output_path=src_script_path) - - script_extension = get_extension(src_script_path) - if script_extension not in ('.ass', '.srt'): - raise SushiError('Unknown script type') - - # selection destination subtitles - if args.output_script: - dst_script_path = args.output_script - dst_script_extension = get_extension(args.output_script) - if dst_script_extension != script_extension: - raise SushiError("Source and destination script file types don't match ({0} vs {1})" - .format(script_extension, dst_script_extension)) - else: - dst_script_path = format_full_path(args.temp_dir, args.destination, '.sushi' + script_extension) - - # selecting chapters - if args.grouping and not ignore_chapters: - if args.chapters_file: - if get_extension(args.chapters_file) == '.xml': - chapter_times = chapters.get_xml_start_times(args.chapters_file) - else: - chapter_times = chapters.get_ogm_start_times(args.chapters_file) - elif not src_demuxer.is_wav: - chapter_times = src_demuxer.chapters - output_path = format_full_path(args.temp_dir, src_demuxer.path, ".sushi.chapters.txt") - src_demuxer.set_chapters(output_path) - else: - chapter_times = [] - else: - chapter_times = [] - - # selecting keyframes and timecodes - if args.src_keyframes: - def select_keyframes(file_arg, demuxer): - auto_file = format_full_path(args.temp_dir, demuxer.path, '.sushi.keyframes.txt') - if file_arg in ('auto', 'make'): - if file_arg == 'make' or not os.path.exists(auto_file): - if not demuxer.has_video: - raise SushiError("Cannot make keyframes for {0} because it doesn't have any video!" - .format(demuxer.path)) - demuxer.set_keyframes(output_path=auto_file) - return auto_file - else: - return file_arg - - def select_timecodes(external_file, fps_arg, demuxer): - if external_file: - return external_file - elif fps_arg: - return None - elif demuxer.has_video: - path = format_full_path(args.temp_dir, demuxer.path, '.sushi.timecodes.txt') - demuxer.set_timecodes(output_path=path) - return path - else: - raise SushiError('Fps, timecodes or video files must be provided if keyframes are used') - - src_keyframes_file = select_keyframes(args.src_keyframes, src_demuxer) - dst_keyframes_file = select_keyframes(args.dst_keyframes, dst_demuxer) - src_timecodes_file = select_timecodes(args.src_timecodes, args.src_fps, src_demuxer) - dst_timecodes_file = select_timecodes(args.dst_timecodes, args.dst_fps, dst_demuxer) - - # after this point nothing should fail so it's safe to start slow operations - # like running the actual demuxing - src_demuxer.demux() - dst_demuxer.demux() - - try: - if args.src_keyframes: - src_timecodes = Timecodes.cfr(args.src_fps) if args.src_fps else Timecodes.from_file(src_timecodes_file) - src_keytimes = [src_timecodes.get_frame_time(f) for f in keyframes.parse_keyframes(src_keyframes_file)] - - dst_timecodes = Timecodes.cfr(args.dst_fps) if args.dst_fps else Timecodes.from_file(dst_timecodes_file) - dst_keytimes = [dst_timecodes.get_frame_time(f) for f in keyframes.parse_keyframes(dst_keyframes_file)] - - script = AssScript.from_file(src_script_path) if script_extension == '.ass' else SrtScript.from_file(src_script_path) - script.sort_by_time() - - src_stream = WavStream(src_audio_path, sample_rate=args.sample_rate, sample_type=args.sample_type) - dst_stream = WavStream(dst_audio_path, sample_rate=args.sample_rate, sample_type=args.sample_type) - - search_groups = prepare_search_groups(script.events, - source_duration=src_stream.duration_seconds, - chapter_times=chapter_times, - max_ts_duration=args.max_ts_duration, - max_ts_distance=args.max_ts_distance) - - calculate_shifts(src_stream, dst_stream, search_groups, - normal_window=args.window, - max_window=args.max_window, - rewind_thresh=args.rewind_thresh if args.grouping else 0) - - events = script.events - - if write_plot: - plt.plot([x.shift for x in events], label='From audio') - - if args.grouping: - if not ignore_chapters and chapter_times: - groups = groups_from_chapters(events, chapter_times) - for g in groups: - fix_near_borders(g) - smooth_events([x for x in g if not x.linked], args.smooth_radius) - groups = split_broken_groups(groups) - else: - fix_near_borders(events) - smooth_events([x for x in events if not x.linked], args.smooth_radius) - groups = detect_groups(events) - - if write_plot: - plt.plot([x.shift for x in events], label='Borders fixed') - - for g in groups: - start_shift = g[0].shift - end_shift = g[-1].shift - avg_shift = average_shifts(g) - logging.info('Group (start: {0}, end: {1}, lines: {2}), ' - 'shifts (start: {3}, end: {4}, average: {5})' - .format(format_time(g[0].start), format_time(g[-1].end), len(g), start_shift, end_shift, - avg_shift)) - - if args.src_keyframes: - for e in (x for x in events if x.linked): - e.resolve_link() - for g in groups: - snap_groups_to_keyframes(g, chapter_times, args.max_ts_duration, args.max_ts_distance, src_keytimes, - dst_keytimes, src_timecodes, dst_timecodes, args.max_kf_distance, args.kf_mode) - else: - fix_near_borders(events) - if write_plot: - plt.plot([x.shift for x in events], label='Borders fixed') - - if args.src_keyframes: - for e in (x for x in events if x.linked): - e.resolve_link() - snap_groups_to_keyframes(events, chapter_times, args.max_ts_duration, args.max_ts_distance, src_keytimes, - dst_keytimes, src_timecodes, dst_timecodes, args.max_kf_distance, args.kf_mode) - - for event in events: - event.apply_shift() - - script.save_to_file(dst_script_path) - - if write_plot: - plt.plot([x.shift + (x._start_shift + x._end_shift) / 2.0 for x in events], label='After correction') - plt.legend(fontsize=5, frameon=False, fancybox=False) - plt.savefig(args.plot_path, dpi=300) - - finally: - if args.cleanup: - src_demuxer.cleanup() - dst_demuxer.cleanup() - - def create_arg_parser(): parser = argparse.ArgumentParser(description='Sushi - Automatic Subtitle Shifter') diff --git a/sushi/demux.py b/sushi/demux.py index f833cc8..bb7ca57 100644 --- a/sushi/demux.py +++ b/sushi/demux.py @@ -17,7 +17,8 @@ class FFmpeg(object): @staticmethod def get_info(path): try: - process = subprocess.Popen(['ffmpeg', '-hide_banner', '-i', path], stderr=subprocess.PIPE) + # text=True is an alias for universal_newlines since 3.7 + process = subprocess.Popen(['ffmpeg', '-hide_banner', '-i', path], stderr=subprocess.PIPE, universal_newlines=True) out, err = process.communicate() process.wait() return err diff --git a/sushi/subs.py b/sushi/subs.py index 3b12ed7..7d8ab39 100644 --- a/sushi/subs.py +++ b/sushi/subs.py @@ -79,9 +79,6 @@ def adjust_shift(self, value): assert not self.linked, 'Cannot adjust time of linked events' self._shift += value - def __repr__(self): - return str(self) - class ScriptBase(object): def __init__(self, events): @@ -112,7 +109,7 @@ def from_string(cls, text): end = cls.parse_time(match.group(3)) return SrtEvent(int(match.group(1)), start, end, match.group(4).strip()) - def __unicode__(self): + def __str__(self): return '{0}\n{1} --> {2}\n{3}'.format(self.source_index, self._format_time(self.start), self._format_time(self.end), self.text) @@ -168,7 +165,7 @@ def __init__(self, text, position=0): self.margin_vertical = split[7] self.effect = split[8] - def __unicode__(self): + def __str__(self): return '{0}: {1},{2},{3},{4},{5},{6},{7},{8},{9},{10}'.format(self.kind, self.layer, self._format_time(self.start), self._format_time(self.end), diff --git a/tests/main.py b/tests/main.py index e3e2fa7..ebf380a 100644 --- a/tests/main.py +++ b/tests/main.py @@ -4,7 +4,8 @@ from unittest.mock import patch, ANY from sushi.common import SushiError, format_time -from sushi import __main__ as sushi +import sushi +from sushi import __main__ as main here = os.path.dirname(os.path.abspath(__file__)) @@ -107,24 +108,24 @@ def test_events_in_two_groups_one_chapter(self): events = [FakeEvent(end=1), FakeEvent(end=2), FakeEvent(end=3)] groups = sushi.groups_from_chapters(events, [0.0, 1.5]) self.assertEqual(2, len(groups)) - self.assertItemsEqual([events[0]], groups[0]) - self.assertItemsEqual([events[1], events[2]], groups[1]) + self.assertEqual([events[0]], groups[0]) + self.assertEqual([events[1], events[2]], groups[1]) def test_multiple_groups_multiple_chapters(self): events = [FakeEvent(end=x) for x in range(1, 10)] groups = sushi.groups_from_chapters(events, [0.0, 3.2, 4.4, 7.7]) self.assertEqual(4, len(groups)) - self.assertItemsEqual(events[0:3], groups[0]) - self.assertItemsEqual(events[3:4], groups[1]) - self.assertItemsEqual(events[4:7], groups[2]) - self.assertItemsEqual(events[7:9], groups[3]) + self.assertEqual(events[0:3], groups[0]) + self.assertEqual(events[3:4], groups[1]) + self.assertEqual(events[4:7], groups[2]) + self.assertEqual(events[7:9], groups[3]) class SplitBrokenGroupsTestCase(unittest.TestCase): def test_doing_nothing_on_correct_groups(self): groups = [[FakeEvent(0.5), FakeEvent(0.5)], [FakeEvent(10.0)]] fixed = sushi.split_broken_groups(groups) - self.assertItemsEqual(groups, fixed) + self.assertEqual(groups, fixed) def test_split_groups_without_merging(self): groups = [ @@ -132,7 +133,7 @@ def test_split_groups_without_merging(self): [FakeEvent(0.5)] * 10, ] fixed = sushi.split_broken_groups(groups) - self.assertItemsEqual([ + self.assertEqual([ [FakeEvent(0.5)] * 10, [FakeEvent(10.0)] * 5, [FakeEvent(0.5)] * 10 @@ -144,7 +145,7 @@ def test_split_groups_with_merging(self): [FakeEvent(10.0), FakeEvent(10.0), FakeEvent(15.0)] ] fixed = sushi.split_broken_groups(groups) - self.assertItemsEqual([ + self.assertEqual([ [FakeEvent(0.5)], [FakeEvent(10.0), FakeEvent(10.0), FakeEvent(10.0)], [FakeEvent(15.0)] @@ -192,7 +193,7 @@ def test_checks_that_files_exist(self, mock_object): '--dst-keyframes', 'dst-keyframes', '--src-keyframes', 'src-keyframes', '--src-timecodes', 'src-tcs', '--dst-timecodes', 'dst-tcs'] try: - sushi.parse_args_and_run(keys) + main.parse_args_and_run(keys) except SushiError: pass mock_object.assert_any_call('src', ANY) @@ -206,16 +207,16 @@ def test_checks_that_files_exist(self, mock_object): def test_raises_on_unknown_script_type(self, ignore): keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.mp4'] - self.assertRaisesRegex(SushiError, self.any_case_regex(r'script.*type'), lambda: sushi.parse_args_and_run(keys)) + self.assertRaisesRegex(SushiError, self.any_case_regex(r'script.*type'), lambda: main.parse_args_and_run(keys)) def test_raises_on_script_type_not_matching(self, ignore): keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '-o', 'd.srt'] self.assertRaisesRegex(SushiError, self.any_case_regex(r'script.*type.*match'), - lambda: sushi.parse_args_and_run(keys)) + lambda: main.parse_args_and_run(keys)) def test_raises_on_timecodes_and_fps_being_defined_together(self, ignore): keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '--src-timecodes', 'tc.txt', '--src-fps', '25'] - self.assertRaisesRegex(SushiError, self.any_case_regex(r'timecodes'), lambda: sushi.parse_args_and_run(keys)) + self.assertRaisesRegex(SushiError, self.any_case_regex(r'timecodes'), lambda: main.parse_args_and_run(keys)) class FormatTimeTestCase(unittest.TestCase): diff --git a/tests/subtitles.py b/tests/subtitles.py index fae397e..fe21fec 100644 --- a/tests/subtitles.py +++ b/tests/subtitles.py @@ -81,7 +81,7 @@ def test_write_to_file(self): self.assertEqual(SINGLE_LINE_SRT_EVENT + "\n\n" + MULTILINE_SRT_EVENT, text) def test_read_from_file(self): - os.write(self.script_description, """1 + os.write(self.script_description, b"""1 00:00:17,500 --> 00:00:18,870 Yeah, really! @@ -131,7 +131,7 @@ def test_write_to_file(self): {0}""".format(ASS_EVENT), text) def test_read_from_file(self): - text = """[Script Info] + text = b"""[Script Info] ; Script generated by Aegisub 3.1.1 Title: script title From 6cf970e3480c34c05f74a9fe39b939a294c0dcee Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 21:31:11 +0200 Subject: [PATCH 05/17] Update Python versions for CI --- .travis.yml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.travis.yml b/.travis.yml index 23b3bd5..bca2c58 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,14 +3,16 @@ language: python virtualenv: system_site_packages: true before_install: - - sudo apt-get update - - sudo apt-get install python-opencv - - sudo dpkg -L python-opencv - - sudo ln /dev/null /dev/raw1394 + - sudo apt-get update + - sudo apt-get install python-opencv + - sudo dpkg -L python-opencv + - sudo ln /dev/null /dev/raw1394 install: - - "pip install -r requirements.txt" + - pip install -r requirements.txt python: - - "2.7" + - "3.5" + - "3.6" + - "3.7" script: - python run-tests.py From 953946a6a625f36098a1eb5395cd69e888ad1b5b Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 21:32:11 +0200 Subject: [PATCH 06/17] Remove mock requirement --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 54b39f5..24ce15a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ numpy -mock From 424d8e050ceeea9ac583ab5af4cf83f1f37bdca9 Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 21:36:18 +0200 Subject: [PATCH 07/17] Add very basic flake8 config --- .flake8 | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..e78aa1e --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +[flake8] +exclude = + .git, + build, + dist, + *.egg-info, + venv, + __pycache__, + +extend-ignore = + E501, # max line length From 6f3cc4876ee283f87dc8414408bd1955bf7bf155 Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 21:45:47 +0200 Subject: [PATCH 08/17] Use setuptools for installation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I don't know how to make an entry point with distutils – it might not even be supported. --- .gitignore | 1 + setup.py | 11 ++++++++--- sushi/__main__.py | 6 +++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 1f0dcc1..5057722 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ tests/media dist build tests.json +*.egg-info diff --git a/setup.py b/setup.py index 312a0bb..66d85e8 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,16 @@ -from distutils.core import setup +from setuptools import setup import sushi setup( name='Sushi', description='Automatic subtitle shifter based on audio', + packages=['sushi'], version=sushi.VERSION, url='https://github.com/tp7/Sushi', - console=['sushi.py'], - license='MIT' + license='MIT', + entry_points={ + 'console_scripts': [ + "sushi=sushi.__main__:main", + ], + }, ) diff --git a/sushi/__main__.py b/sushi/__main__.py index a3c56ed..ebfbea4 100755 --- a/sushi/__main__.py +++ b/sushi/__main__.py @@ -140,9 +140,13 @@ def format_arg(arg): logging.info('Done in {0}s'.format(time.time() - start_time)) -if __name__ == '__main__': +def main(): try: parse_args_and_run(sys.argv[1:]) except SushiError as e: logging.critical(e.message) sys.exit(2) + + +if __name__ == '__main__': + main() From 879c7cdfc343f36d2e271ab7e373a42ceea21048 Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 21:53:14 +0200 Subject: [PATCH 09/17] Update README installation instructions Someone please verify macOS for me because I have no clue about that. I also removed the opencv Windows binary, because that's for 2.7. It can be re-added later. --- README.md | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 9771cbd..badc9b4 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,9 @@ Do note that WAV is not the only format Sushi can work with. It can process audi ### Requirements Sushi should work on Windows, Linux and OS X. Please open an issue if it doesn't. To run it, you have to have the following installed: -1. [Python 2.7.x][5] +1. [Python 3.5 or higher][5] 2. [NumPy][6] (1.8 or newer) -3. [OpenCV 2.4.x or newer][7] (on Windows putting [this file][8] in the same folder as Sushi should be enough, assuming you use x86 Python) +3. [OpenCV 2.4.x or newer][7] Optionally, you might want: @@ -41,16 +41,21 @@ The provided Windows binaries include all required components and Colorama so yo #### Installation on Mac OS X -No binary packages are provided for OS X right now so you'll have to use the script form. Assuming you have python 2, pip and [homebrew](http://brew.sh/) installed, run the following: +No binary packages are provided for OS X right now so you'll have to use the script form. Assuming you have Python 3, pip and [homebrew](http://brew.sh/) installed, run the following: ```bash brew tap homebrew/science brew install git opencv -pip install numpy -git clone https://github.com/tp7/sushi -# create a symlink if you want to run sushi globally -ln -s `pwd`/sushi/sushi.py /usr/local/bin/sushi +pip3 install numpy # install some optional dependencies brew install ffmpeg mkvtoolnix + +# fetch sushi +git clone https://github.com/tp7/sushi +# run from source +python3 -m sushi args… +# install globally (for your user) +python3 setup.py install --user +sushi args… ``` If you don't have pip, you can install numpy with homebrew, but that will probably add a few more dependencies. ```bash @@ -62,9 +67,13 @@ brew install numpy If you have apt-get available, the installation process is trivial. ```bash sudo apt-get update -sudo apt-get install git python python-numpy python-opencv +sudo apt-get install git python3 python3-numpy python3-opencv git clone https://github.com/tp7/sushi -ln -s `pwd`/sushi/sushi.py /usr/local/bin/sushi +# run from source +python3 -m sushi args… +# install globally (for your user; ensure ~/.local/bin is in your PATH) +python3 setup.py install --user +sushi args… ``` ### Limitations @@ -82,7 +91,6 @@ In short, while this might be safe for immediate viewing, you probably shouldn't [5]: https://www.python.org/downloads/ [6]: http://www.scipy.org/scipylib/download.html [7]: http://opencv.org/ - [8]: https://www.dropbox.com/s/nlylgdh4bgrjgxv/cv2.pyd?dl=0 [9]: http://www.ffmpeg.org/download.html [10]: http://www.bunkus.org/videotools/mkvtoolnix/downloads.html [11]: https://github.com/soyokaze/SCXvid-standalone/releases From 622b65fe609135ccbf939dc4101c546730ff0b87 Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 22:22:37 +0200 Subject: [PATCH 10/17] Attempt to build on multiple environments There is currently no environment on Travis with 3.7. Also fix opencv package being installed. --- .travis.yml | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/.travis.yml b/.travis.yml index bca2c58..2ba2d1b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,18 +1,24 @@ language: python +matrix: + include: + - name: "Python 3.5 on Xenial" + dist: xenial + python: 3.5 + - name: "Python 3.6 on Bionic" + dist: bionic + python: 3.6 virtualenv: system_site_packages: true + before_install: - sudo apt-get update - - sudo apt-get install python-opencv - - sudo dpkg -L python-opencv + - sudo apt-get install python3-opencv + - sudo dpkg -L python3-opencv - sudo ln /dev/null /dev/raw1394 + install: - pip install -r requirements.txt -python: - - "3.5" - - "3.6" - - "3.7" script: - python run-tests.py From eb359e045ef4345539d9dd056e14f4bbcca90979 Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Tue, 1 Oct 2019 22:46:13 +0200 Subject: [PATCH 11/17] Test all versions using unofficial binaries Courtesy of https://pypi.org/project/opencv-python/. --- .travis.yml | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/.travis.yml b/.travis.yml index 2ba2d1b..2e1d495 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,23 +1,12 @@ +dist: bionic language: python -matrix: - include: - - name: "Python 3.5 on Xenial" - dist: xenial - python: 3.5 - - name: "Python 3.6 on Bionic" - dist: bionic - python: 3.6 - -virtualenv: - system_site_packages: true - -before_install: - - sudo apt-get update - - sudo apt-get install python3-opencv - - sudo dpkg -L python3-opencv - - sudo ln /dev/null /dev/raw1394 +python: + - 3.5 + - 3.6 + - 3.7 install: + - pip install opencv-python-headless - pip install -r requirements.txt script: From 8573e31392ce97392a156023c6c07e3f06c85ccc Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Wed, 16 Oct 2019 03:20:52 +0200 Subject: [PATCH 12/17] Operands comparing data of chunks must be bytes --- sushi/wav.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sushi/wav.py b/sushi/wav.py index f35d213..e0707be 100644 --- a/sushi/wav.py +++ b/sushi/wav.py @@ -22,9 +22,9 @@ def __init__(self, path): self._file = open(path, 'rb') try: riff = Chunk(self._file, bigendian=False) - if riff.getname() != 'RIFF': + if riff.getname() != b'RIFF': raise SushiError('File does not start with RIFF id') - if riff.read(4) != 'WAVE': + if riff.read(4) != b'WAVE': raise SushiError('Not a WAVE file') fmt_chunk_read = False @@ -37,10 +37,10 @@ def __init__(self, path): except EOFError: break - if chunk.getname() == 'fmt ': + if chunk.getname() == b'fmt ': self._read_fmt_chunk(chunk) fmt_chunk_read = True - elif chunk.getname() == 'data': + elif chunk.getname() == b'data': if file_size > 0xFFFFFFFF: # large broken wav self.frames_count = (file_size - self._file.tell()) // self.frame_size From 26f4112cb9ef6dddc1cd30039c6846363603550b Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Fri, 7 Feb 2020 22:12:07 +0100 Subject: [PATCH 13/17] Force integer division Co-Authored-By: al3xtjames --- sushi/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sushi/__init__.py b/sushi/__init__.py index 524aaff..7847ea8 100644 --- a/sushi/__init__.py +++ b/sushi/__init__.py @@ -405,7 +405,7 @@ def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offs idx += 1 continue - left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0]) / 2], axis=1) + left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0]) // 2], axis=1) right_half_offset = len(left_audio_half[0]) / float(src_stream.sample_rate) terminate = False # searching from last committed shift From 6cdfd269e9eb3eb49c06ca7c3af2e09c31944a3a Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Fri, 7 Feb 2020 22:13:05 +0100 Subject: [PATCH 14/17] Run tests on Python 3.8 --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 2e1d495..d877c49 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,7 @@ python: - 3.5 - 3.6 - 3.7 + - 3.8 install: - pip install opencv-python-headless From ffb8065dacb126d2b564ec867126ab7b0d624b81 Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Sun, 19 Jul 2020 13:47:27 +0200 Subject: [PATCH 15/17] Update pyinstaller packaging Untested, but might be enough to get it working. --- build-windows.bat | 5 +++-- sushi/__main__.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/build-windows.bat b/build-windows.bat index 2cecb33..05fb79f 100644 --- a/build-windows.bat +++ b/build-windows.bat @@ -4,9 +4,10 @@ pyinstaller --noupx --onefile --noconfirm ^ --exclude-module Tkconstants ^ --exclude-module Tkinter ^ --exclude-module matplotlib ^ - sushi.py + --name sushi ^ + sushi/__main__.py mkdir dist\licenses copy /Y licenses\* dist\licenses\* copy LICENSE dist\licenses\Sushi.txt -copy README.md dist\readme.md \ No newline at end of file +copy README.md dist\readme.md diff --git a/sushi/__main__.py b/sushi/__main__.py index ebfbea4..3add7af 100755 --- a/sushi/__main__.py +++ b/sushi/__main__.py @@ -4,8 +4,10 @@ import sys import time -from . import run, VERSION -from .common import SushiError +# Use absolute imports to support pyinstaller +# https://github.com/pyinstaller/pyinstaller/issues/2560 +from sushi import run, VERSION +from sushi.common import SushiError if sys.platform == 'win32': try: From d2e4f4105cb995fdd892f4934dccaa9235f8a25c Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Wed, 13 Jan 2021 15:37:18 +0100 Subject: [PATCH 16/17] Fix error logging when an exception is encountered --- sushi/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sushi/__main__.py b/sushi/__main__.py index 3add7af..41637c3 100755 --- a/sushi/__main__.py +++ b/sushi/__main__.py @@ -146,7 +146,7 @@ def main(): try: parse_args_and_run(sys.argv[1:]) except SushiError as e: - logging.critical(e.message) + logging.critical(e.args[0]) sys.exit(2) From abdd1e9cc8927b537417b9d67efeb0024359608f Mon Sep 17 00:00:00 2001 From: FichteFoll Date: Sun, 17 Jul 2022 20:19:05 +0200 Subject: [PATCH 17/17] Fix text encoding of ffmpeg call on Windows --- sushi/demux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sushi/demux.py b/sushi/demux.py index bb7ca57..fd7192d 100644 --- a/sushi/demux.py +++ b/sushi/demux.py @@ -18,7 +18,8 @@ class FFmpeg(object): def get_info(path): try: # text=True is an alias for universal_newlines since 3.7 - process = subprocess.Popen(['ffmpeg', '-hide_banner', '-i', path], stderr=subprocess.PIPE, universal_newlines=True) + process = subprocess.Popen(['ffmpeg', '-hide_banner', '-i', path], stderr=subprocess.PIPE, + universal_newlines=True, encoding='utf-8') out, err = process.communicate() process.wait() return err