diff --git a/lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/AudioBuffer.java b/lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/AudioBuffer.java new file mode 100644 index 00000000000..ca258f48cd6 --- /dev/null +++ b/lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/AudioBuffer.java @@ -0,0 +1,161 @@ +package org.tensorflow.lite.examples.soundclassifier; + +import android.media.AudioFormat; +import android.media.AudioRecord; +import android.provider.MediaStore; +import android.util.Log; + +import java.nio.FloatBuffer; + +import java.util.ArrayList; +import java.util.List; + +import kotlin.collections.ArrayDeque; + +public class AudioBuffer { + + // TODO: What if the ring buffer is not yet fully filled before invocation? + private class RingBuffer { + private float[] buffer; + private int current; + + RingBuffer(int size) { + buffer = new float[size]; + } + + public int getCapacity() { + return buffer.length; + } + + public void feed(float v) { + buffer[current] = v; + current = (current + 1) % buffer.length; + } + + public void feed(float[] data, int size) { + for (int i = 0; i < size; i++) { + feed(data[i]); + } + } + + public float[] getArray() { + float[] output = new float[buffer.length]; + for (int i = current; i < buffer.length; i++) { + output[i - current] = buffer[i]; + } + for (int i = 0; i < current; i++) { + output[buffer.length - current + i] = buffer[i]; + } + return output; + } + } + + // Do we actually need it in Java here? + private AudioFormat audioFormat; + private RingBuffer ringBuffer; + private int ringBufferIndex = 0; + + public AudioBuffer(AudioFormat audioFormat, int sampleCount) { + this.audioFormat = audioFormat; + this.ringBuffer = new RingBuffer(sampleCount); + } + + // PCM float + public int feed(float[] data) { + return feed(data, data.length); + } + + // TODO: what's the correct name? + public int feed(float[] data, int size) { + ringBuffer.feed(data, size); + return size; + } + + // PCM int16 + public int feed(short[] data) { + return feed(data, data.length); + } + + private float pcm16ToFloat(short v) { + return (float) v / 32768; + } + + public int feed(short[] data, int size) { + for (int i = 0; i < size; i++) { + ringBuffer.feed(pcm16ToFloat(data[i])); + } + return size; + } + + // Read from AudioRecord as a helper function + public int feed(AudioRecord record) { + return feed(record, ringBuffer.getCapacity()); + } + + private int feed(AudioRecord record, float[] temporary) { + int readSamples = record.read(temporary, 0, temporary.length, AudioRecord.READ_BLOCKING); + if (readSamples > 0) { + feed(temporary, readSamples); + } + return readSamples; + } + + public int feed(AudioRecord record, short[] temporary) { + int readSamples = record.read(temporary, 0, temporary.length, AudioRecord.READ_BLOCKING); + if (readSamples > 0) { + feed(temporary, readSamples); + } + return readSamples; + } + + public int feed(AudioRecord record, int size) { +// assert record.getChannelCount() == 1; +// assert record.getSampleRate() == this.audioFormat.getSampleRate(); + + int readSamples = 0; + switch (record.getAudioFormat()) { + case AudioFormat.ENCODING_PCM_FLOAT: + readSamples = feed(record, new float[size]); + break; + case AudioFormat.ENCODING_PCM_16BIT: + readSamples = feed(record, new short[size]); + break; + default: + Log.e(TAG, "Unsupported AudioFormat. Requires either PCM float or PCM 16."); + } + + // Report errors. + switch (readSamples) { + case AudioRecord.ERROR_INVALID_OPERATION: + Log.w(TAG, "AudioRecord.ERROR_INVALID_OPERATION"); + break; + case AudioRecord.ERROR_BAD_VALUE: + Log.w(TAG, "AudioRecord.ERROR_BAD_VALUE"); + break; + case AudioRecord.ERROR_DEAD_OBJECT: + Log.w(TAG, "AudioRecord.ERROR_DEAD_OBJECT"); + break; + case AudioRecord.ERROR: + Log.w(TAG, "AudioRecord.ERROR"); + break; + } + return readSamples; + + } + + public AudioFormat getAudioFormat() { + return audioFormat; + } + + // TODO: Convert this to byte buffer. + + // TODO: ownership + public FloatBuffer GetAudioBufferInFloat() { + FloatBuffer output = FloatBuffer.wrap(this.ringBuffer.getArray()); + // TODO: Is this needed? +// output.rewind(); + return output; + } + + private String TAG = "AudioBuffer"; +} diff --git a/lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/SoundClassifier.kt b/lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/SoundClassifier.kt index b9ec5aeb343..854fcc16b6b 100644 --- a/lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/SoundClassifier.kt +++ b/lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/SoundClassifier.kt @@ -34,7 +34,6 @@ import java.nio.FloatBuffer import java.util.Locale import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock -import kotlin.concurrent.withLock import kotlin.math.ceil import kotlin.math.sin import org.tensorflow.lite.Interpreter @@ -56,8 +55,6 @@ class SoundClassifier(context: Context, private val options: Options = Options() val modelPath: String = "sound_classifier.tflite", /** The required audio sample rate in Hz. */ val sampleRate: Int = 44_100, - /** How many milliseconds to sleep between successive audio sample pulls. */ - val audioPullPeriod: Long = 50L, /** Number of warm up runs to do after loading the TFLite model. */ val warmupRuns: Int = 3, /** Number of points in average to reduce noise. */ @@ -148,12 +145,17 @@ class SoundClassifier(context: Context, private val options: Options = Options() private lateinit var recordingBuffer: ShortArray /** Buffer that holds audio PCM sample that are fed to the TFLite model for inference. */ - private lateinit var inputBuffer: FloatBuffer + + + private lateinit var newAudioBuffer: AudioBuffer + private var record: AudioRecord? = null init { loadLabels(context) setupInterpreter(context) warmUpModel() + startRecording() + startRecognition() } override fun onResume(owner: LifecycleOwner) = start() @@ -166,7 +168,8 @@ class SoundClassifier(context: Context, private val options: Options = Options() */ fun start() { if (!isPaused) { - startAudioRecord() + startRecording() + startRecognition() } } @@ -176,7 +179,7 @@ class SoundClassifier(context: Context, private val options: Options = Options() */ fun stop() { if (isClosed || !isRecording) return - recordingThread?.interrupt() + record?.stop() recognitionThread?.interrupt() _probabilities.postValue(labelList.associateWith { 0f }) @@ -194,7 +197,8 @@ class SoundClassifier(context: Context, private val options: Options = Options() /** Retrieve labels from "labels.txt" file */ private fun loadLabels(context: Context) { try { - val reader = BufferedReader(InputStreamReader(context.assets.open(options.metadataPath))) + val reader = + BufferedReader(InputStreamReader(context.assets.open(options.metadataPath))) val wordList = mutableListOf() reader.useLines { lines -> lines.forEach { @@ -229,16 +233,19 @@ class SoundClassifier(context: Context, private val options: Options = Options() Log.e( TAG, "Mismatch between metadata number of classes (${labelList.size})" + - " and model output length ($modelNumClasses)" + " and model output length ($modelNumClasses)" ) } // Fill the array with NaNs initially. predictionProbs = FloatArray(modelNumClasses) { Float.NaN } - inputBuffer = FloatBuffer.allocate(modelInputLength) +// inputBuffer = FloatBuffer.allocate(modelInputLength) + newAudioBuffer = AudioBuffer(null, modelInputLength) } private fun warmUpModel() { + var inputBuffer = FloatBuffer.allocate(modelInputLength) + generateDummyAudioInput(inputBuffer) for (n in 0 until options.warmupRuns) { val t0 = SystemClock.elapsedRealtimeNanos() @@ -267,16 +274,6 @@ class SoundClassifier(context: Context, private val options: Options = Options() } } - /** Start a thread to pull audio samples in continuously. */ - @Synchronized - private fun startAudioRecord() { - if (isRecording) return - recordingThread = AudioRecordingThread().apply { - start() - } - isClosed = false - } - /** Start a thread that runs model inference (i.e., recognition) at a regular interval. */ private fun startRecognition() { recognitionThread = RecognitionThread().apply { @@ -284,83 +281,35 @@ class SoundClassifier(context: Context, private val options: Options = Options() } } - /** Runnable class to run a thread for audio recording */ - private inner class AudioRecordingThread : Thread() { - override fun run() { - var bufferSize = AudioRecord.getMinBufferSize( - options.sampleRate, - AudioFormat.CHANNEL_IN_MONO, - AudioFormat.ENCODING_PCM_16BIT - ) - if (bufferSize == AudioRecord.ERROR || bufferSize == AudioRecord.ERROR_BAD_VALUE) { - bufferSize = options.sampleRate * 2 - Log.w(TAG, "bufferSize has error or bad value") - } - Log.i(TAG, "bufferSize = $bufferSize") - val record = AudioRecord( - // including MIC, UNPROCESSED, and CAMCORDER. - MediaRecorder.AudioSource.VOICE_RECOGNITION, - options.sampleRate, - AudioFormat.CHANNEL_IN_MONO, - AudioFormat.ENCODING_PCM_16BIT, - bufferSize - ) - if (record.state != AudioRecord.STATE_INITIALIZED) { - Log.e(TAG, "AudioRecord failed to initialize") - return - } - Log.i(TAG, "Successfully initialized AudioRecord") - val bufferSamples = bufferSize / 2 - val audioBuffer = ShortArray(bufferSamples) - val recordingBufferSamples = - ceil(modelInputLength.toFloat() / bufferSamples.toDouble()) - .toInt() * bufferSamples - Log.i(TAG, "recordingBufferSamples = $recordingBufferSamples") - recordingOffset = 0 - recordingBuffer = ShortArray(recordingBufferSamples) - record.startRecording() - Log.i(TAG, "Successfully started AudioRecord recording") - - // Start recognition (model inference) thread. - startRecognition() - - while (!isInterrupted) { - try { - TimeUnit.MILLISECONDS.sleep(options.audioPullPeriod) - } catch (e: InterruptedException) { - Log.w(TAG, "Sleep interrupted in audio recording thread.") - break - } - when (record.read(audioBuffer, 0, audioBuffer.size)) { - AudioRecord.ERROR_INVALID_OPERATION -> { - Log.w(TAG, "AudioRecord.ERROR_INVALID_OPERATION") - } - AudioRecord.ERROR_BAD_VALUE -> { - Log.w(TAG, "AudioRecord.ERROR_BAD_VALUE") - } - AudioRecord.ERROR_DEAD_OBJECT -> { - Log.w(TAG, "AudioRecord.ERROR_DEAD_OBJECT") - } - AudioRecord.ERROR -> { - Log.w(TAG, "AudioRecord.ERROR") - } - bufferSamples -> { - // We apply locks here to avoid two separate threads (the recording and - // recognition threads) reading and writing from the recordingBuffer at the same - // time, which can cause the recognition thread to read garbled audio snippets. - recordingBufferLock.withLock { - audioBuffer.copyInto( - recordingBuffer, - recordingOffset, - 0, - bufferSamples - ) - recordingOffset = (recordingOffset + bufferSamples) % recordingBufferSamples - } - } - } - } + private fun startRecording() { + var bufferSize = options.sampleRate * 2 + var minBufferSize = AudioRecord.getMinBufferSize( + options.sampleRate, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT + ) + // Minimal buffer size is greater than the default buffer size. + if (minBufferSize > options.sampleRate * 2) { + bufferSize = bufferSize + } + Log.i(TAG, "bufferSize = $bufferSize") + + record = AudioRecord( + // including MIC, UNPROCESSED, and CAMCORDER. + MediaRecorder.AudioSource.VOICE_RECOGNITION, + options.sampleRate, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + bufferSize + ) + if (record?.state != AudioRecord.STATE_INITIALIZED) { + Log.e(TAG, "AudioRecord failed to initialize") + return } + Log.i(TAG, "Successfully initialized AudioRecord") + record?.startRecording() + Log.i(TAG, "Successfully started AudioRecord recording") + } private inner class RecognitionThread : Thread() { @@ -370,44 +319,45 @@ class SoundClassifier(context: Context, private val options: Options = Options() return } val outputBuffer = FloatBuffer.allocate(modelNumClasses) + + var lastInvokeMs = SystemClock.elapsedRealtime() + while (!isInterrupted) { - try { - TimeUnit.MILLISECONDS.sleep(recognitionPeriod) - } catch (e: InterruptedException) { - Log.w(TAG, "Sleep interrupted in recognition thread.") - break - } - var samplesAreAllZero = true - recordingBufferLock.withLock { - var j = (recordingOffset - modelInputLength) % modelInputLength - if (j < 0) { - j += modelInputLength - } + val currentMs = SystemClock.elapsedRealtime() - for (i in 0 until modelInputLength) { - val s = if (i >= options.pointsInAverage && j >= options.pointsInAverage) { - ((j - options.pointsInAverage + 1)..j).map { recordingBuffer[it % modelInputLength] } - .average() - } else { - recordingBuffer[j % modelInputLength] - } - j += 1 - - if (samplesAreAllZero && s.toInt() != 0) { - samplesAreAllZero = false - } - inputBuffer.put(i, s.toFloat()) + if (currentMs - lastInvokeMs < recognitionPeriod) { + try { + TimeUnit.MILLISECONDS.sleep(recognitionPeriod - (currentMs - lastInvokeMs)) + } catch (e: InterruptedException) { + Log.w(TAG, "Sleep interrupted in recognition thread.") + break } } - if (samplesAreAllZero) { - Log.w(TAG, "No audio input: All audio samples are zero!") - continue + + // TODO: Check output against 0? + val cnt = newAudioBuffer.feed(record) + Log.i(TAG, "Loaded $cnt samples from recorder") + var newInputBuffer = newAudioBuffer.GetAudioBufferInFloat() + + var averageBuffer = FloatBuffer.allocate(modelInputLength) + for (i in 0 until modelInputLength) { + val s = if (i >= options.pointsInAverage) { + ((i - options.pointsInAverage + 1)..i).map { + newInputBuffer[i] + }.average() + } else { + newInputBuffer[i] + } + + averageBuffer.put(i, s.toFloat()) + } + val t0 = SystemClock.elapsedRealtimeNanos() - inputBuffer.rewind() + averageBuffer.rewind() outputBuffer.rewind() - interpreter.run(inputBuffer, outputBuffer) + interpreter.run(averageBuffer, outputBuffer) outputBuffer.rewind() outputBuffer.get(predictionProbs) // Copy data to predictionProbs. @@ -416,7 +366,9 @@ class SoundClassifier(context: Context, private val options: Options = Options() } _probabilities.postValue(labelList.zip(probList).toMap()) - latestPredictionLatencyMs = ((SystemClock.elapsedRealtimeNanos() - t0) / 1e6).toFloat() + latestPredictionLatencyMs = + ((SystemClock.elapsedRealtimeNanos() - t0) / 1e6).toFloat() + Log.i(TAG, "Latency: $latestPredictionLatencyMs") } } }