Skip to content

Commit

Permalink
revamp how ai init openai creating a deployment figures out what it c… (
Browse files Browse the repository at this point in the history
#336)

* revamp how ai init openai creating a deployment figures out what it can/should do with deployments; a few other changes too

* updates for realtime deployment

* updated test

* updated with realtime samples, and tests for those

* updated to filter to realtime deployments appropriately during ai init flows
  • Loading branch information
robch authored Oct 8, 2024
1 parent 2312add commit d47c7e1
Show file tree
Hide file tree
Showing 30 changed files with 1,532 additions and 92 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
using NAudio.Wave;

#nullable disable

/// <summary>
/// Uses the NAudio library (https://github.com/naudio/NAudio) to provide a rudimentary abstraction of microphone
/// input as a stream.
/// </summary>
public class MicrophoneAudioInputStream : Stream, IDisposable
{
public MicrophoneAudioInputStream()
{
_waveInEvent = new()
{
WaveFormat = new WaveFormat(SAMPLES_PER_SECOND, BYTES_PER_SAMPLE * 8, CHANNELS),
};
_waveInEvent.DataAvailable += (_, e) =>
{
lock (_bufferLock)
{
int bytesToCopy = e.BytesRecorded;
if (_bufferWritePos + bytesToCopy >= _buffer.Length)
{
int bytesToCopyBeforeWrap = _buffer.Length - _bufferWritePos;
Array.Copy(e.Buffer, 0, _buffer, _bufferWritePos, bytesToCopyBeforeWrap);
bytesToCopy -= bytesToCopyBeforeWrap;
_bufferWritePos = 0;
}
Array.Copy(e.Buffer, e.BytesRecorded - bytesToCopy, _buffer, _bufferWritePos, bytesToCopy);
_bufferWritePos += bytesToCopy;
}
};
_waveInEvent.StartRecording();
}

public override bool CanRead => true;

public override bool CanSeek => false;

public override bool CanWrite => false;

public override long Length => throw new NotImplementedException();

public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public override void Flush()
{
throw new NotImplementedException();
}

public override int Read(byte[] buffer, int offset, int count)
{
int totalCount = count;

int GetBytesAvailable() => _bufferWritePos < _bufferReadPos
? _bufferWritePos + (_buffer.Length - _bufferReadPos)
: _bufferWritePos - _bufferReadPos;

// For simplicity, we'll block until all requested data is available and not perform partial reads.
while (GetBytesAvailable() < count)
{
Thread.Sleep(100);
}

lock (_bufferLock)
{
if (_bufferReadPos + count >= _buffer.Length)
{
int bytesBeforeWrap = _buffer.Length - _bufferReadPos;
Array.Copy(
sourceArray: _buffer,
sourceIndex: _bufferReadPos,
destinationArray: buffer,
destinationIndex: offset,
length: bytesBeforeWrap);
_bufferReadPos = 0;
count -= bytesBeforeWrap;
offset += bytesBeforeWrap;
}

Array.Copy(_buffer, _bufferReadPos, buffer, offset, count);
_bufferReadPos += count;
}

return totalCount;
}

public override long Seek(long offset, SeekOrigin origin)
{
throw new NotImplementedException();
}

public override void SetLength(long value)
{
throw new NotImplementedException();
}

public override void Write(byte[] buffer, int offset, int count)
{
throw new NotImplementedException();
}

protected override void Dispose(bool disposing)
{
_waveInEvent?.Dispose();
base.Dispose(disposing);
}

private const int SAMPLES_PER_SECOND = 24000;
private const int BYTES_PER_SAMPLE = 2;
private const int CHANNELS = 1;

// For simplicity, this is configured to use a static 10-second ring buffer.
private readonly byte[] _buffer = new byte[BYTES_PER_SAMPLE * SAMPLES_PER_SECOND * CHANNELS * 10];
private readonly object _bufferLock = new();
private int _bufferReadPos = 0;
private int _bufferWritePos = 0;

private readonly WaveInEvent _waveInEvent;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<EnableDefaultCompileItems>true</EnableDefaultCompileItems>
<OutputType>Exe</OutputType>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Identity" Version="1.12.1" />
<PackageReference Include="Azure.AI.OpenAI" Version="2.1.0-beta.1" />
<PackageReference Include="NAudio" Version="2.2.1" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
using Azure.AI.OpenAI;
using OpenAI.RealtimeConversation;
using System.ClientModel;
using System.Text;

#pragma warning disable OPENAI002

public class {ClassName}
{
public {ClassName}(string apiKey, string endpoint, string model, string instructions, MicrophoneAudioInputStream microphone, SpeakerAudioOutputStream speaker)
{
_client = new AzureOpenAIClient(new Uri(endpoint), new ApiKeyCredential(apiKey));
_conversationClient = _client.GetRealtimeConversationClient(model);

_microphone = microphone;
_speaker = speaker;

_sessionOptions = new ConversationSessionOptions()
{
Instructions = instructions,
InputTranscriptionOptions = new()
{
Model = "whisper-1",
}
};
}

public async Task StartSessionAsync()
{
if (Program.Debug) Console.WriteLine("Starting session...");

_session = await _conversationClient.StartConversationSessionAsync();
await _session.ConfigureSessionAsync(_sessionOptions);
}

public async Task GetSessionUpdatesAsync(Action<string, string> callback)
{
if (_session == null) throw new InvalidOperationException("Session has not been started.");

await foreach (var update in _session.ReceiveUpdatesAsync())
{
if (Program.Debug) Console.WriteLine($"Received update: {update.GetType().Name}");

switch (update)
{
case ConversationSessionStartedUpdate:
HandleSessionStarted(callback);
break;

case ConversationAudioDeltaUpdate audioDeltaUpdate:
HandleAudioDelta(audioDeltaUpdate);
break;

case ConversationInputSpeechStartedUpdate:
HandleInputSpeechStarted(callback);
break;

case ConversationInputSpeechFinishedUpdate:
HandleInputSpeechFinished();
break;

case ConversationInputTranscriptionFinishedUpdate transcriptionFinishedUpdate:
HandleInputTranscriptionFinished(callback, transcriptionFinishedUpdate);
break;

case ConversationOutputTranscriptionDeltaUpdate outputTranscriptionDeltaUpdate:
HandleOutputTranscriptionDelta(callback, outputTranscriptionDeltaUpdate);
break;

case ConversationOutputTranscriptionFinishedUpdate:
HandleOutputTranscriptionFinished(callback);
break;

case ConversationErrorUpdate errorUpdate:
Console.WriteLine($"ERROR: {errorUpdate.ErrorMessage}");
return;
}
}
}

private void HandleSessionStarted(Action<string, string> callback)
{
if (Program.Debug) Console.WriteLine("Connected: session started");
_ = Task.Run(async () =>
{
callback("assistant", "Listening...\n");
await _session.SendAudioAsync(_microphone);
callback("user", "");
});
}

private void HandleAudioDelta(ConversationAudioDeltaUpdate audioUpdate)
{
_speaker.EnqueueForPlayback(audioUpdate.Delta);
}

private void HandleInputSpeechStarted(Action<string, string> callback)
{
if (Program.Debug) Console.WriteLine("Start of speech detected");
_speaker.ClearPlayback();
callback("user", "");
}

private void HandleInputSpeechFinished()
{
if (Program.Debug) Console.WriteLine("End of speech detected");
StartBufferingOutputTranscriptionDeltas();
}

private void HandleInputTranscriptionFinished(Action<string, string> callback, ConversationInputTranscriptionFinishedUpdate transcriptionUpdate)
{
callback?.Invoke("user", $"{transcriptionUpdate.Transcript}");
StopBufferingOutputTranscriptionDeltas(callback);
}

private void HandleOutputTranscriptionDelta(Action<string, string> callback, ConversationOutputTranscriptionDeltaUpdate transcriptionUpdate)
{
if (IsBufferingOutputTranscriptionDeltas())
{
BufferOutputTranscriptionDelta(transcriptionUpdate.Delta);
}
else
{
callback?.Invoke("assistant", transcriptionUpdate.Delta);
}
}

private void HandleOutputTranscriptionFinished(Action<string, string> callback)
{
callback?.Invoke("assistant", "\n");
}

private bool IsBufferingOutputTranscriptionDeltas()
{
return _bufferOutputTranscriptionDeltas != null;
}

private void StartBufferingOutputTranscriptionDeltas()
{
_bufferOutputTranscriptionDeltas = new StringBuilder();
}

private void BufferOutputTranscriptionDelta(string delta)
{
_bufferOutputTranscriptionDeltas?.Append(delta);
}

private void StopBufferingOutputTranscriptionDeltas(Action<string, string> callback)
{
if (_bufferOutputTranscriptionDeltas != null)
{
callback?.Invoke("assistant", _bufferOutputTranscriptionDeltas.ToString());
_bufferOutputTranscriptionDeltas = null;
}
}

private readonly AzureOpenAIClient _client;
private RealtimeConversationClient _conversationClient;

private RealtimeConversationSession? _session;
private ConversationSessionOptions _sessionOptions;

private MicrophoneAudioInputStream _microphone;
private SpeakerAudioOutputStream _speaker;

private StringBuilder? _bufferOutputTranscriptionDeltas;
}
36 changes: 36 additions & 0 deletions src/ai/.x/templates/openai-realtime-chat-cs/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using OpenAI.Chat;
using System;

public class Program
{
public static bool Debug { get; set; } = false;

public static async Task Main(string[] args)
{
Debug = args.Contains("--debug") || args.Contains("debug");

var openAIAPIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? "<insert your OpenAI API key here>";
var openAIEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? "<insert your OpenAI endpoint here>";
var openAIChatDeploymentName = Environment.GetEnvironmentVariable("AZURE_OPENAI_REALTIME_DEPLOYMENT") ?? "<insert your OpenAI chat deployment name here>";
var openAISystemPrompt = Environment.GetEnvironmentVariable("AZURE_OPENAI_SYSTEM_PROMPT") ?? "You are a helpful AI assistant.";

var microphone = new MicrophoneAudioInputStream();
var speaker = new SpeakerAudioOutputStream();
var conversation = new {ClassName}(openAIAPIKey, openAIEndpoint, openAIChatDeploymentName, openAISystemPrompt, microphone, speaker);

await conversation.StartSessionAsync();

var lastRole = string.Empty;
await conversation.GetSessionUpdatesAsync((role, content) => {
var isUser = role.ToLower() == "user";
role = isUser ? "User" : "Assistant";
if (role != lastRole)
{
if (!string.IsNullOrEmpty(lastRole)) Console.WriteLine();
Console.Write($"{role}: ");
lastRole = role;
}
Console.Write(content);
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using NAudio.Wave;

/// <summary>
/// Uses the NAudio library (https://github.com/naudio/NAudio) to provide a rudimentary abstraction to output
/// BinaryData audio segments to the default output (speaker/headphone) device.
/// </summary>
public class SpeakerAudioOutputStream : IDisposable
{
BufferedWaveProvider _waveProvider;
WaveOutEvent _waveOutEvent;

public SpeakerAudioOutputStream()
{
WaveFormat outputAudioFormat = new(
rate: 24000,
bits: 16,
channels: 1);
_waveProvider = new(outputAudioFormat)
{
BufferDuration = TimeSpan.FromMinutes(2),
};
_waveOutEvent = new();
_waveOutEvent.Init(_waveProvider);
_waveOutEvent.Play();
}

public void EnqueueForPlayback(BinaryData audioData)
{
byte[] buffer = audioData.ToArray();
_waveProvider.AddSamples(buffer, 0, buffer.Length);
}

public void ClearPlayback()
{
_waveProvider.ClearBuffer();
}

public void Dispose()
{
_waveOutEvent?.Dispose();
}
}
Loading

0 comments on commit d47c7e1

Please sign in to comment.