Skip to content

Commit

Permalink
Merge pull request #3 from LaurentMazare/opus-stream
Browse files Browse the repository at this point in the history
Sketch the opus stream api.
  • Loading branch information
LaurentMazare authored Aug 24, 2024
2 parents e79b4c1 + d91d3b6 commit d1b43cb
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ crate-type = ["cdylib"]
[dependencies]
anyhow = "1.0.79"
byteorder = "1.5.0"
futures-util = "0.3.30"
numpy = "0.21.0"
ogg = "0.9.1"
ogg = { version = "0.9.1", features = ["async"] }
opus = "0.3.0"
pyo3 = "0.21.0"
rayon = "1.8.1"
rubato = "0.15.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.113"
symphonia = { version = "0.5.4", features = ["all"] }
tokio = { version = "1.39.3", features = ["full"] }

[profile.release]
debug = true
Expand Down
35 changes: 35 additions & 0 deletions py_src/sphn/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def read_opus_bytes(bytes):
"""
pass

@staticmethod
def resample(pcm, src_sample_rate, dst_sample_rate):
"""
Resamples some pcm data.
"""
pass

@staticmethod
def write_opus(filename, data, sample_rate):
"""
Expand Down Expand Up @@ -112,3 +119,31 @@ class FileReader:
The sample rate as an int.
"""
pass

class OpusStreamReader:
def __init__(sample_rate):
pass

def append_bytes(self, data):
"""
Write some ogg/opus bytes to the current stream.
"""
pass

def read_pcm(self):
"""
Get some pcm data out of the stream.
"""
pass

class OpusStreamWriter:
def __init__(sample_rate):
pass

def append_pcm(self, pcm):
""" """
pass

def read_bytes(self):
""" """
pass
76 changes: 76 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,85 @@ fn read_opus_bytes(bytes: Vec<u8>, py: Python) -> PyResult<(PyObject, u32)> {
Ok((data, sample_rate))
}

#[pyclass]
struct OpusStreamWriter {
inner: opus::StreamWriter,
sample_rate: u32,
}

#[pymethods]
impl OpusStreamWriter {
#[new]
fn new(sample_rate: u32) -> PyResult<Self> {
let inner = opus::StreamWriter::new(sample_rate).w()?;
Ok(Self { inner, sample_rate })
}

fn __str__(&self) -> String {
format!("OpusStreamWriter(sample_rate={})", self.sample_rate)
}

fn append_pcm(&mut self, pcm: numpy::PyReadonlyArray1<f32>) -> PyResult<()> {
let pcm = pcm.as_array();
match pcm.as_slice() {
None => {
let pcm = pcm.to_vec();
self.inner.append_pcm(&pcm).w()?
}
Some(pcm) => self.inner.append_pcm(pcm).w()?,
};
Ok(())
}

fn read_bytes(&mut self) -> PyResult<PyObject> {
let bytes = self.inner.read_bytes().w()?;
let bytes = Python::with_gil(|py| pyo3::types::PyBytes::new_bound(py, &bytes).into_py(py));
Ok(bytes)
}
}

#[pyclass]
struct OpusStreamReader {
inner: opus::StreamReader,
sample_rate: u32,
}

#[pymethods]
impl OpusStreamReader {
#[new]
fn new(sample_rate: u32) -> PyResult<Self> {
let inner = opus::StreamReader::new(sample_rate).w()?;
Ok(Self { inner, sample_rate })
}

fn __str__(&self) -> String {
format!("OpusStreamReader(sample_rate={})", self.sample_rate)
}

/// Write some ogg/opus bytes to the current stream.
fn append_bytes(&mut self, data: &[u8]) -> PyResult<()> {
self.inner.append(data).w()
}

// TODO(laurent): maybe we should also have a pyo3_async api here.
/// Get some pcm data out of the stream.
fn read_pcm(&mut self) -> PyResult<PyObject> {
let pcm_data = self.inner.read_pcm().w()?;
Python::with_gil(|py| match pcm_data {
None => Ok(py.None()),
Some(data) => {
let data = numpy::PyArray1::from_vec_bound(py, data.to_vec()).into_py(py);
Ok(data)
}
})
}
}

#[pymodule]
fn sphn(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<FileReader>()?;
m.add_class::<OpusStreamReader>()?;
m.add_class::<OpusStreamWriter>()?;
m.add_function(wrap_pyfunction!(durations, m)?)?;
m.add_function(wrap_pyfunction!(read, m)?)?;
m.add_function(wrap_pyfunction!(write_wav, m)?)?;
Expand Down
138 changes: 138 additions & 0 deletions src/opus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use anyhow::Result;
// https://opus-codec.org/docs/opus_api-1.2/group__opus__encoder.html#ga4ae9905859cd241ef4bb5c59cd5e5309
const OPUS_ENCODER_FRAME_SIZE: usize = 960;
const OPUS_SAMPLE_RATE: u32 = 48000;
const OPUS_ALLOWED_FRAME_SIZES: [usize; 6] = [120, 240, 480, 960, 1920, 2880];

/// See https://www.opus-codec.org/docs/opusfile_api-0.4/structOpusHead.html
#[allow(unused)]
Expand Down Expand Up @@ -198,3 +199,140 @@ pub fn write_ogg_stereo<W: std::io::Write>(
write_ogg_48khz(w, &pcm, sample_rate, true)
}
}

struct BufferStream(std::sync::mpsc::Sender<Vec<u8>>);

impl std::io::Write for BufferStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if self.0.send(buf.to_vec()).is_err() {
return Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"opus stream writer error".to_string(),
));
};
Ok(buf.len())
}

fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}

pub struct StreamWriter {
pw: ogg::PacketWriter<'static, BufferStream>,
encoder: opus::Encoder,
out_encoded: Vec<u8>,
total_data: u64,
rx: std::sync::mpsc::Receiver<Vec<u8>>,
}

impl StreamWriter {
pub fn new(sample_rate: u32) -> Result<Self> {
if sample_rate != 48000 && sample_rate != 24000 {
anyhow::bail!("sample-rate has to be 48000 or 24000, got {sample_rate}")
}
let encoder =
opus::Encoder::new(sample_rate, opus::Channels::Mono, opus::Application::Voip)?;
let (tx, rx) = std::sync::mpsc::channel();
let mut pw = ogg::PacketWriter::new(BufferStream(tx));
let out_encoded = vec![0u8; 50_000];
let mut head = Vec::new();
write_opus_header(&mut head, 1u8, sample_rate)?;
pw.write_packet(head, 42, ogg::PacketWriteEndInfo::EndPage, 0)?;
let mut tags = Vec::new();
write_opus_tags(&mut tags)?;
pw.write_packet(tags, 42, ogg::PacketWriteEndInfo::EndPage, 0)?;
Ok(Self { pw, encoder, out_encoded, total_data: 0, rx })
}

pub fn append_pcm(&mut self, pcm: &[f32]) -> Result<()> {
if !OPUS_ALLOWED_FRAME_SIZES.contains(&pcm.len()) {
anyhow::bail!(
"pcm length has to match an allowed frame size {OPUS_ALLOWED_FRAME_SIZES:?}, got {}", pcm.len()
)
}

let size = self.encoder.encode_float(pcm, &mut self.out_encoded)?;
let msg = self.out_encoded[..size].to_vec();
self.total_data += pcm.len() as u64;
self.pw.write_packet(msg, 42, ogg::PacketWriteEndInfo::EndPage, self.total_data)?;
Ok(())
}

pub fn read_bytes(&mut self) -> Result<Vec<u8>> {
match self.rx.try_recv() {
Ok(data) => Ok(data),
Err(std::sync::mpsc::TryRecvError::Empty) => Ok(vec![]),
Err(std::sync::mpsc::TryRecvError::Disconnected) => {
anyhow::bail!("opus stream writer disconnected")
}
}
}
}

pub struct StreamReader {
pr: ogg::reading::async_api::PacketReader<tokio::io::DuplexStream>,
decoder: opus::Decoder,
tx: tokio::io::DuplexStream,
runtime: tokio::runtime::Runtime,
pcm_buf: Vec<f32>,
}

// The StreamReader implementation uses tokio under the hood, this is a bit of a bummer and comes
// from the ogg crate PacketReader non-async api requiring Seek on its input stream (and so not
// having to keep much inner state), the async version just requires read so is more adapted to
// what we want to provide here.
impl StreamReader {
pub fn new(sample_rate: u32) -> Result<Self> {
if sample_rate != 48000 && sample_rate != 24000 {
anyhow::bail!("sample-rate has to be 48000 or 24000, got {sample_rate}")
}
// TODO(laurent): look whether there is a more adapted channel type.
let (tx, rx) = tokio::io::duplex(100_000);
let decoder = opus::Decoder::new(sample_rate, opus::Channels::Mono)?;
let pr = ogg::reading::async_api::PacketReader::new(rx);
// TODO(laurent): consider spawning a thread so that the process happens in the background.
let runtime = tokio::runtime::Runtime::new()?;
let pcm_buf = vec![0f32; 24_000 * 10];
Ok(Self { pr, decoder, tx, runtime, pcm_buf })
}

pub fn append(&mut self, data: &[u8]) -> Result<()> {
use tokio::io::AsyncWriteExt;

self.runtime.block_on({
// Maybe we should wait for a bit of time here to allow for the PacketReader to do some
// processing.
self.tx.write_all(data)
})?;
Ok(())
}

/// Returns None at the end of the stream and an empty slice if no data is currently available.
pub fn read_pcm(&mut self) -> Result<Option<&[f32]>> {
use futures_util::StreamExt;
use std::future::Future;

let waker = futures_util::task::noop_waker();
let mut cx = futures_util::task::Context::from_waker(&waker);
let mut next = self.pr.next();
let next = std::pin::Pin::new(&mut next);
let result = match next.poll(&mut cx) {
std::task::Poll::Ready(None) => None,
std::task::Poll::Ready(Some(packet)) => {
let packet = packet?;
if packet.data.starts_with(b"OpusHead") || packet.data.starts_with(b"OpusTags") {
todo!();
}
let bytes_read = self.decoder.decode_float(
&packet.data,
&mut self.pcm_buf,
/* Forward Error Correction */ false,
)?;
Some(&self.pcm_buf[..bytes_read])
}
std::task::Poll::Pending => Some(&self.pcm_buf[..0]),
};
Ok(result)
}
}
30 changes: 30 additions & 0 deletions test/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import sphn

filename = "bria.mp3"
data, sr = sphn.read(filename)
print(data.shape, sr)

data = sphn.resample(data, sr, 24000)
print(data.shape)

stream_writer = sphn.OpusStreamWriter(24000)
# This must be an allowed value among 120, 240, 480, 960, 1920, and 2880.
packet_size = 960
for lo in range(0, data.shape[-1], packet_size):
up = lo + packet_size
packet = data[0, lo:up]
print(packet.shape)
if packet.shape[-1] != packet_size:
break
stream_writer.append_pcm(packet)

with open("myfile.opus", "wb") as fobj:
while True:
opus = stream_writer.read_bytes()
if len(opus) == 0:
break
fobj.write(opus)

data_roundtrip, sr_roundtrip = sphn.read_opus("myfile.opus")
print(data_roundtrip.shape, sr_roundtrip)
sphn.write_opus("myfile2.opus", data_roundtrip, sr_roundtrip)

0 comments on commit d1b43cb

Please sign in to comment.