diff --git a/Cargo.toml b/Cargo.toml index f1867c3..4950d0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,9 @@ crate-type = ["cdylib"] [dependencies] anyhow = "1.0.79" byteorder = "1.5.0" +futures-util = "0.3.30" numpy = "0.21.0" -ogg = "0.9.1" +ogg = { version = "0.9.1", features = ["async"] } opus = "0.3.0" pyo3 = "0.21.0" rayon = "1.8.1" @@ -19,6 +20,7 @@ rubato = "0.15.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.113" symphonia = { version = "0.5.4", features = ["all"] } +tokio = { version = "1.39.3", features = ["full"] } [profile.release] debug = true diff --git a/py_src/sphn/__init__.pyi b/py_src/sphn/__init__.pyi index 9bafa3a..018e892 100644 --- a/py_src/sphn/__init__.pyi +++ b/py_src/sphn/__init__.pyi @@ -42,6 +42,13 @@ def read_opus_bytes(bytes): """ pass +@staticmethod +def resample(pcm, src_sample_rate, dst_sample_rate): + """ + Resamples some pcm data. + """ + pass + @staticmethod def write_opus(filename, data, sample_rate): """ @@ -112,3 +119,31 @@ class FileReader: The sample rate as an int. """ pass + +class OpusStreamReader: + def __init__(sample_rate): + pass + + def append_bytes(self, data): + """ + Write some ogg/opus bytes to the current stream. + """ + pass + + def read_pcm(self): + """ + Get some pcm data out of the stream. + """ + pass + +class OpusStreamWriter: + def __init__(sample_rate): + pass + + def append_pcm(self, pcm): + """ """ + pass + + def read_bytes(self): + """ """ + pass diff --git a/src/lib.rs b/src/lib.rs index 5347024..9dcc96e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -325,9 +325,85 @@ fn read_opus_bytes(bytes: Vec, py: Python) -> PyResult<(PyObject, u32)> { Ok((data, sample_rate)) } +#[pyclass] +struct OpusStreamWriter { + inner: opus::StreamWriter, + sample_rate: u32, +} + +#[pymethods] +impl OpusStreamWriter { + #[new] + fn new(sample_rate: u32) -> PyResult { + let inner = opus::StreamWriter::new(sample_rate).w()?; + Ok(Self { inner, sample_rate }) + } + + fn __str__(&self) -> String { + format!("OpusStreamWriter(sample_rate={})", self.sample_rate) + } + + fn append_pcm(&mut self, pcm: numpy::PyReadonlyArray1) -> PyResult<()> { + let pcm = pcm.as_array(); + match pcm.as_slice() { + None => { + let pcm = pcm.to_vec(); + self.inner.append_pcm(&pcm).w()? + } + Some(pcm) => self.inner.append_pcm(pcm).w()?, + }; + Ok(()) + } + + fn read_bytes(&mut self) -> PyResult { + let bytes = self.inner.read_bytes().w()?; + let bytes = Python::with_gil(|py| pyo3::types::PyBytes::new_bound(py, &bytes).into_py(py)); + Ok(bytes) + } +} + +#[pyclass] +struct OpusStreamReader { + inner: opus::StreamReader, + sample_rate: u32, +} + +#[pymethods] +impl OpusStreamReader { + #[new] + fn new(sample_rate: u32) -> PyResult { + let inner = opus::StreamReader::new(sample_rate).w()?; + Ok(Self { inner, sample_rate }) + } + + fn __str__(&self) -> String { + format!("OpusStreamReader(sample_rate={})", self.sample_rate) + } + + /// Write some ogg/opus bytes to the current stream. + fn append_bytes(&mut self, data: &[u8]) -> PyResult<()> { + self.inner.append(data).w() + } + + // TODO(laurent): maybe we should also have a pyo3_async api here. + /// Get some pcm data out of the stream. + fn read_pcm(&mut self) -> PyResult { + let pcm_data = self.inner.read_pcm().w()?; + Python::with_gil(|py| match pcm_data { + None => Ok(py.None()), + Some(data) => { + let data = numpy::PyArray1::from_vec_bound(py, data.to_vec()).into_py(py); + Ok(data) + } + }) + } +} + #[pymodule] fn sphn(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(durations, m)?)?; m.add_function(wrap_pyfunction!(read, m)?)?; m.add_function(wrap_pyfunction!(write_wav, m)?)?; diff --git a/src/opus.rs b/src/opus.rs index 55fa48e..6650f5c 100644 --- a/src/opus.rs +++ b/src/opus.rs @@ -5,6 +5,7 @@ use anyhow::Result; // https://opus-codec.org/docs/opus_api-1.2/group__opus__encoder.html#ga4ae9905859cd241ef4bb5c59cd5e5309 const OPUS_ENCODER_FRAME_SIZE: usize = 960; const OPUS_SAMPLE_RATE: u32 = 48000; +const OPUS_ALLOWED_FRAME_SIZES: [usize; 6] = [120, 240, 480, 960, 1920, 2880]; /// See https://www.opus-codec.org/docs/opusfile_api-0.4/structOpusHead.html #[allow(unused)] @@ -198,3 +199,140 @@ pub fn write_ogg_stereo( write_ogg_48khz(w, &pcm, sample_rate, true) } } + +struct BufferStream(std::sync::mpsc::Sender>); + +impl std::io::Write for BufferStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + if self.0.send(buf.to_vec()).is_err() { + return Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "opus stream writer error".to_string(), + )); + }; + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +pub struct StreamWriter { + pw: ogg::PacketWriter<'static, BufferStream>, + encoder: opus::Encoder, + out_encoded: Vec, + total_data: u64, + rx: std::sync::mpsc::Receiver>, +} + +impl StreamWriter { + pub fn new(sample_rate: u32) -> Result { + if sample_rate != 48000 && sample_rate != 24000 { + anyhow::bail!("sample-rate has to be 48000 or 24000, got {sample_rate}") + } + let encoder = + opus::Encoder::new(sample_rate, opus::Channels::Mono, opus::Application::Voip)?; + let (tx, rx) = std::sync::mpsc::channel(); + let mut pw = ogg::PacketWriter::new(BufferStream(tx)); + let out_encoded = vec![0u8; 50_000]; + let mut head = Vec::new(); + write_opus_header(&mut head, 1u8, sample_rate)?; + pw.write_packet(head, 42, ogg::PacketWriteEndInfo::EndPage, 0)?; + let mut tags = Vec::new(); + write_opus_tags(&mut tags)?; + pw.write_packet(tags, 42, ogg::PacketWriteEndInfo::EndPage, 0)?; + Ok(Self { pw, encoder, out_encoded, total_data: 0, rx }) + } + + pub fn append_pcm(&mut self, pcm: &[f32]) -> Result<()> { + if !OPUS_ALLOWED_FRAME_SIZES.contains(&pcm.len()) { + anyhow::bail!( + "pcm length has to match an allowed frame size {OPUS_ALLOWED_FRAME_SIZES:?}, got {}", pcm.len() + ) + } + + let size = self.encoder.encode_float(pcm, &mut self.out_encoded)?; + let msg = self.out_encoded[..size].to_vec(); + self.total_data += pcm.len() as u64; + self.pw.write_packet(msg, 42, ogg::PacketWriteEndInfo::EndPage, self.total_data)?; + Ok(()) + } + + pub fn read_bytes(&mut self) -> Result> { + match self.rx.try_recv() { + Ok(data) => Ok(data), + Err(std::sync::mpsc::TryRecvError::Empty) => Ok(vec![]), + Err(std::sync::mpsc::TryRecvError::Disconnected) => { + anyhow::bail!("opus stream writer disconnected") + } + } + } +} + +pub struct StreamReader { + pr: ogg::reading::async_api::PacketReader, + decoder: opus::Decoder, + tx: tokio::io::DuplexStream, + runtime: tokio::runtime::Runtime, + pcm_buf: Vec, +} + +// The StreamReader implementation uses tokio under the hood, this is a bit of a bummer and comes +// from the ogg crate PacketReader non-async api requiring Seek on its input stream (and so not +// having to keep much inner state), the async version just requires read so is more adapted to +// what we want to provide here. +impl StreamReader { + pub fn new(sample_rate: u32) -> Result { + if sample_rate != 48000 && sample_rate != 24000 { + anyhow::bail!("sample-rate has to be 48000 or 24000, got {sample_rate}") + } + // TODO(laurent): look whether there is a more adapted channel type. + let (tx, rx) = tokio::io::duplex(100_000); + let decoder = opus::Decoder::new(sample_rate, opus::Channels::Mono)?; + let pr = ogg::reading::async_api::PacketReader::new(rx); + // TODO(laurent): consider spawning a thread so that the process happens in the background. + let runtime = tokio::runtime::Runtime::new()?; + let pcm_buf = vec![0f32; 24_000 * 10]; + Ok(Self { pr, decoder, tx, runtime, pcm_buf }) + } + + pub fn append(&mut self, data: &[u8]) -> Result<()> { + use tokio::io::AsyncWriteExt; + + self.runtime.block_on({ + // Maybe we should wait for a bit of time here to allow for the PacketReader to do some + // processing. + self.tx.write_all(data) + })?; + Ok(()) + } + + /// Returns None at the end of the stream and an empty slice if no data is currently available. + pub fn read_pcm(&mut self) -> Result> { + use futures_util::StreamExt; + use std::future::Future; + + let waker = futures_util::task::noop_waker(); + let mut cx = futures_util::task::Context::from_waker(&waker); + let mut next = self.pr.next(); + let next = std::pin::Pin::new(&mut next); + let result = match next.poll(&mut cx) { + std::task::Poll::Ready(None) => None, + std::task::Poll::Ready(Some(packet)) => { + let packet = packet?; + if packet.data.starts_with(b"OpusHead") || packet.data.starts_with(b"OpusTags") { + todo!(); + } + let bytes_read = self.decoder.decode_float( + &packet.data, + &mut self.pcm_buf, + /* Forward Error Correction */ false, + )?; + Some(&self.pcm_buf[..bytes_read]) + } + std::task::Poll::Pending => Some(&self.pcm_buf[..0]), + }; + Ok(result) + } +} diff --git a/test/stream.py b/test/stream.py new file mode 100644 index 0000000..171ac67 --- /dev/null +++ b/test/stream.py @@ -0,0 +1,30 @@ +import sphn + +filename = "bria.mp3" +data, sr = sphn.read(filename) +print(data.shape, sr) + +data = sphn.resample(data, sr, 24000) +print(data.shape) + +stream_writer = sphn.OpusStreamWriter(24000) +# This must be an allowed value among 120, 240, 480, 960, 1920, and 2880. +packet_size = 960 +for lo in range(0, data.shape[-1], packet_size): + up = lo + packet_size + packet = data[0, lo:up] + print(packet.shape) + if packet.shape[-1] != packet_size: + break + stream_writer.append_pcm(packet) + +with open("myfile.opus", "wb") as fobj: + while True: + opus = stream_writer.read_bytes() + if len(opus) == 0: + break + fobj.write(opus) + +data_roundtrip, sr_roundtrip = sphn.read_opus("myfile.opus") +print(data_roundtrip.shape, sr_roundtrip) +sphn.write_opus("myfile2.opus", data_roundtrip, sr_roundtrip)