diff --git a/mutagen/_riff.py b/mutagen/_riff.py index d17eb202..c5486bc2 100644 --- a/mutagen/_riff.py +++ b/mutagen/_riff.py @@ -14,7 +14,12 @@ from ._compat import text_type -from mutagen._util import resize_bytes, delete_bytes, MutagenError +from mutagen._util import ( + MutagenError, + delete_bytes, + insert_bytes, + resize_bytes, +) class error(MutagenError): @@ -55,39 +60,50 @@ def assert_valid_chunk_id(id): raise ValueError("Invalid RIFF-chunk-ID.") -class RiffChunkHeader(object): - """ RIFF chunk header""" +class RiffChunk(object): + """Generic RIFF chunk""" # Chunk headers are 8 bytes long (4 for ID and 4 for the size) HEADER_SIZE = 8 - def __init__(self, fileobj, parent_chunk): - self.__fileobj = fileobj - self.parent_chunk = parent_chunk - self.offset = fileobj.tell() - - header = fileobj.read(self.HEADER_SIZE) - if len(header) < self.HEADER_SIZE: - raise InvalidChunk('Header size < %i' % self.HEADER_SIZE) - - self.id, self.data_size = struct.unpack('<4sI', header) - self.data_offset = fileobj.tell() + @classmethod + def parse(cls, fileobj, parent_chunk=None): + header = fileobj.read(cls.HEADER_SIZE) + if len(header) < cls.HEADER_SIZE: + raise InvalidChunk('Header size < %i' % cls.HEADER_SIZE) + id, data_size = struct.unpack('<4sI', header) try: - self.id = self.id.decode('ascii').rstrip() + id = id.decode('ascii').rstrip() except UnicodeDecodeError as e: raise InvalidChunk(e) - if not is_valid_chunk_id(self.id): - raise InvalidChunk('Invalid chunk ID %s' % self.id) + if not is_valid_chunk_id(id): + raise InvalidChunk('Invalid chunk ID %s' % id) + return cls.get_class(id)(fileobj, id, data_size, parent_chunk) + + @classmethod + def get_class(cls, id): + if id in (u'LIST', u'RIFF'): + return ListRiffChunk + else: + return cls + + def __init__(self, fileobj, id, data_size, parent_chunk): + self._fileobj = fileobj + self.id = id + self.data_size = data_size + self.parent_chunk = parent_chunk + self.data_offset = fileobj.tell() + self.offset = self.data_offset - self.HEADER_SIZE self._calculate_size() def read(self): """Read the chunks data""" - self.__fileobj.seek(self.data_offset) - return self.__fileobj.read(self.data_size) + self._fileobj.seek(self.data_offset) + return self._fileobj.read(self.data_size) def write(self, data): """Write the chunk data""" @@ -95,37 +111,37 @@ def write(self, data): if len(data) > self.data_size: raise ValueError - self.__fileobj.seek(self.data_offset) - self.__fileobj.write(data) + self._fileobj.seek(self.data_offset) + self._fileobj.write(data) # Write the padding bytes padding = self.padding() if padding: - self.__fileobj.seek(self.data_offset + self.data_size + 1) - self.__fileobj.write(b'\x00' * padding) + self._fileobj.seek(self.data_offset + self.data_size + 1) + self._fileobj.write(b'\x00' * padding) def delete(self): """Removes the chunk from the file""" - delete_bytes(self.__fileobj, self.size, self.offset) + delete_bytes(self._fileobj, self.size, self.offset) if self.parent_chunk is not None: - self.parent_chunk._update_size( - self.parent_chunk.data_size - self.size) + self.parent_chunk._remove_subchunk(self) - def _update_size(self, data_size): + def _update_size(self, size_diff, changed_subchunk=None): """Update the size of the chunk""" - self.__fileobj.seek(self.offset + 4) - self.__fileobj.write(pack('