Skip to content

Commit

Permalink
Passing API tests
Browse files Browse the repository at this point in the history
Add cleanup service tests for word alignment

Fix flaky test

Fix naming

Fix more names

Passing statistical engine service tests - first pass

Add hangfire implementation and tests

Refactor alignment data structure; revert to pretranslate/word-align where appropriate

Use parallel data when inferencing for word alignment

Fix JSON serialization

Rebase-related changes: extend executionData, commit-related issues to WA, small rebase mistakes

Get rid of WordAlignmentResult
  • Loading branch information
Enkidu93 committed Feb 3, 2025
1 parent 82c81c8 commit 6da569a
Show file tree
Hide file tree
Showing 58 changed files with 1,631 additions and 795 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- name: Start MongoDB
uses: supercharge/[email protected]
with:
mongodb-version: "6.0"
mongodb-version: "8.0"
mongodb-replica-set: rs0

# Pull in a matching machine repo branch if it exists to use it rather than the released version of Machine.
Expand Down
13 changes: 6 additions & 7 deletions src/Echo/src/EchoEngine/TranslationEngineServiceV1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ await client.BuildStartedAsync(
try
{
using (
AsyncClientStreamingCall<InsertInferencesRequest, Empty> call = client.InsertInferences(
cancellationToken: cancellationToken
)
AsyncClientStreamingCall<InsertPretranslationsRequest, Empty> call =
client.InsertPretranslations(cancellationToken: cancellationToken)
)
{
foreach (ParallelCorpus corpus in request.Corpora)
Expand Down Expand Up @@ -133,7 +132,7 @@ await client.BuildStartedAsync(
if (sourceLine.Length > 0 && targetLine.Length == 0)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
Expand Down Expand Up @@ -166,7 +165,7 @@ await call.RequestStream.WriteAsync(
if (sourceLine.Length > 0 && targetLine.Length == 0)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
Expand All @@ -191,7 +190,7 @@ await call.RequestStream.WriteAsync(
if (sourceLine.Length > 0)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
Expand All @@ -212,7 +211,7 @@ await call.RequestStream.WriteAsync(
if (sourceLine.Length > 0)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
Expand Down
6 changes: 3 additions & 3 deletions src/Echo/src/EchoEngine/WordAlignmentEngineServiceV1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ await client.BuildStartedAsync(
try
{
using (
AsyncClientStreamingCall<InsertInferencesRequest, Empty> call = client.InsertInferences(
AsyncClientStreamingCall<InsertWordAlignmentsRequest, Empty> call = client.InsertWordAlignments(
cancellationToken: cancellationToken
)
)
Expand Down Expand Up @@ -128,7 +128,7 @@ await client.BuildStartedAsync(
targetLine.Split().Length
);
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertWordAlignmentsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
Expand Down Expand Up @@ -168,7 +168,7 @@ await call.RequestStream.WriteAsync(
targetLine.Split().Length
);
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertWordAlignmentsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ public static IMachineBuilder AddThotSmtModel(this IMachineBuilder builder, ICon
return builder;
}

public static IMachineBuilder AddWordAlignmentModel(this IMachineBuilder builder)
{
builder.Services.Configure<WordAlignmentModelOptions>(
builder.Configuration.GetSection(WordAlignmentModelOptions.Key)
);
builder.Services.AddSingleton<IWordAlignmentModelFactory, WordAlignmentModelFactory>();
return builder;
}

public static IMachineBuilder AddTransferEngine(this IMachineBuilder builder)
{
builder.Services.AddSingleton<ITransferEngineFactory, TransferEngineFactory>();
Expand Down Expand Up @@ -485,7 +494,7 @@ public static IMachineBuilder AddThot(this IMachineBuilder builder)
{
try
{
builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser();
builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser().AddWordAlignmentModel();
}
catch (ArgumentException)
{
Expand Down Expand Up @@ -516,9 +525,9 @@ public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder)
var smtTransferEngineOptions = new SmtTransferEngineOptions();
builder.Configuration.GetSection(SmtTransferEngineOptions.Key).Bind(smtTransferEngineOptions);
string? smtDriveLetter = Path.GetPathRoot(smtTransferEngineOptions.EnginesDir)?[..1];
var statisticsEngineOptions = new WordAlignmentEngineOptions();
builder.Configuration.GetSection(WordAlignmentEngineOptions.Key).Bind(statisticsEngineOptions);
string? statisticsDriveLetter = Path.GetPathRoot(statisticsEngineOptions.EnginesDir)?[..1];
var statisticalEngineOptions = new WordAlignmentEngineOptions();
builder.Configuration.GetSection(WordAlignmentEngineOptions.Key).Bind(statisticalEngineOptions);
string? statisticsDriveLetter = Path.GetPathRoot(statisticalEngineOptions.EnginesDir)?[..1];
if (smtDriveLetter is null || statisticsDriveLetter is null)
throw new InvalidOperationException("SMT Engine and Statistical directory is required");
if (smtDriveLetter != statisticsDriveLetter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ public class BuildJobService<TEngine>(IEnumerable<IBuildJobRunner> runners, IRep
: IBuildJobService<TEngine>
where TEngine : ITrainingEngine
{
// TODO: make some sort of service to get the engine repos.
protected readonly Dictionary<BuildJobRunnerType, IBuildJobRunner> Runners = runners.ToDictionary(r => r.Type);
protected readonly IRepository<TEngine> Engines = engines;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Task BuildCompletedAsync(
Task BuildFaultedAsync(string buildId, string message, CancellationToken cancellationToken = default);
Task BuildRestartingAsync(string buildId, CancellationToken cancellationToken = default);

Task InsertPretranslationsAsync(
Task InsertInferencesAsync(
string engineId,
Stream pretranslationsStream,
CancellationToken cancellationToken = default
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
namespace Serval.Machine.Shared.Services;

public interface ISmtModelFactory
public interface ISmtModelFactory : IModelFactory
{
IInteractiveTranslationModel Create(
string engineDir,
IRangeTokenizer<string, int, string> tokenizer,
IDetokenizer<string, string> detokenizer,
ITruecaser truecaser
);
ITrainer CreateTrainer(
string engineDir,
IRangeTokenizer<string, int, string> tokenizer,
IParallelTextCorpus corpus
);
void InitNew(string engineDir);
void Cleanup(string engineDir);
Task UpdateEngineFromAsync(string engineDir, Stream source, CancellationToken cancellationToken = default);
Task SaveEngineToAsync(string engineDir, Stream destination, CancellationToken cancellationToken = default);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace Serval.Machine.Shared.Services;
using Serval.WordAlignment.V1;

namespace Serval.Machine.Shared.Services;

public interface IWordAlignmentEngineService
{
Expand All @@ -13,7 +15,7 @@ Task<WordAlignmentEngine> CreateAsync(
);
Task DeleteAsync(string engineId, CancellationToken cancellationToken = default);

Task<WordAlignmentResult> GetBestPhraseAlignmentAsync(
Task<WordAlignmentResult> GetBestWordAlignmentAsync(
string engineId,
string sourceSegment,
string targetSegment,
Expand All @@ -24,7 +26,7 @@ Task StartBuildAsync(
string engineId,
string buildId,
string? buildOptions,
IReadOnlyList<ParallelCorpus> corpora,
IReadOnlyList<SIL.ServiceToolkit.Models.ParallelCorpus> corpora,
CancellationToken cancellationToken = default
);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
namespace Serval.Machine.Shared.Services;

public interface IWordAlignmentModelFactory
public interface IWordAlignmentModelFactory : IModelFactory
{
IWordAlignmentModel Create(string engineDir);
ITrainer CreateTrainer(string engineDir, ITokenizer<string, int, string> tokenizer, IParallelTextCorpus corpus);
void InitNew(string engineDir);
void Cleanup(string engineDir);
Task UpdateEngineFromAsync(string engineDir, Stream source, CancellationToken cancellationToken = default);
Task SaveEngineToAsync(string engineDir, Stream destination, CancellationToken cancellationToken = default);
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ CancellationToken cancellationToken
bool sourceTagInBaseModel = ResolveLanguageCodeForBaseModel(engine.SourceLanguage, out string srcLang);
bool targetTagInBaseModel = ResolveLanguageCodeForBaseModel(engine.TargetLanguage, out string trgLang);

(int trainCount, int pretranslateCount) = await WriteDataFilesAsync(
(int trainCount, int inferenceCount) = await WriteDataFilesAsync(
buildId,
data,
buildOptions,
Expand All @@ -70,7 +70,7 @@ CancellationToken cancellationToken
{ "EngineId", engineId },
{ "BuildId", buildId },
{ "NumTrainRows", trainCount },
{ "NumInferenceRows", pretranslateCount },
{ "NumInferenceRows", inferenceCount },
{ "SourceLanguageResolved", srcLang },
{ "TargetLanguageResolved", trgLang }
};
Expand All @@ -86,7 +86,7 @@ CancellationToken cancellationToken
var executionData = new Dictionary<string, string>()
{
{ "trainCount", trainCount.ToString(CultureInfo.InvariantCulture) },
{ "pretranslateCount", pretranslateCount.ToString(CultureInfo.InvariantCulture) }
{ "inference", inferenceCount.ToString(CultureInfo.InvariantCulture) }
};
await PlatformService.UpdateBuildExecutionDataAsync(engineId, buildId, executionData, cancellationToken);

Expand All @@ -105,6 +105,7 @@ CancellationToken cancellationToken
throw new OperationCanceledException();
}

//TODO: Move this method to translation-specific PreprocessBuildJob
protected virtual async Task<(int TrainCount, int InferenceCount)> WriteDataFilesAsync(
string buildId,
IReadOnlyList<ParallelCorpus> corpora,
Expand Down Expand Up @@ -142,9 +143,9 @@ await ParallelCorpusPreprocessingService.PreprocessAsync(
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0)
trainCount++;
},
async (row, corpus) =>
async (row, isInTrainingData, corpus) =>
{
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0)
if (row.SourceSegment.Length > 0 && !isInTrainingData)
{
pretranslateWriter.WriteStartObject();
pretranslateWriter.WriteString("corpusId", corpus.Id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ public static class ServalTranslationPlatformOutboxConstants
public const string BuildFaulted = "BuildFaulted";
public const string BuildRestarting = "BuildRestarting";
public const string InsertPretranslations = "InsertPretranslations";
public const string IncrementTranslationEngineCorpusSize = "IncrementTranslationEngineCorpusSize";
public const string IncrementTrainEngineCorpusSize = "IncrementTrainEngineCorpusSize";
public const string UpdateBuildExecutionData = "UpdateBuildExecutionData";
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ public class ServalTranslationPlatformOutboxMessageHandler(TranslationPlatformAp
: IOutboxMessageHandler
{
private readonly TranslationPlatformApi.TranslationPlatformApiClient _client = client;
private static readonly JsonSerializerOptions JsonSerializerOptions =
new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase };
private readonly JsonSerializerOptions _jsonSerializerOptions = MessageOutboxOptions.JsonSerializerOptions;

public string OutboxId => ServalTranslationPlatformOutboxConstants.OutboxId;

public async Task HandleMessageAsync(
string groupId,
string method,
string? content,
Stream? contentStream,
Expand All @@ -22,51 +22,51 @@ public async Task HandleMessageAsync(
{
case ServalTranslationPlatformOutboxConstants.BuildStarted:
await _client.BuildStartedAsync(
JsonSerializer.Deserialize<BuildStartedRequest>(content!),
JsonSerializer.Deserialize<BuildStartedRequest>(content!, _jsonSerializerOptions),
cancellationToken: cancellationToken
);
break;
case ServalTranslationPlatformOutboxConstants.BuildCompleted:
await _client.BuildCompletedAsync(
JsonSerializer.Deserialize<BuildCompletedRequest>(content!),
JsonSerializer.Deserialize<BuildCompletedRequest>(content!, _jsonSerializerOptions),
cancellationToken: cancellationToken
);
break;
case ServalTranslationPlatformOutboxConstants.BuildCanceled:
await _client.BuildCanceledAsync(
JsonSerializer.Deserialize<BuildCanceledRequest>(content!),
JsonSerializer.Deserialize<BuildCanceledRequest>(content!, _jsonSerializerOptions),
cancellationToken: cancellationToken
);
break;
case ServalTranslationPlatformOutboxConstants.BuildFaulted:
await _client.BuildFaultedAsync(
JsonSerializer.Deserialize<BuildFaultedRequest>(content!),
JsonSerializer.Deserialize<BuildFaultedRequest>(content!, _jsonSerializerOptions),
cancellationToken: cancellationToken
);
break;
case ServalTranslationPlatformOutboxConstants.BuildRestarting:
await _client.BuildRestartingAsync(
JsonSerializer.Deserialize<BuildRestartingRequest>(content!),
JsonSerializer.Deserialize<BuildRestartingRequest>(content!, _jsonSerializerOptions),
cancellationToken: cancellationToken
);
break;
case ServalTranslationPlatformOutboxConstants.InsertInferences:
case ServalTranslationPlatformOutboxConstants.InsertPretranslations:
IAsyncEnumerable<Pretranslation> pretranslations = JsonSerializer
.DeserializeAsyncEnumerable<Pretranslation>(
contentStream!,
JsonSerializerOptions,
_jsonSerializerOptions,
cancellationToken
)
.OfType<Pretranslation>();

using (var call = _client.InsertInferences(cancellationToken: cancellationToken))
using (var call = _client.InsertPretranslations(cancellationToken: cancellationToken))
{
await foreach (Pretranslation pretranslation in pretranslations)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = content!,
EngineId = groupId,
CorpusId = pretranslation.CorpusId,
TextId = pretranslation.TextId,
Refs = { pretranslation.Refs },
Expand Down
Loading

0 comments on commit 6da569a

Please sign in to comment.