Skip to content

Commit

Permalink
code style and api
Browse files Browse the repository at this point in the history
  • Loading branch information
KSemenenko committed Jan 5, 2025
1 parent 5ccddc3 commit a2edc20
Show file tree
Hide file tree
Showing 12 changed files with 405 additions and 89 deletions.
37 changes: 31 additions & 6 deletions Together.Tests/HttpCallsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Together.Models.Completions;
using Together.Models.Embeddings;
using Together.Models.Images;
using Together.Models.Rerank;
using ChatMessage = Microsoft.Extensions.AI.ChatMessage;

namespace Together.Tests;
Expand All @@ -19,6 +20,7 @@ private HttpClient CreateHttpClient()
httpClient.Timeout = TimeSpan.FromSeconds(TogetherConstants.TIMEOUT_SECS);
httpClient.BaseAddress = new Uri(TogetherConstants.BASE_URL);
httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", API_KEY);
httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
return httpClient;
}

Expand All @@ -28,7 +30,7 @@ public async Task CompletionTest()
var client = new TogetherClient(CreateHttpClient());


var responseAsync = await client.GetCompletionResponseAsync(new CompletionRequest()
var responseAsync = await client.Completions.CreateAsync(new CompletionRequest()
{
Prompt = "Hi",
Model = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
Expand All @@ -43,7 +45,7 @@ public async Task ChatCompletionTest()
{
var client = new TogetherClient(CreateHttpClient());

var responseAsync = await client.GetChatCompletionResponseAsync(new ChatCompletionRequest
var responseAsync = await client.ChatCompletions.CreateAsync(new ChatCompletionRequest
{
Messages = new List<ChatCompletionMessage>()
{
Expand All @@ -65,7 +67,7 @@ public async Task StreamChatCompletionTest()
{
var client = new TogetherClient(CreateHttpClient());

var responseAsync = await client.GetStreamChatCompletionResponseAsync(new ChatCompletionRequest
var responseAsync = await client.ChatCompletions.CreateStreamAsync(new ChatCompletionRequest
{
Messages = new List<ChatCompletionMessage>()
{
Expand All @@ -90,7 +92,7 @@ public async Task EmbeddingTest()
{
var client = new TogetherClient(CreateHttpClient());

var responseAsync = await client.GetEmbeddingResponseAsync(new EmbeddingRequest()
var responseAsync = await client.Embeddings.CreateAsync(new EmbeddingRequest()
{
Input = "Hi",
Model = "togethercomputer/m2-bert-80M-2k-retrieval",
Expand All @@ -104,7 +106,7 @@ public async Task ImageTest()
{
var client = new TogetherClient(CreateHttpClient());

var responseAsync = await client.GetImageResponseAsync(new ImageRequest()
var responseAsync = await client.Images.GenerateAsync(new ImageRequest()
{
Model = "black-forest-labs/FLUX.1-dev",
Prompt = "Cats eating popcorn",
Expand All @@ -117,14 +119,37 @@ public async Task ImageTest()
Assert.NotEmpty(responseAsync.Data.First().Url);
}

[Fact]
public async Task ModelsTest()
{
var client = new TogetherClient(CreateHttpClient());

var responseAsync = await client.Models.ListModelsAsync();

Assert.NotEmpty(responseAsync);
}

[Fact]
public async Task RerankTest()
{
var client = new TogetherClient(CreateHttpClient());

var responseAsync = await client.Rerank.CreateAsync(new RerankRequest()
{

});

Assert.NotEmpty(responseAsync.Results);
}

[Fact]
public async Task WrongModelTest()
{
var client = new TogetherClient(CreateHttpClient());

await Assert.ThrowsAsync<Exception>(async () =>
{
var responseAsync = await client.GetImageResponseAsync(new ImageRequest()
var responseAsync = await client.Images.GenerateAsync(new ImageRequest()
{
Model = "Wring-Model",
Prompt = "so wrong",
Expand Down
56 changes: 56 additions & 0 deletions Together/Clients/BaseClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using System.Net.Http.Json;
using Together.Models.Error;

namespace Together.Clients;

public abstract class BaseClient
{
protected readonly HttpClient HttpClient;

protected BaseClient(HttpClient httpClient)
{
HttpClient = httpClient;
}

protected async Task<TResponse> SendRequestAsync<TRequest, TResponse>(string requestUri, TRequest request, CancellationToken cancellationToken)
{
var responseMessage = await HttpClient.PostAsJsonAsync(requestUri, request, cancellationToken);
return await HandleResponseAsync<TResponse>(responseMessage, cancellationToken);
}

protected async Task<TResponse> SendRequestAsync<TResponse>(string requestUri, HttpMethod method, HttpContent? content, CancellationToken cancellationToken)
{
using var request = new HttpRequestMessage(method, requestUri);
if (content != null)
{
request.Content = content;
}

var responseMessage = await HttpClient.SendAsync(request, cancellationToken);
return await HandleResponseAsync<TResponse>(responseMessage, cancellationToken);
}

private static async Task<TResponse> HandleResponseAsync<TResponse>(HttpResponseMessage responseMessage, CancellationToken cancellationToken)
{
if (responseMessage.IsSuccessStatusCode)
{
if (typeof(TResponse) == typeof(HttpResponseMessage) && responseMessage is TResponse response)
{
return response;
}

var result = await responseMessage.Content.ReadFromJsonAsync<TResponse>(cancellationToken: cancellationToken);
return result!;
}

var errorResponse = await responseMessage.Content.ReadFromJsonAsync<ErrorResponse>(cancellationToken: cancellationToken);
if (errorResponse?.Error != null)
{
throw new Exception(errorResponse.Error.Message);
}

var statusCode = responseMessage.StatusCode;
var errorContent = await responseMessage.Content.ReadAsStringAsync(cancellationToken);
throw new Exception($"Request failed with status code {statusCode}: {errorContent}");
}
}
40 changes: 40 additions & 0 deletions Together/Clients/ChatCompletionClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using System.Runtime.CompilerServices;
using System.Text.Json;
using Together.Models.ChatCompletions;

namespace Together.Clients;

public class ChatCompletionClient(HttpClient httpClient) : BaseClient(httpClient)
{
public async Task<ChatCompletionResponse> CreateAsync(ChatCompletionRequest request,
CancellationToken cancellationToken = default)
{
return await SendRequestAsync<ChatCompletionRequest, ChatCompletionResponse>("/chat/completions", request, cancellationToken);
}

public async IAsyncEnumerable<ChatCompletionChunk> CreateStreamAsync(ChatCompletionRequest request,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var responseMessage = await SendRequestAsync<ChatCompletionRequest, HttpResponseMessage>("/chat/completions", request, cancellationToken);

await using var stream = await responseMessage.Content.ReadAsStreamAsync(cancellationToken);
using var reader = new StreamReader(stream);

while (await reader.ReadLineAsync(cancellationToken) is string line)
{
if (!line.StartsWith("data:"))
continue;

var eventData = line.Substring("data:".Length)
.Trim();
if (eventData is null or "[DONE]")
break;

var result = JsonSerializer.Deserialize<ChatCompletionChunk>(eventData);

if (result is not null)
yield return result;
}
}

}
17 changes: 17 additions & 0 deletions Together/Clients/CompletionClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System.Net.Http.Json;
using Together.Models.Completions;

namespace Together.Clients;

public class CompletionClient(HttpClient httpClient) : BaseClient(httpClient)
{


public async Task<CompletionResponse> CreateAsync(CompletionRequest request, CancellationToken cancellationToken = default)
{
return await SendRequestAsync<CompletionRequest, CompletionResponse>("/completions", request, cancellationToken);
}



}
15 changes: 15 additions & 0 deletions Together/Clients/EmbeddingClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using Together.Models.Embeddings;

namespace Together.Clients;

public class EmbeddingClient(HttpClient httpClient) : BaseClient(httpClient)
{


public async Task<EmbeddingResponse> CreateAsync(EmbeddingRequest request, CancellationToken cancellationToken = default)
{
return await SendRequestAsync<EmbeddingRequest, EmbeddingResponse>("/embeddings", request, cancellationToken);
}


}
67 changes: 67 additions & 0 deletions Together/Clients/FileClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using System.Net.Http.Headers;
using Together.Models.Files;

namespace Together.Clients;

public class FileClient(HttpClient httpClient) : BaseClient(httpClient)
{
public async Task<FileResponse> UploadAsync(
string filePath,
FilePurpose? purpose = null,
bool checkFile = true,
CancellationToken cancellationToken = default)
{
purpose ??= FilePurpose.FineTune;

if (checkFile && !File.Exists(filePath))
{
throw new FileNotFoundException("File not found", filePath);
}

using var form = new MultipartFormDataContent();
using var fileStream = File.OpenRead(filePath);
using var content = new StreamContent(fileStream);

content.Headers.ContentType = new MediaTypeHeaderValue("application/octet-stream");
form.Add(content, "file", Path.GetFileName(filePath));
form.Add(new StringContent(purpose.ToString().ToLowerInvariant()), "purpose");

return await SendRequestAsync<FileResponse>("/files", HttpMethod.Post, form, cancellationToken);
}

public async Task<FileList> ListAsync(CancellationToken cancellationToken = default)
{
return await SendRequestAsync<FileList>("/files", HttpMethod.Get, null, cancellationToken);
}

public async Task<FileResponse> RetrieveAsync(string fileId, CancellationToken cancellationToken = default)
{
return await SendRequestAsync<FileResponse>($"/files/{fileId}", HttpMethod.Get, null, cancellationToken);
}

public async Task<FileObject> RetrieveContentAsync(string fileId, string? outputPath = null, CancellationToken cancellationToken = default)
{
var fileName = outputPath ?? NormalizeKey($"{fileId}.jsonl");
var response = await HttpClient.GetAsync($"/files/{fileId}/content", cancellationToken);
response.EnsureSuccessStatusCode();

await using var fs = File.Create(fileName);
await response.Content.CopyToAsync(fs, cancellationToken);

var fileInfo = new FileInfo(fileName);
return new FileObject
{
Object = "local",
Id = fileId,
Filename = fileName,
Size = (int)fileInfo.Length
};
}

public async Task<FileDeleteResponse> DeleteAsync(string fileId, CancellationToken cancellationToken = default)
{
return await SendRequestAsync<FileDeleteResponse>($"/files/{fileId}", HttpMethod.Delete, null, cancellationToken);
}

private static string NormalizeKey(string key) => string.Join("_", key.Split(Path.GetInvalidFileNameChars()));
}
Loading

0 comments on commit a2edc20

Please sign in to comment.