From 6da569a9133ebbdb6de5eb0fae053b4fcc42208d Mon Sep 17 00:00:00 2001 From: Enkidu93 Date: Mon, 16 Dec 2024 16:31:25 -0500 Subject: [PATCH] Passing API tests Add cleanup service tests for word alignment Fix flaky test Fix naming Fix more names Passing statistical engine service tests - first pass Add hangfire implementation and tests Refactor alignment data structure; revert to pretranslate/word-align where appropriate Use parallel data when inferencing for word alignment Fix JSON serialization Rebase-related changes: extend executionData, commit-related issues to WA, small rebase mistakes Get rid of WordAlignmentResult --- .github/workflows/ci.yml | 2 +- .../EchoEngine/TranslationEngineServiceV1.cs | 13 +- .../WordAlignmentEngineServiceV1.cs | 6 +- .../IMachineBuilderExtensions.cs | 17 +- .../Services/BuildJobService.cs | 1 - .../Services/IPlatformService.cs | 2 +- .../Services/ISmtModelFactory.cs | 11 +- .../Services/IWordAlignmentEngineService.cs | 8 +- .../Services/IWordAlignmentModelFactory.cs | 7 +- .../Services/PreprocessBuildJob.cs | 11 +- ...ervalTranslationPlatformOutboxConstants.cs | 2 +- ...TranslationPlatformOutboxMessageHandler.cs | 24 +- .../ServalTranslationPlatformService.cs | 47 +- .../ServalWordAlignmentEngineServiceV1.cs | 29 +- ...valWordAlignmentPlatformOutboxConstants.cs | 1 + ...rdAlignmentPlatformOutboxMessageHandler.cs | 108 +++- .../ServalWordAlignmentPlatformService.cs | 45 +- .../Services/StatisticalEngineService.cs | 52 +- .../StatisticalPostprocessBuildJob.cs | 8 +- .../Services/StatisticalTrainBuildJob.cs | 173 +++++- .../Services/WordAlignmentEngineState.cs | 37 +- .../WordAlignmentEngineStateService.cs | 13 +- .../Services/WordAlignmentModelFactory.cs | 8 +- .../WordAlignmentPreprocessBuildJob.cs | 11 +- .../Utils/AsyncDisposableBase.cs | 19 - .../Serval.Machine.Shared/Utils/AsyncTimer.cs | 70 --- .../Services/PreprocessBuildJobTests.cs | 240 --------- ...ServalPlatformOutboxMessageHandlerTests.cs | 7 +- .../Services/SmtTransferEngineServiceTests.cs | 10 +- .../Services/StatisticalEngineServiceTests.cs | 504 ++++++++++++++++++ src/Serval/src/Serval.Client/Client.g.cs | 43 +- .../serval/translation/v1/platform.proto | 4 +- .../serval/word_alignment/v1/platform.proto | 12 +- .../Services/EngineService.cs | 12 +- .../Services/TranslationPlatformServiceV1.cs | 10 +- .../IMongoDataAccessConfiguratorExtensions.cs | 5 + .../Contracts/WordAlignmentBuildConfigDto.cs | 2 +- .../Contracts/WordAlignmentBuildDto.cs | 3 +- .../Contracts/WordAlignmentCorpusConfigDto.cs | 8 + .../Contracts/WordAlignmentCorpusDto.cs | 8 + .../WordAlignmentEnginesController.cs | 83 ++- .../src/Serval.WordAlignment/Models/Build.cs | 3 +- .../Models/WordAlignmentCorpus.cs | 8 + .../Services/BuildService.cs | 32 +- .../Services/EngineService.cs | 23 +- .../WordAlignmentPlatformServiceV1.cs | 25 +- .../TranslationEngineTests.cs | 81 +-- .../WordAlignmentEngineTests.cs | 193 ++++--- .../test/Serval.E2ETests/ServalApiTests.cs | 44 +- .../Serval.E2ETests/ServalClientHelper.cs | 2 +- .../Services/PlatformServiceTests.cs | 20 +- .../Services/BuildCleanupServiceTests.cs | 56 ++ .../Services/EngineCleanupServiceTests.cs | 64 +++ .../Services/EngineServiceTests.cs | 4 +- .../Services/PlatformServiceTests.cs | 75 ++- .../IParallelCorpusPreprocessingService.cs | 2 +- .../ParallelCorpusPreprocessingService.cs | 104 +++- .../ParallelCorpusProcessingServiceTests.cs | 14 +- 58 files changed, 1631 insertions(+), 795 deletions(-) delete mode 100644 src/Machine/src/Serval.Machine.Shared/Utils/AsyncDisposableBase.cs delete mode 100644 src/Machine/src/Serval.Machine.Shared/Utils/AsyncTimer.cs create mode 100644 src/Machine/test/Serval.Machine.Shared.Tests/Services/StatisticalEngineServiceTests.cs create mode 100644 src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentCorpusConfigDto.cs create mode 100644 src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentCorpusDto.cs create mode 100644 src/Serval/src/Serval.WordAlignment/Models/WordAlignmentCorpus.cs create mode 100644 src/Serval/test/Serval.WordAlignment.Tests/Services/BuildCleanupServiceTests.cs create mode 100644 src/Serval/test/Serval.WordAlignment.Tests/Services/EngineCleanupServiceTests.cs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6ffb75a7..af7e382c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: - name: Start MongoDB uses: supercharge/mongodb-github-action@1.8.0 with: - mongodb-version: "6.0" + mongodb-version: "8.0" mongodb-replica-set: rs0 # Pull in a matching machine repo branch if it exists to use it rather than the released version of Machine. diff --git a/src/Echo/src/EchoEngine/TranslationEngineServiceV1.cs b/src/Echo/src/EchoEngine/TranslationEngineServiceV1.cs index 1a31331e..bd0f7771 100644 --- a/src/Echo/src/EchoEngine/TranslationEngineServiceV1.cs +++ b/src/Echo/src/EchoEngine/TranslationEngineServiceV1.cs @@ -78,9 +78,8 @@ await client.BuildStartedAsync( try { using ( - AsyncClientStreamingCall call = client.InsertInferences( - cancellationToken: cancellationToken - ) + AsyncClientStreamingCall call = + client.InsertPretranslations(cancellationToken: cancellationToken) ) { foreach (ParallelCorpus corpus in request.Corpora) @@ -133,7 +132,7 @@ await client.BuildStartedAsync( if (sourceLine.Length > 0 && targetLine.Length == 0) { await call.RequestStream.WriteAsync( - new InsertInferencesRequest + new InsertPretranslationsRequest { EngineId = request.EngineId, CorpusId = corpus.Id, @@ -166,7 +165,7 @@ await call.RequestStream.WriteAsync( if (sourceLine.Length > 0 && targetLine.Length == 0) { await call.RequestStream.WriteAsync( - new InsertInferencesRequest + new InsertPretranslationsRequest { EngineId = request.EngineId, CorpusId = corpus.Id, @@ -191,7 +190,7 @@ await call.RequestStream.WriteAsync( if (sourceLine.Length > 0) { await call.RequestStream.WriteAsync( - new InsertInferencesRequest + new InsertPretranslationsRequest { EngineId = request.EngineId, CorpusId = corpus.Id, @@ -212,7 +211,7 @@ await call.RequestStream.WriteAsync( if (sourceLine.Length > 0) { await call.RequestStream.WriteAsync( - new InsertInferencesRequest + new InsertPretranslationsRequest { EngineId = request.EngineId, CorpusId = corpus.Id, diff --git a/src/Echo/src/EchoEngine/WordAlignmentEngineServiceV1.cs b/src/Echo/src/EchoEngine/WordAlignmentEngineServiceV1.cs index 3d3c5e0a..461938ff 100644 --- a/src/Echo/src/EchoEngine/WordAlignmentEngineServiceV1.cs +++ b/src/Echo/src/EchoEngine/WordAlignmentEngineServiceV1.cs @@ -69,7 +69,7 @@ await client.BuildStartedAsync( try { using ( - AsyncClientStreamingCall call = client.InsertInferences( + AsyncClientStreamingCall call = client.InsertWordAlignments( cancellationToken: cancellationToken ) ) @@ -128,7 +128,7 @@ await client.BuildStartedAsync( targetLine.Split().Length ); await call.RequestStream.WriteAsync( - new InsertInferencesRequest + new InsertWordAlignmentsRequest { EngineId = request.EngineId, CorpusId = corpus.Id, @@ -168,7 +168,7 @@ await call.RequestStream.WriteAsync( targetLine.Split().Length ); await call.RequestStream.WriteAsync( - new InsertInferencesRequest + new InsertWordAlignmentsRequest { EngineId = request.EngineId, CorpusId = corpus.Id, diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs index 616f0e11..4e398454 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs @@ -75,6 +75,15 @@ public static IMachineBuilder AddThotSmtModel(this IMachineBuilder builder, ICon return builder; } + public static IMachineBuilder AddWordAlignmentModel(this IMachineBuilder builder) + { + builder.Services.Configure( + builder.Configuration.GetSection(WordAlignmentModelOptions.Key) + ); + builder.Services.AddSingleton(); + return builder; + } + public static IMachineBuilder AddTransferEngine(this IMachineBuilder builder) { builder.Services.AddSingleton(); @@ -485,7 +494,7 @@ public static IMachineBuilder AddThot(this IMachineBuilder builder) { try { - builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser(); + builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser().AddWordAlignmentModel(); } catch (ArgumentException) { @@ -516,9 +525,9 @@ public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder) var smtTransferEngineOptions = new SmtTransferEngineOptions(); builder.Configuration.GetSection(SmtTransferEngineOptions.Key).Bind(smtTransferEngineOptions); string? smtDriveLetter = Path.GetPathRoot(smtTransferEngineOptions.EnginesDir)?[..1]; - var statisticsEngineOptions = new WordAlignmentEngineOptions(); - builder.Configuration.GetSection(WordAlignmentEngineOptions.Key).Bind(statisticsEngineOptions); - string? statisticsDriveLetter = Path.GetPathRoot(statisticsEngineOptions.EnginesDir)?[..1]; + var statisticalEngineOptions = new WordAlignmentEngineOptions(); + builder.Configuration.GetSection(WordAlignmentEngineOptions.Key).Bind(statisticalEngineOptions); + string? statisticsDriveLetter = Path.GetPathRoot(statisticalEngineOptions.EnginesDir)?[..1]; if (smtDriveLetter is null || statisticsDriveLetter is null) throw new InvalidOperationException("SMT Engine and Statistical directory is required"); if (smtDriveLetter != statisticsDriveLetter) diff --git a/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs b/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs index 4b3bd5f4..25b2bf82 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs @@ -4,7 +4,6 @@ public class BuildJobService(IEnumerable runners, IRep : IBuildJobService where TEngine : ITrainingEngine { - // TODO: make some sort of service to get the engine repos. protected readonly Dictionary Runners = runners.ToDictionary(r => r.Type); protected readonly IRepository Engines = engines; diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IPlatformService.cs b/src/Machine/src/Serval.Machine.Shared/Services/IPlatformService.cs index eb718920..24df0ebe 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IPlatformService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IPlatformService.cs @@ -23,7 +23,7 @@ Task BuildCompletedAsync( Task BuildFaultedAsync(string buildId, string message, CancellationToken cancellationToken = default); Task BuildRestartingAsync(string buildId, CancellationToken cancellationToken = default); - Task InsertPretranslationsAsync( + Task InsertInferencesAsync( string engineId, Stream pretranslationsStream, CancellationToken cancellationToken = default diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs index 01776084..4249c871 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs @@ -1,6 +1,6 @@ namespace Serval.Machine.Shared.Services; -public interface ISmtModelFactory +public interface ISmtModelFactory : IModelFactory { IInteractiveTranslationModel Create( string engineDir, @@ -8,13 +8,4 @@ IInteractiveTranslationModel Create( IDetokenizer detokenizer, ITruecaser truecaser ); - ITrainer CreateTrainer( - string engineDir, - IRangeTokenizer tokenizer, - IParallelTextCorpus corpus - ); - void InitNew(string engineDir); - void Cleanup(string engineDir); - Task UpdateEngineFromAsync(string engineDir, Stream source, CancellationToken cancellationToken = default); - Task SaveEngineToAsync(string engineDir, Stream destination, CancellationToken cancellationToken = default); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentEngineService.cs index feaf4f97..f403d33e 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentEngineService.cs @@ -1,4 +1,6 @@ -namespace Serval.Machine.Shared.Services; +using Serval.WordAlignment.V1; + +namespace Serval.Machine.Shared.Services; public interface IWordAlignmentEngineService { @@ -13,7 +15,7 @@ Task CreateAsync( ); Task DeleteAsync(string engineId, CancellationToken cancellationToken = default); - Task GetBestPhraseAlignmentAsync( + Task GetBestWordAlignmentAsync( string engineId, string sourceSegment, string targetSegment, @@ -24,7 +26,7 @@ Task StartBuildAsync( string engineId, string buildId, string? buildOptions, - IReadOnlyList corpora, + IReadOnlyList corpora, CancellationToken cancellationToken = default ); diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentModelFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentModelFactory.cs index b1ebe197..5ac39b85 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentModelFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IWordAlignmentModelFactory.cs @@ -1,11 +1,6 @@ namespace Serval.Machine.Shared.Services; -public interface IWordAlignmentModelFactory +public interface IWordAlignmentModelFactory : IModelFactory { IWordAlignmentModel Create(string engineDir); - ITrainer CreateTrainer(string engineDir, ITokenizer tokenizer, IParallelTextCorpus corpus); - void InitNew(string engineDir); - void Cleanup(string engineDir); - Task UpdateEngineFromAsync(string engineDir, Stream source, CancellationToken cancellationToken = default); - Task SaveEngineToAsync(string engineDir, Stream destination, CancellationToken cancellationToken = default); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs index 4b456bf9..8670a56b 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs @@ -56,7 +56,7 @@ CancellationToken cancellationToken bool sourceTagInBaseModel = ResolveLanguageCodeForBaseModel(engine.SourceLanguage, out string srcLang); bool targetTagInBaseModel = ResolveLanguageCodeForBaseModel(engine.TargetLanguage, out string trgLang); - (int trainCount, int pretranslateCount) = await WriteDataFilesAsync( + (int trainCount, int inferenceCount) = await WriteDataFilesAsync( buildId, data, buildOptions, @@ -70,7 +70,7 @@ CancellationToken cancellationToken { "EngineId", engineId }, { "BuildId", buildId }, { "NumTrainRows", trainCount }, - { "NumInferenceRows", pretranslateCount }, + { "NumInferenceRows", inferenceCount }, { "SourceLanguageResolved", srcLang }, { "TargetLanguageResolved", trgLang } }; @@ -86,7 +86,7 @@ CancellationToken cancellationToken var executionData = new Dictionary() { { "trainCount", trainCount.ToString(CultureInfo.InvariantCulture) }, - { "pretranslateCount", pretranslateCount.ToString(CultureInfo.InvariantCulture) } + { "inference", inferenceCount.ToString(CultureInfo.InvariantCulture) } }; await PlatformService.UpdateBuildExecutionDataAsync(engineId, buildId, executionData, cancellationToken); @@ -105,6 +105,7 @@ CancellationToken cancellationToken throw new OperationCanceledException(); } + //TODO: Move this method to translation-specific PreprocessBuildJob protected virtual async Task<(int TrainCount, int InferenceCount)> WriteDataFilesAsync( string buildId, IReadOnlyList corpora, @@ -142,9 +143,9 @@ await ParallelCorpusPreprocessingService.PreprocessAsync( if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) trainCount++; }, - async (row, corpus) => + async (row, isInTrainingData, corpus) => { - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0) + if (row.SourceSegment.Length > 0 && !isInTrainingData) { pretranslateWriter.WriteStartObject(); pretranslateWriter.WriteString("corpusId", corpus.Id); diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformOutboxConstants.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformOutboxConstants.cs index 560ca463..392d67b9 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformOutboxConstants.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformOutboxConstants.cs @@ -10,6 +10,6 @@ public static class ServalTranslationPlatformOutboxConstants public const string BuildFaulted = "BuildFaulted"; public const string BuildRestarting = "BuildRestarting"; public const string InsertPretranslations = "InsertPretranslations"; - public const string IncrementTranslationEngineCorpusSize = "IncrementTranslationEngineCorpusSize"; + public const string IncrementTrainEngineCorpusSize = "IncrementTrainEngineCorpusSize"; public const string UpdateBuildExecutionData = "UpdateBuildExecutionData"; } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformOutboxMessageHandler.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformOutboxMessageHandler.cs index 4be48091..b1df9995 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformOutboxMessageHandler.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformOutboxMessageHandler.cs @@ -6,12 +6,12 @@ public class ServalTranslationPlatformOutboxMessageHandler(TranslationPlatformAp : IOutboxMessageHandler { private readonly TranslationPlatformApi.TranslationPlatformApiClient _client = client; - private static readonly JsonSerializerOptions JsonSerializerOptions = - new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + private readonly JsonSerializerOptions _jsonSerializerOptions = MessageOutboxOptions.JsonSerializerOptions; public string OutboxId => ServalTranslationPlatformOutboxConstants.OutboxId; public async Task HandleMessageAsync( + string groupId, string method, string? content, Stream? contentStream, @@ -22,51 +22,51 @@ public async Task HandleMessageAsync( { case ServalTranslationPlatformOutboxConstants.BuildStarted: await _client.BuildStartedAsync( - JsonSerializer.Deserialize(content!), + JsonSerializer.Deserialize(content!, _jsonSerializerOptions), cancellationToken: cancellationToken ); break; case ServalTranslationPlatformOutboxConstants.BuildCompleted: await _client.BuildCompletedAsync( - JsonSerializer.Deserialize(content!), + JsonSerializer.Deserialize(content!, _jsonSerializerOptions), cancellationToken: cancellationToken ); break; case ServalTranslationPlatformOutboxConstants.BuildCanceled: await _client.BuildCanceledAsync( - JsonSerializer.Deserialize(content!), + JsonSerializer.Deserialize(content!, _jsonSerializerOptions), cancellationToken: cancellationToken ); break; case ServalTranslationPlatformOutboxConstants.BuildFaulted: await _client.BuildFaultedAsync( - JsonSerializer.Deserialize(content!), + JsonSerializer.Deserialize(content!, _jsonSerializerOptions), cancellationToken: cancellationToken ); break; case ServalTranslationPlatformOutboxConstants.BuildRestarting: await _client.BuildRestartingAsync( - JsonSerializer.Deserialize(content!), + JsonSerializer.Deserialize(content!, _jsonSerializerOptions), cancellationToken: cancellationToken ); break; - case ServalTranslationPlatformOutboxConstants.InsertInferences: + case ServalTranslationPlatformOutboxConstants.InsertPretranslations: IAsyncEnumerable pretranslations = JsonSerializer .DeserializeAsyncEnumerable( contentStream!, - JsonSerializerOptions, + _jsonSerializerOptions, cancellationToken ) .OfType(); - using (var call = _client.InsertInferences(cancellationToken: cancellationToken)) + using (var call = _client.InsertPretranslations(cancellationToken: cancellationToken)) { await foreach (Pretranslation pretranslation in pretranslations) { await call.RequestStream.WriteAsync( - new InsertInferencesRequest + new InsertPretranslationsRequest { - EngineId = content!, + EngineId = groupId, CorpusId = pretranslation.CorpusId, TextId = pretranslation.TextId, Refs = { pretranslation.Refs }, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformService.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformService.cs index 155dc4fb..a3fefa37 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalTranslationPlatformService.cs @@ -17,7 +17,7 @@ await _outboxService.EnqueueMessageAsync( ServalTranslationPlatformOutboxConstants.OutboxId, ServalTranslationPlatformOutboxConstants.BuildStarted, buildId, - JsonSerializer.Serialize(new BuildStartedRequest { BuildId = buildId }), + new BuildStartedRequest { BuildId = buildId }, cancellationToken: cancellationToken ); } @@ -33,14 +33,12 @@ await _outboxService.EnqueueMessageAsync( ServalTranslationPlatformOutboxConstants.OutboxId, ServalTranslationPlatformOutboxConstants.BuildCompleted, buildId, - JsonSerializer.Serialize( - new BuildCompletedRequest - { - BuildId = buildId, - CorpusSize = trainSize, - Confidence = confidence - } - ), + new BuildCompletedRequest + { + BuildId = buildId, + CorpusSize = trainSize, + Confidence = confidence + }, cancellationToken: cancellationToken ); } @@ -51,7 +49,7 @@ await _outboxService.EnqueueMessageAsync( ServalTranslationPlatformOutboxConstants.OutboxId, ServalTranslationPlatformOutboxConstants.BuildCanceled, buildId, - JsonSerializer.Serialize(new BuildCanceledRequest { BuildId = buildId }), + new BuildCanceledRequest { BuildId = buildId }, cancellationToken: cancellationToken ); } @@ -62,7 +60,7 @@ await _outboxService.EnqueueMessageAsync( ServalTranslationPlatformOutboxConstants.OutboxId, ServalTranslationPlatformOutboxConstants.BuildFaulted, buildId, - JsonSerializer.Serialize(new BuildFaultedRequest { BuildId = buildId, Message = message }), + new BuildFaultedRequest { BuildId = buildId, Message = message }, cancellationToken: cancellationToken ); } @@ -73,7 +71,7 @@ await _outboxService.EnqueueMessageAsync( ServalTranslationPlatformOutboxConstants.OutboxId, ServalTranslationPlatformOutboxConstants.BuildRestarting, buildId, - JsonSerializer.Serialize(new BuildRestartingRequest { BuildId = buildId }), + new BuildRestartingRequest { BuildId = buildId }, cancellationToken: cancellationToken ); } @@ -112,10 +110,9 @@ public async Task InsertInferencesAsync( CancellationToken cancellationToken = default ) { - await _outboxService.EnqueueMessageAsync( + await _outboxService.EnqueueMessageStreamAsync( ServalTranslationPlatformOutboxConstants.OutboxId, - ServalTranslationPlatformOutboxConstants.InsertInferences, - engineId, + ServalTranslationPlatformOutboxConstants.InsertPretranslations, engineId, pretranslationsStream, cancellationToken: cancellationToken @@ -132,7 +129,25 @@ await _outboxService.EnqueueMessageAsync( ServalTranslationPlatformOutboxConstants.OutboxId, ServalTranslationPlatformOutboxConstants.IncrementTrainEngineCorpusSize, engineId, - JsonSerializer.Serialize(new IncrementTrainEngineCorpusSizeRequest { EngineId = engineId, Count = count }), + new IncrementTrainEngineCorpusSizeRequest { EngineId = engineId, Count = count }, + cancellationToken: cancellationToken + ); + } + + public async Task UpdateBuildExecutionDataAsync( + string engineId, + string buildId, + IReadOnlyDictionary executionData, + CancellationToken cancellationToken = default + ) + { + var request = new UpdateBuildExecutionDataRequest { EngineId = engineId, BuildId = buildId }; + request.ExecutionData.Add((IDictionary)executionData); + await _outboxService.EnqueueMessageAsync( + ServalTranslationPlatformOutboxConstants.OutboxId, + ServalTranslationPlatformOutboxConstants.UpdateBuildExecutionData, + engineId, + request, cancellationToken: cancellationToken ); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs index 7330df08..a90f3a45 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs @@ -38,10 +38,10 @@ ServerCallContext context ) { IWordAlignmentEngineService engineService = GetEngineService(request.EngineType); - SIL.Machine.Translation.WordAlignmentResult result; + WordAlignmentResult result; try { - result = await engineService.GetBestPhraseAlignmentAsync( + result = await engineService.GetBestWordAlignmentAsync( request.EngineId, request.SourceSegment, request.TargetSegment, @@ -53,7 +53,7 @@ ServerCallContext context throw new RpcException(new Status(StatusCode.Aborted, e.Message, e)); } - return new GetWordAlignmentResponse { Result = Map(result) }; + return new GetWordAlignmentResponse { Result = result }; } public override async Task StartBuild(StartBuildRequest request, ServerCallContext context) @@ -116,29 +116,6 @@ private static EngineType GetEngineType(string engineTypeStr) ); } - private static WordAlignment.V1.WordAlignmentResult Map(SIL.Machine.Translation.WordAlignmentResult source) - { - return new WordAlignment.V1.WordAlignmentResult - { - SourceTokens = { source.SourceTokens }, - TargetTokens = { source.TargetTokens }, - Alignment = { Map(source.Alignment) }, - Confidences = { source.Confidences } - }; - } - - private static IEnumerable Map(WordAlignmentMatrix source) - { - for (int i = 0; i < source.RowCount; i++) - { - for (int j = 0; j < source.ColumnCount; j++) - { - if (source[i, j]) - yield return new WordAlignment.V1.AlignedWordPair { SourceIndex = i, TargetIndex = j }; - } - } - } - private static SIL.ServiceToolkit.Models.ParallelCorpus Map(WordAlignment.V1.ParallelCorpus source) { return new SIL.ServiceToolkit.Models.ParallelCorpus diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxConstants.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxConstants.cs index 573a40d9..b6913028 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxConstants.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxConstants.cs @@ -11,4 +11,5 @@ public static class ServalWordAlignmentPlatformOutboxConstants public const string BuildRestarting = "BuildRestarting"; public const string IncrementTrainEngineCorpusSize = "IncrementTrainEngineCorpusSize"; public const string InsertInferences = "InsertInferences"; + public const string UpdateBuildExecutionData = "UpdateBuildExecutionData"; } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxMessageHandler.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxMessageHandler.cs index b063b8f7..87140b86 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxMessageHandler.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformOutboxMessageHandler.cs @@ -8,11 +8,12 @@ WordAlignmentPlatformApi.WordAlignmentPlatformApiClient client { private readonly WordAlignmentPlatformApi.WordAlignmentPlatformApiClient _client = client; private static readonly JsonSerializerOptions JsonSerializerOptions = - new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase, Converters = { new WordAlignmentConverter() } }; public string OutboxId => ServalWordAlignmentPlatformOutboxConstants.OutboxId; public async Task HandleMessageAsync( + string groupId, string method, string? content, Stream? contentStream, @@ -52,8 +53,6 @@ await _client.BuildRestartingAsync( ); break; case ServalWordAlignmentPlatformOutboxConstants.InsertInferences: - var jsonSerializerOptions = new JsonSerializerOptions(JsonSerializerOptions); - jsonSerializerOptions.Converters.Add(new WordAlignmentJsonConverter()); IAsyncEnumerable wordAlignments = JsonSerializer .DeserializeAsyncEnumerable( contentStream!, @@ -62,14 +61,14 @@ await _client.BuildRestartingAsync( ) .OfType(); - using (var call = _client.InsertInferences(cancellationToken: cancellationToken)) + using (var call = _client.InsertWordAlignments(cancellationToken: cancellationToken)) { await foreach (Models.WordAlignment wordAlignment in wordAlignments) { await call.RequestStream.WriteAsync( - new InsertInferencesRequest + new InsertWordAlignmentsRequest { - EngineId = content!, + EngineId = groupId, CorpusId = wordAlignment.CorpusId, TextId = wordAlignment.TextId, Refs = { wordAlignment.Refs }, @@ -109,32 +108,83 @@ await _client.IncrementTrainEngineCorpusSizeAsync( }; } } -} -public class WordAlignmentJsonConverter : JsonConverter -{ - public override object Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + internal class WordAlignmentConverter : JsonConverter { - switch (reader.TokenType) + public override Models.WordAlignment Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options + ) { - case JsonTokenType.True: - return true; - case JsonTokenType.False: - return false; - case JsonTokenType.Number when reader.TryGetInt64(out long l): - return l; - case JsonTokenType.Number: - return reader.GetDouble(); - case JsonTokenType.String: - var str = reader.GetString(); - if (SIL.Machine.Corpora.AlignedWordPair.TryParse(str, out var alignedWordPair)) - return alignedWordPair; - return str!; - default: - throw new JsonException(); + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException( + $"Expected StartObject token at the beginning of WordAlignment object but instead encountered {reader.TokenType}" + ); + } + string corpusId = "", + textId = ""; + IReadOnlyList confidences = []; + IReadOnlyList refs = [], + sourceTokens = [], + targetTokens = []; + IReadOnlyList alignedWordPairs = []; + while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) + { + if (reader.TokenType == JsonTokenType.PropertyName) + { + string s = reader.GetString()!; + switch (s) + { + case "corpus_id": + reader.Read(); + corpusId = reader.GetString()!; + break; + case "text_id": + reader.Read(); + textId = reader.GetString()!; + break; + case "confidences": + reader.Read(); + confidences = JsonSerializer.Deserialize>(ref reader, options)!.ToArray(); + break; + case "refs": + reader.Read(); + refs = JsonSerializer.Deserialize>(ref reader, options)!.ToArray(); + break; + case "source_tokens": + reader.Read(); + sourceTokens = JsonSerializer.Deserialize>(ref reader, options)!.ToArray(); + break; + case "target_tokens": + reader.Read(); + targetTokens = JsonSerializer.Deserialize>(ref reader, options)!.ToArray(); + break; + case "alignment": + reader.Read(); + alignedWordPairs = SIL.Machine.Corpora.AlignedWordPair.Parse(reader.GetString()).ToArray(); + break; + default: + throw new JsonException( + $"Unexpected property name {s} when deserializing WordAlignment object" + ); + } + } + } + return new Models.WordAlignment() + { + CorpusId = corpusId, + TextId = textId, + Refs = refs, + Alignment = alignedWordPairs, + Confidences = confidences, + SourceTokens = sourceTokens, + TargetTokens = targetTokens + }; } - } - public override void Write(Utf8JsonWriter writer, object objectToWrite, JsonSerializerOptions options) => - JsonSerializer.Serialize(writer, objectToWrite, objectToWrite.GetType(), options); + public override void Write(Utf8JsonWriter writer, Models.WordAlignment value, JsonSerializerOptions options) => + throw new NotSupportedException(); + } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformService.cs b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformService.cs index 23f611ee..eedf15d5 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentPlatformService.cs @@ -17,7 +17,7 @@ await _outboxService.EnqueueMessageAsync( ServalWordAlignmentPlatformOutboxConstants.OutboxId, ServalWordAlignmentPlatformOutboxConstants.BuildStarted, buildId, - JsonSerializer.Serialize(new BuildStartedRequest { BuildId = buildId }), + new BuildStartedRequest { BuildId = buildId }, cancellationToken: cancellationToken ); } @@ -33,14 +33,12 @@ await _outboxService.EnqueueMessageAsync( ServalWordAlignmentPlatformOutboxConstants.OutboxId, ServalWordAlignmentPlatformOutboxConstants.BuildCompleted, buildId, - JsonSerializer.Serialize( - new BuildCompletedRequest - { - BuildId = buildId, - CorpusSize = trainSize, - Confidence = confidence - } - ), + new BuildCompletedRequest + { + BuildId = buildId, + CorpusSize = trainSize, + Confidence = confidence + }, cancellationToken: cancellationToken ); } @@ -51,7 +49,7 @@ await _outboxService.EnqueueMessageAsync( ServalWordAlignmentPlatformOutboxConstants.OutboxId, ServalWordAlignmentPlatformOutboxConstants.BuildCanceled, buildId, - JsonSerializer.Serialize(new BuildCanceledRequest { BuildId = buildId }), + new BuildCanceledRequest { BuildId = buildId }, cancellationToken: cancellationToken ); } @@ -62,7 +60,7 @@ await _outboxService.EnqueueMessageAsync( ServalWordAlignmentPlatformOutboxConstants.OutboxId, ServalWordAlignmentPlatformOutboxConstants.BuildFaulted, buildId, - JsonSerializer.Serialize(new BuildFaultedRequest { BuildId = buildId, Message = message }), + new BuildFaultedRequest { BuildId = buildId, Message = message }, cancellationToken: cancellationToken ); } @@ -73,7 +71,7 @@ await _outboxService.EnqueueMessageAsync( ServalWordAlignmentPlatformOutboxConstants.OutboxId, ServalWordAlignmentPlatformOutboxConstants.BuildRestarting, buildId, - JsonSerializer.Serialize(new BuildRestartingRequest { BuildId = buildId }), + new BuildRestartingRequest { BuildId = buildId }, cancellationToken: cancellationToken ); } @@ -112,11 +110,10 @@ public async Task InsertInferencesAsync( CancellationToken cancellationToken = default ) { - await _outboxService.EnqueueMessageAsync( + await _outboxService.EnqueueMessageStreamAsync( ServalWordAlignmentPlatformOutboxConstants.OutboxId, ServalWordAlignmentPlatformOutboxConstants.InsertInferences, engineId, - engineId, wordAlignmentsStream, cancellationToken: cancellationToken ); @@ -132,7 +129,25 @@ await _outboxService.EnqueueMessageAsync( ServalWordAlignmentPlatformOutboxConstants.OutboxId, ServalWordAlignmentPlatformOutboxConstants.IncrementTrainEngineCorpusSize, engineId, - JsonSerializer.Serialize(new IncrementTrainEngineCorpusSizeRequest { EngineId = engineId, Count = count }), + new IncrementTrainEngineCorpusSizeRequest { EngineId = engineId, Count = count }, + cancellationToken: cancellationToken + ); + } + + public async Task UpdateBuildExecutionDataAsync( + string engineId, + string buildId, + IReadOnlyDictionary executionData, + CancellationToken cancellationToken = default + ) + { + var request = new UpdateBuildExecutionDataRequest { EngineId = engineId, BuildId = buildId }; + request.ExecutionData.Add((IDictionary)executionData); + await _outboxService.EnqueueMessageAsync( + ServalWordAlignmentPlatformOutboxConstants.OutboxId, + ServalWordAlignmentPlatformOutboxConstants.UpdateBuildExecutionData, + engineId, + request, cancellationToken: cancellationToken ); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs index e881485f..129cd5b3 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs @@ -1,4 +1,6 @@ -namespace Serval.Machine.Shared.Services; +using Serval.WordAlignment.V1; + +namespace Serval.Machine.Shared.Services; public class StatisticalEngineService( IDistributedReaderWriterLockFactory lockFactory, @@ -52,7 +54,7 @@ public async Task CreateAsync( return wordAlignmentEngine; } - public async Task GetBestPhraseAlignmentAsync( + public async Task GetBestWordAlignmentAsync( string engineId, string sourceSegment, string targetSegment, @@ -61,6 +63,8 @@ public async Task GetBestPhraseAlignmentAsync( { WordAlignmentEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); WordAlignmentEngineState state = _stateService.Get(engineId); + if (state.IsMarkedForDeletion) + throw new InvalidOperationException("Engine is marked for deletion."); IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); WordAlignmentResult result = await @lock.ReaderLockAsync( @@ -72,33 +76,30 @@ public async Task GetBestPhraseAlignmentAsync( // there is no way to cancel this call IReadOnlyList sourceTokens = tokenizer.Tokenize(sourceSegment).ToList(); IReadOnlyList targetTokens = tokenizer.Tokenize(targetSegment).ToList(); - IReadOnlyCollection wordPairs = wordAlignmentEngine.GetBestAlignedWordPairs( - sourceTokens, - targetTokens - ); + IReadOnlyCollection wordPairs = + wordAlignmentEngine.GetBestAlignedWordPairs(sourceTokens, targetTokens); wordAlignmentEngine.ComputeAlignedWordPairScores(sourceTokens, targetTokens, wordPairs); - return new WordAlignmentResult( - sourceTokens: sourceTokens, - targetTokens: targetTokens, - alignment: new WordAlignmentMatrix( - sourceTokens.Count, - targetTokens.Count, - wordPairs.Select(wp => (wp.SourceIndex, wp.TargetIndex)) - ), - confidences: wordPairs.Select(wp => wp.AlignmentScore * wp.TranslationScore).ToList() - ); + return new WordAlignmentResult() + { + SourceTokens = { sourceTokens }, + TargetTokens = { targetTokens }, + Alignment = { wordPairs.Select(Map) }, + Confidences = { wordPairs.Select(wp => wp.AlignmentScore).ToList() } + }; }, cancellationToken: cancellationToken ); state.Touch(); return result; - - throw new NotImplementedException(); } public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) { + // there is no way to cancel this call + WordAlignmentEngineState state = _stateService.Get(engineId); + state.IsMarkedForDeletion = true; + await CancelBuildJobAsync(engineId, cancellationToken); await _dataAccessContext.WithTransactionAsync( @@ -106,13 +107,11 @@ await _dataAccessContext.WithTransactionAsync( { await _engines.DeleteAsync(e => e.EngineId == engineId, ct); }, - cancellationToken: cancellationToken + cancellationToken: CancellationToken.None ); await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); - WordAlignmentEngineState state = _stateService.Get(engineId); _stateService.Remove(engineId); - // there is no way to cancel this call state.DeleteData(); state.Dispose(); await _lockFactory.DeleteAsync(engineId, CancellationToken.None); @@ -122,7 +121,7 @@ public async Task StartBuildAsync( string engineId, string buildId, string? buildOptions, - IReadOnlyList corpora, + IReadOnlyList corpora, CancellationToken cancellationToken = default ) { @@ -189,4 +188,13 @@ private async Task GetBuiltEngineAsync(string engineId, Can throw new EngineNotBuiltException("The engine must be built first."); return engine; } + + private static WordAlignment.V1.AlignedWordPair Map(SIL.Machine.Corpora.AlignedWordPair alignedWordPair) + { + return new WordAlignment.V1.AlignedWordPair + { + SourceIndex = alignedWordPair.SourceIndex, + TargetIndex = alignedWordPair.TargetIndex + }; + } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs index cfe8354b..932b6242 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalPostprocessBuildJob.cs @@ -8,7 +8,7 @@ public class StatisticalPostprocessBuildJob( ILogger logger, ISharedFileService sharedFileService, IDistributedReaderWriterLockFactory lockFactory, - ISmtModelFactory smtModelFactory, + IWordAlignmentModelFactory wordAlignmentModelFactory, IOptionsMonitor buildOptions, IOptionsMonitor engineOptions ) @@ -22,7 +22,7 @@ IOptionsMonitor engineOptions buildOptions ) { - private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; + private readonly IWordAlignmentModelFactory _wordAlignmentModelFactory = wordAlignmentModelFactory; private readonly IOptionsMonitor _engineOptions = engineOptions; private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; @@ -38,7 +38,7 @@ CancellationToken cancellationToken await using ( Stream wordAlignmentStream = await SharedFileService.OpenReadAsync( - $"builds/{buildId}/word_alignment_outputs.json", + $"builds/{buildId}/word_alignments.outputs.json", cancellationToken ) ) @@ -74,7 +74,7 @@ protected override async Task SaveModelAsync(string engineId, string buildI Stream engineStream = await SharedFileService.OpenReadAsync($"builds/{buildId}/model.tar.gz", ct) ) { - await _smtModelFactory.UpdateEngineFromAsync( + await _wordAlignmentModelFactory.UpdateEngineFromAsync( Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId), engineStream, ct diff --git a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalTrainBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalTrainBuildJob.cs index d68376a0..b56d6f90 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/StatisticalTrainBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/StatisticalTrainBuildJob.cs @@ -5,7 +5,9 @@ public class StatisticalTrainBuildJob( IRepository engines, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, - ILogger logger + ILogger logger, + ISharedFileService sharedFileService, + IWordAlignmentModelFactory wordAlignmentModelFactory ) : HangfireBuildJob( platformServices.First(ps => ps.EngineGroup == EngineGroup.WordAlignment), @@ -15,7 +17,15 @@ ILogger logger logger ) { - protected override Task DoWorkAsync( + private static readonly JsonWriterOptions WordAlignmentWriterOptions = new() { Indented = true }; + private static readonly JsonSerializerOptions JsonSerializerOptions = + new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + private const int BatchSize = 128; + + private readonly ISharedFileService _sharedFileService = sharedFileService; + private readonly IWordAlignmentModelFactory _wordAlignmentFactory = wordAlignmentModelFactory; + + protected override async Task DoWorkAsync( string engineId, string buildId, object? data, @@ -23,6 +33,163 @@ protected override Task DoWorkAsync( CancellationToken cancellationToken ) { - throw new NotImplementedException(); + using TempDirectory tempDir = new(buildId); + string corpusDir = Path.Combine(tempDir.Path, "corpus"); + await DownloadDataAsync(buildId, corpusDir, cancellationToken); + + // assemble corpus + ITextCorpus sourceCorpus = new TextFileTextCorpus(Path.Combine(corpusDir, "train.src.txt")); + ITextCorpus targetCorpus = new TextFileTextCorpus(Path.Combine(corpusDir, "train.trg.txt")); + IParallelTextCorpus parallelCorpus = sourceCorpus.AlignRows(targetCorpus); + + // train word alignment model + string engineDir = Path.Combine(tempDir.Path, "engine"); + int trainCount = await TrainAsync(buildId, engineDir, parallelCorpus, cancellationToken); + + cancellationToken.ThrowIfCancellationRequested(); + + await GenerateWordAlignmentsAsync(buildId, engineDir, cancellationToken); + + bool canceling = !await BuildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + EngineType.Statistical, + engineId, + buildId, + BuildStage.Postprocess, + buildOptions: buildOptions, + data: (trainCount, 0.0), + cancellationToken: cancellationToken + ); + if (canceling) + throw new OperationCanceledException(); + } + + protected override async Task CleanupAsync( + string engineId, + string buildId, + object? data, + JobCompletionStatus completionStatus + ) + { + if (completionStatus is JobCompletionStatus.Canceled) + { + try + { + await _sharedFileService.DeleteAsync($"builds/{buildId}/"); + } + catch (Exception e) + { + Logger.LogWarning(e, "Unable to to delete job data for build {BuildId}.", buildId); + } + } + } + + private async Task DownloadDataAsync(string buildId, string corpusDir, CancellationToken cancellationToken) + { + Directory.CreateDirectory(corpusDir); + await using Stream srcText = await _sharedFileService.OpenReadAsync( + $"builds/{buildId}/train.src.txt", + cancellationToken + ); + await using FileStream srcFileStream = File.Create(Path.Combine(corpusDir, "train.src.txt")); + await srcText.CopyToAsync(srcFileStream, cancellationToken); + + await using Stream tgtText = await _sharedFileService.OpenReadAsync( + $"builds/{buildId}/train.trg.txt", + cancellationToken + ); + await using FileStream tgtFileStream = File.Create(Path.Combine(corpusDir, "train.trg.txt")); + await tgtText.CopyToAsync(tgtFileStream, cancellationToken); + } + + private async Task TrainAsync( + string buildId, + string engineDir, + IParallelTextCorpus parallelCorpus, + CancellationToken cancellationToken + ) + { + _wordAlignmentFactory.InitNew(engineDir); + LatinWordTokenizer tokenizer = new(); + using ITrainer wordAlignmentTrainer = _wordAlignmentFactory.CreateTrainer(engineDir, tokenizer, parallelCorpus); + cancellationToken.ThrowIfCancellationRequested(); + + var progress = new BuildProgress(PlatformService, buildId); + await wordAlignmentTrainer.TrainAsync(progress, cancellationToken); + + int trainCorpusSize = wordAlignmentTrainer.Stats.TrainCorpusSize; + + cancellationToken.ThrowIfCancellationRequested(); + + await wordAlignmentTrainer.SaveAsync(cancellationToken); + + await using Stream engineStream = await _sharedFileService.OpenWriteAsync( + $"builds/{buildId}/model.tar.gz", + cancellationToken + ); + await _wordAlignmentFactory.SaveEngineToAsync(engineDir, engineStream, cancellationToken); + return trainCorpusSize; + } + + private async Task GenerateWordAlignmentsAsync( + string buildId, + string engineDir, + CancellationToken cancellationToken + ) + { + await using Stream sourceStream = await _sharedFileService.OpenReadAsync( + $"builds/{buildId}/word_alignments.inputs.json", + cancellationToken + ); + + IAsyncEnumerable wordAlignments = JsonSerializer + .DeserializeAsyncEnumerable(sourceStream, JsonSerializerOptions, cancellationToken) + .OfType(); + + await using Stream targetStream = await _sharedFileService.OpenWriteAsync( + $"builds/{buildId}/word_alignments.outputs.json", + cancellationToken + ); + await using Utf8JsonWriter targetWriter = new(targetStream, WordAlignmentWriterOptions); + + LatinWordTokenizer tokenizer = new(); + LatinWordDetokenizer detokenizer = new(); + using IWordAlignmentModel wordAlignmentModel = _wordAlignmentFactory.Create(engineDir); + await foreach (IReadOnlyList batch in BatchAsync(wordAlignments)) + { + (IReadOnlyList Source, IReadOnlyList Target)[] segments = batch + .Select(p => (p.SourceTokens, p.TargetTokens)) + .ToArray(); + IReadOnlyList results = wordAlignmentModel.AlignBatch(segments); + foreach ((Models.WordAlignment wordAlignment, WordAlignmentMatrix result) in batch.Zip(results)) + { + JsonSerializer.Serialize( + targetWriter, + wordAlignment with + { + Alignment = result.ToAlignedWordPairs().ToList() + }, + JsonSerializerOptions + ); + } + } + } + + public static async IAsyncEnumerable> BatchAsync( + IAsyncEnumerable wordAlignments + ) + { + List batch = []; + await foreach (Models.WordAlignment item in wordAlignments) + { + batch.Add(item); + if (batch.Count == BatchSize) + { + yield return batch; + batch = []; + } + } + if (batch.Count > 0) + yield return batch; } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineState.cs b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineState.cs index d075af95..bca9760c 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineState.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineState.cs @@ -12,14 +12,15 @@ string engineId private readonly IOptionsMonitor _options = options; private readonly AsyncLock _lock = new(); - private IWordAlignmentEngine? _wordAlignmentEngine; + private IWordAlignmentModel? _wordAlignmentModel; public string EngineId { get; } = engineId; public bool IsUpdated { get; set; } + public bool IsMarkedForDeletion { get; set; } public int CurrentBuildRevision { get; set; } = -1; public DateTime LastUsedTime { get; private set; } = DateTime.UtcNow; - public bool IsLoaded => _wordAlignmentEngine != null; + public bool IsLoaded => _wordAlignmentModel != null; private string EngineDir => Path.Combine(_options.CurrentValue.EnginesDir, EngineId); @@ -33,17 +34,20 @@ public async Task GetEngineAsync( CancellationToken cancellationToken = default ) { + if (IsMarkedForDeletion) + throw new InvalidOperationException("Engine is marked for deletion"); + using (await _lock.LockAsync(cancellationToken)) { - if (_wordAlignmentEngine is not null && CurrentBuildRevision != -1 && buildRevision != CurrentBuildRevision) + if (_wordAlignmentModel is not null && CurrentBuildRevision != -1 && buildRevision != CurrentBuildRevision) { IsUpdated = false; Unload(); } - _wordAlignmentEngine ??= _wordAlignmentModelFactory.Create(EngineDir); + _wordAlignmentModel ??= _wordAlignmentModelFactory.Create(EngineDir); CurrentBuildRevision = buildRevision; - return _wordAlignmentEngine; + return _wordAlignmentModel; } } @@ -55,11 +59,14 @@ public void DeleteData() public void Commit(int buildRevision, TimeSpan inactiveTimeout) { - if (_wordAlignmentEngine is null) + if (_wordAlignmentModel is null) return; if (CurrentBuildRevision == -1) CurrentBuildRevision = buildRevision; + + SaveModel(); + if (buildRevision != CurrentBuildRevision) { Unload(); @@ -69,6 +76,10 @@ public void Commit(int buildRevision, TimeSpan inactiveTimeout) { Unload(); } + else + { + SaveModel(); + } } public void Touch() @@ -76,14 +87,22 @@ public void Touch() LastUsedTime = DateTime.UtcNow; } + private void SaveModel() + { + if (_wordAlignmentModel is not null && IsUpdated && !IsMarkedForDeletion) + { + _wordAlignmentModel.Save(); + } + } + private void Unload() { - if (_wordAlignmentEngine is null) + if (_wordAlignmentModel is null) return; - _wordAlignmentEngine.Dispose(); + _wordAlignmentModel.Dispose(); - _wordAlignmentEngine = null; + _wordAlignmentModel = null; CurrentBuildRevision = -1; } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineStateService.cs b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineStateService.cs index 03f03038..6faf8ef2 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineStateService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentEngineStateService.cs @@ -27,23 +27,28 @@ public void Remove(string engineId) public async Task CommitAsync( IDistributedReaderWriterLockFactory lockFactory, - IRepository engines, + IRepository engines, TimeSpan inactiveTimeout, CancellationToken cancellationToken = default ) { foreach (WordAlignmentEngineState state in _engineStates.Values) { + if (!state.IsLoaded || state.IsMarkedForDeletion) + { + continue; + } + try { IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(state.EngineId, cancellationToken); await @lock.WriterLockAsync( async ct => { - TranslationEngine? engine = await engines.GetAsync(state.EngineId, ct); - if (engine is not null && !(engine.CollectTrainSegmentPairs ?? false)) + WordAlignmentEngine? engine = await engines.GetAsync(state.EngineId, ct); + if (engine is not null) // there is no way to cancel this call - state.Commit(engine.BuildRevision, inactiveTimeout); + state.Commit(engine!.BuildRevision, inactiveTimeout); }, _options.CurrentValue.EngineCommitTimeout, cancellationToken: cancellationToken diff --git a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentModelFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentModelFactory.cs index ac97d382..ff87abc7 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentModelFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentModelFactory.cs @@ -26,13 +26,17 @@ IParallelTextCorpus corpus { var modelPath = Path.Combine(engineDir, "tm", "src_trg"); var directModel = ThotWordAlignmentModel.Create(ThotWordAlignmentModelType.Hmm); + directModel.SourceTokenizer = tokenizer; + directModel.TargetTokenizer = tokenizer; directModel.Load(modelPath + "_invswm"); var inverseModel = ThotWordAlignmentModel.Create(ThotWordAlignmentModelType.Hmm); + inverseModel.SourceTokenizer = tokenizer; + inverseModel.TargetTokenizer = tokenizer; inverseModel.Load(modelPath + "_swm"); - ITrainer directTrainer = directModel.CreateTrainer(corpus, tokenizer); - ITrainer inverseTrainer = inverseModel.CreateTrainer(corpus.Invert(), tokenizer); + ITrainer directTrainer = directModel.CreateTrainer(corpus); + ITrainer inverseTrainer = inverseModel.CreateTrainer(corpus.Invert()); return new SymmetrizedWordAlignmentModelTrainer(directTrainer, inverseTrainer); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs index be3147cb..8fd4e950 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs @@ -36,7 +36,7 @@ CancellationToken cancellationToken new(await SharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken)); await using Stream inferenceStream = await SharedFileService.OpenWriteAsync( - $"builds/{buildId}/word_alignment_inputs.json", + $"builds/{buildId}/word_alignments.inputs.json", cancellationToken ); await using Utf8JsonWriter inferenceWriter = new(inferenceStream, InferenceWriterOptions); @@ -48,17 +48,16 @@ await ParallelCorpusPreprocessingService.PreprocessAsync( corpora, async row => { - if (row.SourceSegment.Length > 0 || row.TargetSegment.Length > 0) + if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) { await sourceTrainWriter.WriteAsync($"{row.SourceSegment}\n"); await targetTrainWriter.WriteAsync($"{row.TargetSegment}\n"); - } - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) trainCount++; + } }, - async (row, corpus) => + async (row, isInTrainingData, corpus) => { - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) + if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0 && !isInTrainingData) { inferenceWriter.WriteStartObject(); inferenceWriter.WriteString("corpusId", corpus.Id); diff --git a/src/Machine/src/Serval.Machine.Shared/Utils/AsyncDisposableBase.cs b/src/Machine/src/Serval.Machine.Shared/Utils/AsyncDisposableBase.cs deleted file mode 100644 index 6c4a5c0f..00000000 --- a/src/Machine/src/Serval.Machine.Shared/Utils/AsyncDisposableBase.cs +++ /dev/null @@ -1,19 +0,0 @@ -using SIL.ObjectModel; - -namespace Serval.Machine.Shared.Utils; - -public class AsyncDisposableBase : DisposableBase, IAsyncDisposable -{ - public async ValueTask DisposeAsync() - { - await DisposeAsyncCore(); - - Dispose(false); - GC.SuppressFinalize(this); - } - - protected virtual ValueTask DisposeAsyncCore() - { - return default; - } -} diff --git a/src/Machine/src/Serval.Machine.Shared/Utils/AsyncTimer.cs b/src/Machine/src/Serval.Machine.Shared/Utils/AsyncTimer.cs deleted file mode 100644 index 86bcbe75..00000000 --- a/src/Machine/src/Serval.Machine.Shared/Utils/AsyncTimer.cs +++ /dev/null @@ -1,70 +0,0 @@ -namespace Serval.Machine.Shared.Utils; - -public class AsyncTimer : AsyncDisposableBase -{ - private readonly Timer _timer; - private readonly Func _callback; - private readonly AsyncLock _lock; - private bool _running; - - public AsyncTimer(Func callback) - { - _callback = callback; - _lock = new AsyncLock(); - _timer = new Timer(FireTimerAsync, null, Timeout.Infinite, Timeout.Infinite); - } - - public void Start(TimeSpan period) - { - _running = true; - _timer.Change(period, period); - } - - private async void FireTimerAsync(object? state) - { - using (await _lock.LockAsync()) - { - if (_running) - await _callback(); - } - } - - public async Task StopAsync() - { - using (await _lock.LockAsync()) - { - // FireTimer is *not* running _callback (since we got the lock) - StopTimer(); - } - // Now FireTimer will *never* run _callback - } - - public void Stop() - { - using (_lock.Lock()) - { - // FireTimer is *not* running _callback (since we got the lock) - StopTimer(); - } - // Now FireTimer will *never* run _callback - } - - private void StopTimer() - { - _timer.Change(Timeout.Infinite, Timeout.Infinite); - _running = false; - } - - protected override async ValueTask DisposeAsyncCore() - { - await base.DisposeAsyncCore(); - await StopAsync(); - _timer.Dispose(); - } - - protected override void DisposeManagedResources() - { - Stop(); - _timer.Dispose(); - } -} diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs index 1ad5fc53..8d4c1f65 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs @@ -519,246 +519,6 @@ public async Task ParallelCorpusAsync() }); } - [Test] - public async Task ParallelCorpusAsync_UseKeyTerms() - { - using TestEnvironment env = new(); - var corpora = new List() - { - new ParallelCorpus() - { - Id = "1", - SourceCorpora = new List() - { - new() - { - Id = "_1", - Language = "en", - Files = new List { env.ParatextFile("pt-source1") }, - TrainOnChapters = new() - { - { - "MAT", - new() { 1 } - }, - { - "LEV", - new() { } - } - }, - PretranslateChapters = new() - { - { - "1CH", - new() { } - } - } - }, - new() - { - Id = "_1", - Language = "en", - Files = new List { env.ParatextFile("pt-source2") }, - TrainOnChapters = new() - { - { - "MAT", - new() { 1 } - }, - { - "MRK", - new() { } - } - }, - PretranslateChapters = new() { } - }, - }, - TargetCorpora = new List() - { - new() - { - Id = "_1", - Language = "en", - Files = new List { env.ParatextFile("pt-target1") }, - TrainOnChapters = new() - { - { - "MAT", - new() { 1 } - }, - { - "MRK", - new() { } - } - } - }, - new() - { - Id = "_2", - Language = "en", - Files = new List { env.ParatextFile("pt-target2") }, - TrainOnChapters = new() - { - { - "MAT", - new() { 1 } - }, - { - "MRK", - new() { } - }, - { - "LEV", - new() { } - } - } - } - } - } - }; - await env.RunBuildJobAsync(corpora, useKeyTerms: true); - string source = await env.GetSourceExtractAsync(); - string target = await env.GetTargetExtractAsync(); - Assert.Multiple(() => - { - StringAssert.StartsWith( - @"Source one, chapter fourteen, verse fifty-five. Segment b. -Source one, chapter fourteen, verse fifty-six. -Source two, chapter one, verse one. -Source two, chapter one, verse two. -Source two, chapter one, verse three. -Source one, chapter one, verse four. -Source two, chapter one, verse five. Source two, chapter one, verse six. -Source one, chapter one, verse seven, eight, and nine. Source one, chapter one, verse ten. -Source two, chapter one, verse one. -", - source - ); - StringAssert.StartsWith( - @"Target two, chapter fourteen, verse fifty-five. -Target two, chapter fourteen, verse fifty-six. -Target one, chapter one, verse one. -Target one, chapter one, verse two. -Target one, chapter one, verse three. - -Target one, chapter one, verse five and six. -Target one, chapter one, verse seven and eight. Target one, chapter one, verse nine and ten. - -", - target - ); - StringAssert.Contains("Abraham", source); - StringAssert.Contains("Abraham", target); - StringAssert.DoesNotContain("Zedekiah", source); - StringAssert.DoesNotContain("Zedekiah", target); - }); - JsonArray? pretranslations = await env.GetPretranslationsAsync(); - Assert.That(pretranslations, Is.Not.Null); - Assert.That(pretranslations!.Count, Is.EqualTo(7), pretranslations.ToJsonString()); - Assert.That( - pretranslations[2]!["translation"]!.ToString(), - Is.EqualTo("Source one, chapter twelve, verse one.") - ); - } - - [Test] - public async Task ParallelCorpusAsync_UseKeyTerms_TextIds() - { - using TestEnvironment env = new(); - var corpora = new List() - { - new ParallelCorpus() - { - Id = "1", - SourceCorpora = new List() - { - new() - { - Id = "_1", - Language = "en", - Files = new List { env.ParatextFile("pt-source1") }, - TrainOnTextIds = ["MAT", "LEV"], - PretranslateTextIds = ["1CH"] - }, - new() - { - Id = "_1", - Language = "en", - Files = new List { env.ParatextFile("pt-source2") }, - TrainOnTextIds = ["MAT", "MRK"], - PretranslateTextIds = [] - }, - }, - TargetCorpora = new List() - { - new() - { - Id = "_1", - Language = "en", - Files = new List { env.ParatextFile("pt-target1") }, - TrainOnTextIds = ["MAT", "MRK"] - }, - new() - { - Id = "_2", - Language = "en", - Files = new List { env.ParatextFile("pt-target2") }, - TrainOnTextIds = ["MAT", "MRK", "LEV"] - } - } - } - }; - await env.RunBuildJobAsync(corpora, useKeyTerms: true); - string source = await env.GetSourceExtractAsync(); - string target = await env.GetTargetExtractAsync(); - Assert.Multiple(() => - { - StringAssert.StartsWith( - @"Source one, chapter fourteen, verse fifty-five. Segment b. -Source one, chapter fourteen, verse fifty-six. -Source two, chapter one, verse one. -Source two, chapter one, verse two. -Source two, chapter one, verse three. -Source one, chapter one, verse four. -Source two, chapter one, verse five. Source two, chapter one, verse six. -Source one, chapter one, verse seven, eight, and nine. Source one, chapter one, verse ten. -Source one, chapter two, verse one. -Source one, chapter two, verse two. - -Source two, chapter one, verse one. -", - source - ); - StringAssert.StartsWith( - @"Target two, chapter fourteen, verse fifty-five. -Target two, chapter fourteen, verse fifty-six. -Target one, chapter one, verse one. -Target one, chapter one, verse two. -Target one, chapter one, verse three. - -Target one, chapter one, verse five and six. -Target one, chapter one, verse seven and eight. Target one, chapter one, verse nine and ten. -Target one, chapter two, verse one. - -Target one, chapter two, verse three. - -", - target - ); - StringAssert.Contains("Abraham", source); - StringAssert.Contains("Abraham", target); - StringAssert.DoesNotContain("Zedekiah", source); - StringAssert.DoesNotContain("Zedekiah", target); - }); - JsonArray? pretranslations = await env.GetPretranslationsAsync(); - Assert.That(pretranslations, Is.Not.Null); - Assert.That(pretranslations!.Count, Is.EqualTo(7), pretranslations.ToJsonString()); - Assert.That( - pretranslations[2]!["translation"]!.ToString(), - Is.EqualTo("Source one, chapter twelve, verse one.") - ); - } - private class TestEnvironment : DisposableBase { private static readonly string TestDataPath = Path.Combine( diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/ServalPlatformOutboxMessageHandlerTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/ServalPlatformOutboxMessageHandlerTests.cs index 9afaeb8f..baf3c213 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/ServalPlatformOutboxMessageHandlerTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/ServalPlatformOutboxMessageHandlerTests.cs @@ -12,8 +12,12 @@ public async Task HandleMessageAsync_BuildStarted() TestEnvironment env = new(); await env.Handler.HandleMessageAsync( + "groupId", ServalTranslationPlatformOutboxConstants.BuildStarted, - JsonSerializer.Serialize(new BuildStartedRequest { BuildId = "C" }), + JsonSerializer.Serialize( + new BuildStartedRequest { BuildId = "C" }, + MessageOutboxOptions.JsonSerializerOptions + ), null ); @@ -43,6 +47,7 @@ await JsonSerializer.SerializeAsync( ); stream.Seek(0, SeekOrigin.Begin); await env.Handler.HandleMessageAsync( + "engine1", ServalTranslationPlatformOutboxConstants.InsertPretranslations, "engine1", stream diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs index 70a8859a..4a89ad59 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs @@ -96,7 +96,7 @@ public async Task CancelBuildAsync_Building(BuildJobRunnerType trainJobRunnerTyp await env.WaitForTrainingToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); - Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); + Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active)); await env.Service.CancelBuildAsync(EngineId1); await env.WaitForBuildToFinishAsync(); _ = env.SmtBatchTrainer.DidNotReceive().SaveAsync(); @@ -122,12 +122,12 @@ public async Task StartBuildAsync_RestartUnfinishedBuild() await env.WaitForTrainingToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); - Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); + Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active)); env.StopServer(); await env.WaitForBuildToRestartAsync(); engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); - Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Pending)); + Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Pending)); _ = env.PlatformService.Received().BuildRestartingAsync(BuildId1); env.SmtBatchTrainer.ClearSubstitute(ClearOptions.CallActions); env.StartServer(); @@ -147,7 +147,7 @@ public async Task DeleteAsync_WhileBuilding(BuildJobRunnerType trainJobRunnerTyp await env.WaitForTrainingToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); - Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); + Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active)); await env.Service.DeleteAsync(EngineId1); await env.WaitForBuildToFinishAsync(); await env.WaitForAllHangfireJobsToFinishAsync(); @@ -167,7 +167,7 @@ public async Task TrainSegmentPairAsync(BuildJobRunnerType trainJobRunnerType) await env.WaitForBuildToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); - Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); + Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active)); await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true); env.StopTraining(); await env.WaitForBuildToFinishAsync(); diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/StatisticalEngineServiceTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/StatisticalEngineServiceTests.cs new file mode 100644 index 00000000..aeb8a332 --- /dev/null +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/StatisticalEngineServiceTests.cs @@ -0,0 +1,504 @@ +using WordAlignmentResult = Serval.WordAlignment.V1.WordAlignmentResult; + +namespace Serval.Machine.Shared.Services; + +[TestFixture] +public class StatisticalEngineServiceTests +{ + const string EngineId1 = "engine1"; + const string EngineId2 = "engine2"; + const string BuildId1 = "build1"; + const string CorpusId1 = "corpus1"; + + [Test] + public async Task CreateAsync() + { + using var env = new TestEnvironment(); + await env.Service.CreateAsync(EngineId2, "Engine 2", "es", "en"); + WordAlignmentEngine? engine = await env.Engines.GetAsync(e => e.EngineId == EngineId2); + Assert.Multiple(() => + { + Assert.That(engine, Is.Not.Null); + Assert.That(engine?.EngineId, Is.EqualTo(EngineId2)); + Assert.That(engine?.BuildRevision, Is.EqualTo(0)); + }); + string engineDir = Path.Combine("word_alignment_engines", EngineId2); + env.WordAlignmentModelFactory.Received().InitNew(engineDir); + } + + [TestCase(BuildJobRunnerType.Hangfire)] + [TestCase(BuildJobRunnerType.ClearML)] + public async Task StartBuildAsync(BuildJobRunnerType trainJobRunnerType) + { + using var env = new TestEnvironment(trainJobRunnerType); + WordAlignmentEngine engine = env.Engines.Get(EngineId1); + Assert.That(engine.BuildRevision, Is.EqualTo(1)); + // ensure that the model was loaded before training + await env.Service.GetBestWordAlignmentAsync(EngineId1, "esto es una prueba.", "this is a test."); + await env.Service.StartBuildAsync( + EngineId1, + BuildId1, + null, + [ + new ParallelCorpus() + { + Id = CorpusId1, + SourceCorpora = new List() + { + new() + { + Id = "src", + Language = "es", + Files = [], + TrainOnTextIds = null, + InferenceTextIds = null + } + }, + TargetCorpora = new List() + { + new() + { + Id = "trg", + Language = "en", + Files = [], + TrainOnTextIds = null + } + }, + } + ] + ); + await env.WaitForBuildToFinishAsync(); + _ = env.WordAlignmentBatchTrainer.Received() + .TrainAsync(Arg.Any>(), Arg.Any()); + _ = env.WordAlignmentBatchTrainer.Received().SaveAsync(Arg.Any()); + engine = env.Engines.Get(EngineId1); + Assert.That(engine.CurrentBuild, Is.Null); + Assert.That(engine.BuildRevision, Is.EqualTo(2)); + // check if model was reloaded upon first use after training + env.WordAlignmentModel.ClearReceivedCalls(); + await env.Service.GetBestWordAlignmentAsync(EngineId1, "esto es una prueba.", "this is a test."); + env.WordAlignmentModel.Received().Dispose(); + } + + [TestCase(BuildJobRunnerType.Hangfire)] + [TestCase(BuildJobRunnerType.ClearML)] + public async Task CancelBuildAsync_Building(BuildJobRunnerType trainJobRunnerType) + { + using var env = new TestEnvironment(trainJobRunnerType); + env.UseInfiniteTrainJob(); + + await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); + await env.WaitForTrainingToStartAsync(); + WordAlignmentEngine engine = env.Engines.Get(EngineId1); + Assert.That(engine.CurrentBuild, Is.Not.Null); + Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active)); + await env.Service.CancelBuildAsync(EngineId1); + await env.WaitForBuildToFinishAsync(); + _ = env.WordAlignmentBatchTrainer.DidNotReceive().SaveAsync(); + engine = env.Engines.Get(EngineId1); + Assert.That(engine.CurrentBuild, Is.Null); + } + + [TestCase(BuildJobRunnerType.Hangfire)] + [TestCase(BuildJobRunnerType.ClearML)] + public void CancelBuildAsync_NotBuilding(BuildJobRunnerType trainJobRunnerType) + { + using var env = new TestEnvironment(trainJobRunnerType); + Assert.ThrowsAsync(() => env.Service.CancelBuildAsync(EngineId1)); + } + + [TestCase(BuildJobRunnerType.Hangfire)] + [TestCase(BuildJobRunnerType.ClearML)] + public async Task DeleteAsync_WhileBuilding(BuildJobRunnerType trainJobRunnerType) + { + using var env = new TestEnvironment(trainJobRunnerType); + env.UseInfiniteTrainJob(); + + await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); + await env.WaitForTrainingToStartAsync(); + WordAlignmentEngine engine = env.Engines.Get(EngineId1); + Assert.That(engine.CurrentBuild, Is.Not.Null); + Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active)); + await env.Service.DeleteAsync(EngineId1); + await env.WaitForBuildToFinishAsync(); + await env.WaitForAllHangfireJobsToFinishAsync(); + _ = env.WordAlignmentBatchTrainer.DidNotReceive().SaveAsync(); + Assert.That(env.Engines.Contains(EngineId1), Is.False); + } + + [Test] + public async Task GetBestWordAlignment() + { + using var env = new TestEnvironment(); + WordAlignmentResult result = await env.Service.GetBestWordAlignmentAsync( + EngineId1, + "esto es una prueba.", + "this is a test." + ); + Assert.That(string.Join(' ', result.TargetTokens), Is.EqualTo("this is a test .")); + Assert.That(result.Confidences, Has.Count.EqualTo(5)); + Assert.That(result.Alignment.First().SourceIndex, Is.EqualTo(0)); + Assert.That(result.Alignment.First().TargetIndex, Is.EqualTo(0)); + } + + private class TestEnvironment : DisposableBase + { + private readonly Hangfire.InMemory.InMemoryStorage _memoryStorage; + private readonly BackgroundJobClient _jobClient; + private BackgroundJobServer _jobServer; + private readonly IDistributedReaderWriterLockFactory _lockFactory; + private readonly BuildJobRunnerType _trainJobRunnerType; + private Task? _trainJobTask; + private readonly CancellationTokenSource _cancellationTokenSource = new(); + private bool _training = true; + + public TestEnvironment(BuildJobRunnerType trainJobRunnerType = BuildJobRunnerType.ClearML) + { + _trainJobRunnerType = trainJobRunnerType; + Engines = new MemoryRepository(); + Engines.Add( + new WordAlignmentEngine + { + Id = EngineId1, + EngineId = EngineId1, + Type = EngineType.Statistical, + SourceLanguage = "es", + TargetLanguage = "en", + BuildRevision = 1, + } + ); + _memoryStorage = new Hangfire.InMemory.InMemoryStorage(); + _jobClient = new BackgroundJobClient(_memoryStorage); + PlatformService = Substitute.For(); + PlatformService.EngineGroup.Returns(EngineGroup.WordAlignment); + WordAlignmentModel = Substitute.For(); + WordAlignmentBatchTrainer = Substitute.For(); + WordAlignmentBatchTrainer.Stats.Returns(new TrainStats { TrainCorpusSize = 0 }); + WordAlignmentModelFactory = CreateWordAlignmentModelFactory(); + _lockFactory = new DistributedReaderWriterLockFactory( + new OptionsWrapper(new ServiceOptions { ServiceId = "host" }), + new OptionsWrapper(new DistributedReaderWriterLockOptions()), + new MemoryRepository(), + new ObjectIdGenerator() + ); + SharedFileService = new SharedFileService(Substitute.For()); + var clearMLOptions = Substitute.For>(); + clearMLOptions.CurrentValue.Returns(new ClearMLOptions()); + var buildJobOptions = Substitute.For>(); + buildJobOptions.CurrentValue.Returns( + new BuildJobOptions + { + ClearML = + [ + new ClearMLBuildQueue() + { + EngineType = EngineType.Statistical.ToString(), + ModelType = "thot", + DockerImage = "default", + Queue = "default" + } + ] + } + ); + ClearMLService = Substitute.For(); + ClearMLService + .GetProjectIdAsync("engine1", Arg.Any()) + .Returns(Task.FromResult("project1")); + ClearMLService + .CreateTaskAsync( + "build1", + "project1", + Arg.Any(), + Arg.Any(), + Arg.Any() + ) + .Returns(Task.FromResult("job1")); + ClearMLService + .When(x => x.EnqueueTaskAsync("job1", Arg.Any(), Arg.Any())) + .Do(_ => _trainJobTask = Task.Run(RunTrainJob)); + ClearMLService + .When(x => x.StopTaskAsync("job1", Arg.Any())) + .Do(_ => _cancellationTokenSource.Cancel()); + ClearMLMonitorService = new ClearMLMonitorService( + Substitute.For(), + ClearMLService, + SharedFileService, + clearMLOptions, + buildJobOptions, + Substitute.For>() + ); + BuildJobService = new BuildJobService( + [ + new HangfireBuildJobRunner(_jobClient, [new StatisticalHangfireBuildJobFactory()]), + new ClearMLBuildJobRunner( + ClearMLService, + [new StatisticalClearMLBuildJobFactory(SharedFileService, Engines)], + buildJobOptions + ) + ], + Engines + ); + _jobServer = CreateJobServer(); + StateService = CreateStateService(); + Service = CreateService(); + } + + public StatisticalEngineService Service { get; private set; } + public WordAlignmentEngineStateService StateService { get; private set; } + public MemoryRepository Engines { get; } + public IWordAlignmentModelFactory WordAlignmentModelFactory { get; } + public ITrainer WordAlignmentBatchTrainer { get; } + public IWordAlignmentModel WordAlignmentModel { get; } + public IPlatformService PlatformService { get; } + + public IClearMLService ClearMLService { get; } + public IClearMLQueueService ClearMLMonitorService { get; } + + public ISharedFileService SharedFileService { get; } + + public IBuildJobService BuildJobService { get; } + + public async Task CommitAsync(TimeSpan inactiveTimeout) + { + await StateService.CommitAsync(_lockFactory, Engines, inactiveTimeout); + } + + public void StopServer() + { + _jobServer.Dispose(); + StateService.Dispose(); + } + + public void StartServer() + { + _jobServer = CreateJobServer(); + StateService = CreateStateService(); + Service = CreateService(); + } + + public void UseInfiniteTrainJob() + { + WordAlignmentBatchTrainer.TrainAsync( + Arg.Any>(), + Arg.Do(cancellationToken => + { + while (_training) + { + cancellationToken.ThrowIfCancellationRequested(); + Thread.Sleep(100); + } + }) + ); + } + + public void StopTraining() + { + _training = false; + } + + private BackgroundJobServer CreateJobServer() + { + var jobServerOptions = new BackgroundJobServerOptions + { + Activator = new EnvActivator(this), + Queues = new[] { "statistical" }, + CancellationCheckInterval = TimeSpan.FromMilliseconds(50), + }; + return new BackgroundJobServer(jobServerOptions, _memoryStorage); + } + + private WordAlignmentEngineStateService CreateStateService() + { + var options = Substitute.For>(); + options.CurrentValue.Returns(new WordAlignmentEngineOptions()); + return new WordAlignmentEngineStateService( + WordAlignmentModelFactory, + options, + Substitute.For>() + ); + } + + private StatisticalEngineService CreateService() + { + return new StatisticalEngineService( + _lockFactory, + new[] { PlatformService }, + new MemoryDataAccessContext(), + Engines, + StateService, + BuildJobService, + ClearMLMonitorService + ); + } + + private IWordAlignmentModelFactory CreateWordAlignmentModelFactory() + { + IWordAlignmentModelFactory factory = Substitute.For(); + + var alignedWordPair = new AlignedWordPair(0, 0); + WordAlignmentModel + .GetBestAlignedWordPairs(Arg.Any>(), Arg.Any>()) + .Returns([alignedWordPair, alignedWordPair, alignedWordPair, alignedWordPair, alignedWordPair]); + factory.Create(Arg.Any()).Returns(WordAlignmentModel); + factory + .CreateTrainer( + Arg.Any(), + Arg.Any>(), + Arg.Any() + ) + .Returns(WordAlignmentBatchTrainer); + return factory; + } + + public async Task WaitForAllHangfireJobsToFinishAsync() + { + IMonitoringApi monitoringApi = _memoryStorage.GetMonitoringApi(); + while (monitoringApi.EnqueuedCount("statistical") > 0 || monitoringApi.ProcessingCount() > 0) + await Task.Delay(50); + } + + public async Task WaitForBuildToFinishAsync() + { + await WaitForBuildState(e => e.CurrentBuild is null); + if (_trainJobTask is not null) + await _trainJobTask; + } + + public Task WaitForBuildToStartAsync() + { + return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Active); + } + + public Task WaitForTrainingToStartAsync() + { + return WaitForBuildState(e => + e.CurrentBuild!.JobState is BuildJobState.Active && e.CurrentBuild!.Stage is BuildStage.Train + ); + } + + public Task WaitForBuildToRestartAsync() + { + return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Pending); + } + + private async Task WaitForBuildState(Func predicate) + { + using ISubscription subscription = await Engines.SubscribeAsync(e => + e.EngineId == EngineId1 + ); + while (true) + { + WordAlignmentEngine? engine = subscription.Change.Entity; + if (engine is null || predicate(engine)) + break; + await subscription.WaitForChangeAsync(); + } + } + + protected override void DisposeManagedResources() + { + StateService.Dispose(); + _jobServer.Dispose(); + } + + private async Task RunTrainJob() + { + try + { + await BuildJobService.BuildJobStartedAsync("engine1", "build1", _cancellationTokenSource.Token); + + string engineDir = Path.Combine("word_alignment_engines", EngineId1); + WordAlignmentModelFactory.InitNew(engineDir); + ITextCorpus sourceCorpus = new DictionaryTextCorpus(); + ITextCorpus targetCorpus = new DictionaryTextCorpus(); + IParallelTextCorpus parallelCorpus = sourceCorpus.AlignRows(targetCorpus); + LatinWordTokenizer tokenizer = new(); + using ITrainer wordAlignmentModelTrainer = WordAlignmentModelFactory.CreateTrainer( + engineDir, + tokenizer, + parallelCorpus + ); + await wordAlignmentModelTrainer.TrainAsync(null, _cancellationTokenSource.Token); + await wordAlignmentModelTrainer.SaveAsync(_cancellationTokenSource.Token); + + await using Stream engineStream = await SharedFileService.OpenWriteAsync( + $"builds/{BuildId1}/model.tar.gz", + _cancellationTokenSource.Token + ); + await using Stream targetStream = await SharedFileService.OpenWriteAsync( + $"builds/{BuildId1}/word_alignments.outputs.json", + _cancellationTokenSource.Token + ); + + await BuildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + EngineType.Statistical, + EngineId1, + BuildId1, + BuildStage.Postprocess, + data: (0, 0.0) + ); + } + catch (OperationCanceledException) + { + await BuildJobService.BuildJobFinishedAsync("engine1", "build1", buildComplete: false); + } + } + + private class EnvActivator(TestEnvironment env) : JobActivator + { + private readonly TestEnvironment _env = env; + + public override object ActivateJob(Type jobType) + { + if (jobType == typeof(WordAlignmentPreprocessBuildJob)) + { + return new WordAlignmentPreprocessBuildJob( + new[] { _env.PlatformService }, + _env.Engines, + new MemoryDataAccessContext(), + Substitute.For>(), + _env.BuildJobService, + _env.SharedFileService, + new ParallelCorpusPreprocessingService(new CorpusService()) + ) + { + TrainJobRunnerType = _env._trainJobRunnerType + }; + } + if (jobType == typeof(StatisticalPostprocessBuildJob)) + { + var engineOptions = Substitute.For>(); + engineOptions.CurrentValue.Returns(new WordAlignmentEngineOptions()); + var buildJobOptions = Substitute.For>(); + buildJobOptions.CurrentValue.Returns(new BuildJobOptions()); + return new StatisticalPostprocessBuildJob( + new[] { _env.PlatformService }, + _env.Engines, + new MemoryDataAccessContext(), + _env.BuildJobService, + Substitute.For>(), + _env.SharedFileService, + _env._lockFactory, + _env.WordAlignmentModelFactory, + buildJobOptions, + engineOptions + ); + } + if (jobType == typeof(StatisticalTrainBuildJob)) + { + return new StatisticalTrainBuildJob( + new[] { _env.PlatformService }, + _env.Engines, + new MemoryDataAccessContext(), + _env.BuildJobService, + Substitute.For>(), + _env.SharedFileService, + _env.WordAlignmentModelFactory + ); + } + return base.ActivateJob(jobType); + } + } + } +} diff --git a/src/Serval/src/Serval.Client/Client.g.cs b/src/Serval/src/Serval.Client/Client.g.cs index d222c5bf..e6f42d12 100644 --- a/src/Serval/src/Serval.Client/Client.g.cs +++ b/src/Serval/src/Serval.Client/Client.g.cs @@ -7137,8 +7137,8 @@ public partial interface IWordAlignmentEnginesClient ///
* An auto-generated reference of `[TextId]:[lineNumber]`, 1 indexed. ///
* **SourceTokens**: the tokenized source segment ///
* **TargetTokens**: the tokenized target segment - ///
* **Confidences**: the confidence of the alignment ona scale from 0 to 1 - ///
* **Alignment**: the word alignment, 0 indexed for source and target positions + ///
* **Confidences**: the confidence of the alignment on a scale from 0 to 1 + ///
* **Alignment**: a list of aligned word pairs ///
///
Word alignments can be filtered by text id if provided. ///
Only word alignments for the most recent successful build of the engine are returned. @@ -8406,8 +8406,8 @@ public string BaseUrl ///
* An auto-generated reference of `[TextId]:[lineNumber]`, 1 indexed. ///
* **SourceTokens**: the tokenized source segment ///
* **TargetTokens**: the tokenized target segment - ///
* **Confidences**: the confidence of the alignment ona scale from 0 to 1 - ///
* **Alignment**: the word alignment, 0 indexed for source and target positions + ///
* **Confidences**: the confidence of the alignment on a scale from 0 to 1 + ///
* **Alignment**: a list of aligned word pairs ///
///
Word alignments can be filtered by text id if provided. ///
Only word alignments for the most recent successful build of the engine are returned. @@ -10679,7 +10679,7 @@ public partial class WordAlignmentBuild public System.Collections.Generic.IList? TrainOn { get; set; } = default!; [Newtonsoft.Json.JsonProperty("wordAlignOn", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] - public System.Collections.Generic.IList? WordAlignOn { get; set; } = default!; + public System.Collections.Generic.IList? WordAlignOn { get; set; } = default!; [Newtonsoft.Json.JsonProperty("step", Required = Newtonsoft.Json.Required.Always)] public int Step { get; set; } = default!; @@ -10710,6 +10710,9 @@ public partial class WordAlignmentBuild [Newtonsoft.Json.JsonProperty("deploymentVersion", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] public string? DeploymentVersion { get; set; } = default!; + [Newtonsoft.Json.JsonProperty("executionData", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public System.Collections.Generic.IDictionary? ExecutionData { get; set; } = default!; + } [System.CodeDom.Compiler.GeneratedCode("NJsonSchema", "14.1.0.0 (NJsonSchema v11.0.2.0 (Newtonsoft.Json v13.0.0.0))")] @@ -10741,6 +10744,20 @@ public partial class ParallelCorpusFilter2 } + [System.CodeDom.Compiler.GeneratedCode("NJsonSchema", "14.1.0.0 (NJsonSchema v11.0.2.0 (Newtonsoft.Json v13.0.0.0))")] + public partial class WordAlignmentCorpus + { + [Newtonsoft.Json.JsonProperty("parallelCorpus", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public ResourceLink? ParallelCorpus { get; set; } = default!; + + [Newtonsoft.Json.JsonProperty("sourceFilters", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public System.Collections.Generic.IList? SourceFilters { get; set; } = default!; + + [Newtonsoft.Json.JsonProperty("targetFilters", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public System.Collections.Generic.IList? TargetFilters { get; set; } = default!; + + } + [System.CodeDom.Compiler.GeneratedCode("NJsonSchema", "14.1.0.0 (NJsonSchema v11.0.2.0 (Newtonsoft.Json v13.0.0.0))")] public partial class WordAlignmentBuildConfig { @@ -10751,7 +10768,7 @@ public partial class WordAlignmentBuildConfig public System.Collections.Generic.IList? TrainOn { get; set; } = default!; [Newtonsoft.Json.JsonProperty("wordAlignOn", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] - public System.Collections.Generic.IList? WordAlignOn { get; set; } = default!; + public System.Collections.Generic.IList? WordAlignOn { get; set; } = default!; [Newtonsoft.Json.JsonProperty("options", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] public object? Options { get; set; } = default!; @@ -10787,6 +10804,20 @@ public partial class ParallelCorpusFilterConfig2 } + [System.CodeDom.Compiler.GeneratedCode("NJsonSchema", "14.1.0.0 (NJsonSchema v11.0.2.0 (Newtonsoft.Json v13.0.0.0))")] + public partial class WordAlignmentCorpusConfig + { + [Newtonsoft.Json.JsonProperty("parallelCorpusId", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public string? ParallelCorpusId { get; set; } = default!; + + [Newtonsoft.Json.JsonProperty("sourceFilters", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public System.Collections.Generic.IList? SourceFilters { get; set; } = default!; + + [Newtonsoft.Json.JsonProperty("targetFilters", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public System.Collections.Generic.IList? TargetFilters { get; set; } = default!; + + } + [System.CodeDom.Compiler.GeneratedCode("NSwag", "14.1.0.0 (NJsonSchema v11.0.2.0 (Newtonsoft.Json v13.0.0.0))")] public partial class FileParameter { diff --git a/src/Serval/src/Serval.Grpc/Protos/serval/translation/v1/platform.proto b/src/Serval/src/Serval.Grpc/Protos/serval/translation/v1/platform.proto index fb608231..4c711e61 100644 --- a/src/Serval/src/Serval.Grpc/Protos/serval/translation/v1/platform.proto +++ b/src/Serval/src/Serval.Grpc/Protos/serval/translation/v1/platform.proto @@ -12,7 +12,7 @@ service TranslationPlatformApi { rpc BuildFaulted(BuildFaultedRequest) returns (google.protobuf.Empty); rpc BuildRestarting(BuildRestartingRequest) returns (google.protobuf.Empty); - rpc IncrementTranslationEngineCorpusSize(IncrementTranslationEngineCorpusSizeRequest) returns (google.protobuf.Empty); + rpc IncrementTrainEngineCorpusSize(IncrementTrainEngineCorpusSizeRequest) returns (google.protobuf.Empty); rpc InsertPretranslations(stream InsertPretranslationsRequest) returns (google.protobuf.Empty); rpc UpdateBuildExecutionData(UpdateBuildExecutionDataRequest) returns (google.protobuf.Empty); } @@ -53,7 +53,7 @@ message IncrementTrainEngineCorpusSizeRequest { int32 count = 2; } -message InsertInferencesRequest { +message InsertPretranslationsRequest { string engine_id = 1; string corpus_id = 2; string text_id = 3; diff --git a/src/Serval/src/Serval.Grpc/Protos/serval/word_alignment/v1/platform.proto b/src/Serval/src/Serval.Grpc/Protos/serval/word_alignment/v1/platform.proto index 7db42106..10f23dd0 100644 --- a/src/Serval/src/Serval.Grpc/Protos/serval/word_alignment/v1/platform.proto +++ b/src/Serval/src/Serval.Grpc/Protos/serval/word_alignment/v1/platform.proto @@ -15,7 +15,9 @@ service WordAlignmentPlatformApi { rpc BuildRestarting(BuildRestartingRequest) returns (google.protobuf.Empty); rpc IncrementTrainEngineCorpusSize(IncrementTrainEngineCorpusSizeRequest) returns (google.protobuf.Empty); - rpc InsertInferences(stream InsertInferencesRequest) returns (google.protobuf.Empty); + rpc InsertWordAlignments(stream InsertWordAlignmentsRequest) returns (google.protobuf.Empty); + rpc UpdateBuildExecutionData(UpdateBuildExecutionDataRequest) returns (google.protobuf.Empty); + } message UpdateBuildStatusRequest { @@ -54,7 +56,7 @@ message IncrementTrainEngineCorpusSizeRequest { int32 count = 2; } -message InsertInferencesRequest { +message InsertWordAlignmentsRequest { string engine_id = 1; string corpus_id = 2; string text_id = 3; @@ -64,3 +66,9 @@ message InsertInferencesRequest { repeated double confidences = 7; repeated AlignedWordPair alignment = 8; } + +message UpdateBuildExecutionDataRequest { + string engine_id = 1; + string build_id = 2; + map execution_data = 3; +} diff --git a/src/Serval/src/Serval.Translation/Services/EngineService.cs b/src/Serval/src/Serval.Translation/Services/EngineService.cs index 59214306..4d2946ec 100644 --- a/src/Serval/src/Serval.Translation/Services/EngineService.cs +++ b/src/Serval/src/Serval.Translation/Services/EngineService.cs @@ -143,15 +143,11 @@ public override async Task CreateAsync(Engine engine, CancellationToken { engine.DateCreated = DateTime.UtcNow; await Entities.InsertAsync(engine, cancellationToken); - TranslationEngineApi.TranslationEngineApiClient? client; - try - { - client = _grpcClientFactory.CreateClient(engine.Type); - } - catch (InvalidOperationException) - { + TranslationEngineApi.TranslationEngineApiClient? client = + _grpcClientFactory.CreateClient(engine.Type); + if (client is null) throw new InvalidOperationException($"'{engine.Type}' is an invalid engine type."); - } + var request = new CreateRequest { EngineType = engine.Type, diff --git a/src/Serval/src/Serval.Translation/Services/TranslationPlatformServiceV1.cs b/src/Serval/src/Serval.Translation/Services/TranslationPlatformServiceV1.cs index 6ef34998..1bc9ff8d 100644 --- a/src/Serval/src/Serval.Translation/Services/TranslationPlatformServiceV1.cs +++ b/src/Serval/src/Serval.Translation/Services/TranslationPlatformServiceV1.cs @@ -284,8 +284,8 @@ await _builds.UpdateAsync( return new Empty(); } - public override async Task IncrementTranslationEngineCorpusSize( - IncrementTranslationEngineCorpusSizeRequest request, + public override async Task IncrementTrainEngineCorpusSize( + IncrementTrainEngineCorpusSizeRequest request, ServerCallContext context ) { @@ -297,8 +297,8 @@ await _engines.UpdateAsync( return Empty; } - public override async Task InsertInferences( - IAsyncStreamReader requestStream, + public override async Task InsertPretranslations( + IAsyncStreamReader requestStream, ServerCallContext context ) { @@ -306,7 +306,7 @@ ServerCallContext context int nextModelRevision = 0; var batch = new List(); - await foreach (InsertInferencesRequest request in requestStream.ReadAllAsync(context.CancellationToken)) + await foreach (InsertPretranslationsRequest request in requestStream.ReadAllAsync(context.CancellationToken)) { if (request.EngineId != engineId) { diff --git a/src/Serval/src/Serval.WordAlignment/Configuration/IMongoDataAccessConfiguratorExtensions.cs b/src/Serval/src/Serval.WordAlignment/Configuration/IMongoDataAccessConfiguratorExtensions.cs index aababccb..c617d23c 100644 --- a/src/Serval/src/Serval.WordAlignment/Configuration/IMongoDataAccessConfiguratorExtensions.cs +++ b/src/Serval/src/Serval.WordAlignment/Configuration/IMongoDataAccessConfiguratorExtensions.cs @@ -30,6 +30,11 @@ await c.Indexes.CreateOrUpdateAsync( await c.Indexes.CreateOrUpdateAsync( new CreateIndexModel(Builders.IndexKeys.Ascending(b => b.DateCreated)) ); + // migrate by adding ExecutionData field + await c.UpdateManyAsync( + Builders.Filter.Exists(b => b.ExecutionData, false), + Builders.Update.Set(b => b.ExecutionData, new Dictionary()) + ); } ); configurator.AddRepository( diff --git a/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentBuildConfigDto.cs b/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentBuildConfigDto.cs index 3a79b0e7..115c78f8 100644 --- a/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentBuildConfigDto.cs +++ b/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentBuildConfigDto.cs @@ -4,7 +4,7 @@ public record WordAlignmentBuildConfigDto { public string? Name { get; init; } public IReadOnlyList? TrainOn { get; init; } - public IReadOnlyList? WordAlignOn { get; init; } + public IReadOnlyList? WordAlignOn { get; init; } /// /// { diff --git a/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentBuildDto.cs b/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentBuildDto.cs index 9fc55652..aa66a587 100644 --- a/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentBuildDto.cs +++ b/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentBuildDto.cs @@ -8,7 +8,7 @@ public record WordAlignmentBuildDto public string? Name { get; init; } public required ResourceLinkDto Engine { get; init; } public IReadOnlyList? TrainOn { get; init; } - public IReadOnlyList? WordAlignOn { get; init; } + public IReadOnlyList? WordAlignOn { get; init; } public required int Step { get; init; } public double? PercentCompleted { get; init; } public string? Message { get; init; } @@ -28,4 +28,5 @@ public record WordAlignmentBuildDto /// public object? Options { get; init; } public string? DeploymentVersion { get; init; } + public IReadOnlyDictionary? ExecutionData { get; init; } } diff --git a/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentCorpusConfigDto.cs b/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentCorpusConfigDto.cs new file mode 100644 index 00000000..b35329fc --- /dev/null +++ b/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentCorpusConfigDto.cs @@ -0,0 +1,8 @@ +namespace Serval.WordAlignment.Contracts; + +public record WordAlignmentCorpusConfigDto +{ + public string? ParallelCorpusId { get; init; } + public IReadOnlyList? SourceFilters { get; init; } + public IReadOnlyList? TargetFilters { get; init; } +} diff --git a/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentCorpusDto.cs b/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentCorpusDto.cs new file mode 100644 index 00000000..1ba2db86 --- /dev/null +++ b/src/Serval/src/Serval.WordAlignment/Contracts/WordAlignmentCorpusDto.cs @@ -0,0 +1,8 @@ +namespace Serval.WordAlignment.Contracts; + +public record WordAlignmentCorpusDto +{ + public ResourceLinkDto? ParallelCorpus { get; init; } + public IReadOnlyList? SourceFilters { get; init; } + public IReadOnlyList? TargetFilters { get; init; } +} diff --git a/src/Serval/src/Serval.WordAlignment/Controllers/WordAlignmentEnginesController.cs b/src/Serval/src/Serval.WordAlignment/Controllers/WordAlignmentEnginesController.cs index 22d6b57e..bd950e1b 100644 --- a/src/Serval/src/Serval.WordAlignment/Controllers/WordAlignmentEnginesController.cs +++ b/src/Serval/src/Serval.WordAlignment/Controllers/WordAlignmentEnginesController.cs @@ -9,6 +9,7 @@ public class WordAlignmentEnginesController( IBuildService buildService, IWordAlignmentService wordAlignmentService, IOptionsMonitor apiOptions, + IConfiguration configuration, IUrlService urlService, ILogger logger ) : ServalControllerBase(authService) @@ -22,6 +23,7 @@ ILogger logger private readonly IOptionsMonitor _apiOptions = apiOptions; private readonly IUrlService _urlService = urlService; private readonly ILogger _logger = logger; + private readonly IConfiguration _configuration = configuration; /// /// Get all word alignment engines @@ -385,8 +387,8 @@ CancellationToken cancellationToken /// * An auto-generated reference of `[TextId]:[lineNumber]`, 1 indexed. /// * **SourceTokens**: the tokenized source segment /// * **TargetTokens**: the tokenized target segment - /// * **Confidences**: the confidence of the alignment ona scale from 0 to 1 - /// * **Alignment**: the word alignment, 0 indexed for source and target positions + /// * **Confidences**: the confidence of the alignment on a scale from 0 to 1 + /// * **Alignment**: a list of aligned word pairs /// /// Word alignments can be filtered by text id if provided. /// Only word alignments for the most recent successful build of the engine are returned. @@ -568,9 +570,11 @@ public async Task> StartBuildAsync( CancellationToken cancellationToken ) { + string deploymentVersion = _configuration.GetValue("deploymentVersion") ?? "Unknown"; + Engine engine = await _engineService.GetAsync(id, cancellationToken); await AuthorizeAsync(engine); - Build build = Map(engine, buildConfig); + Build build = Map(engine, buildConfig, deploymentVersion); await _engineService.StartBuildAsync(build, cancellationToken); WordAlignmentBuildDto dto = Map(build); @@ -756,7 +760,7 @@ private WordAlignmentParallelCorpusDto Map(string engineId, ParallelCorpus sourc }; } - private static Build Map(Engine engine, WordAlignmentBuildConfigDto source) + private static Build Map(Engine engine, WordAlignmentBuildConfigDto source, string deploymentVersion) { return new Build { @@ -764,10 +768,56 @@ private static Build Map(Engine engine, WordAlignmentBuildConfigDto source) Name = source.Name, WordAlignOn = Map(engine, source.WordAlignOn), TrainOn = Map(engine, source.TrainOn), - Options = Map(source.Options) + Options = Map(source.Options), + DeploymentVersion = deploymentVersion }; } + private static List? Map(Engine engine, IReadOnlyList? source) + { + if (source is null) + return null; + + var corpusIds = new HashSet(engine.ParallelCorpora.Select(c => c.Id)); + var wordAlignmentCorpora = new List(); + foreach (WordAlignmentCorpusConfigDto cc in source) + { + if (cc.ParallelCorpusId == null) + { + throw new InvalidOperationException($"One of ParallelCorpusId and CorpusId must be set."); + } + if (!corpusIds.Contains(cc.ParallelCorpusId)) + { + throw new InvalidOperationException( + $"The parallel corpus {cc.ParallelCorpusId} is not valid: This parallel corpus does not exist for engine {engine.Id}." + ); + } + if ( + cc.SourceFilters != null + && cc.SourceFilters.Count > 0 + && ( + cc.SourceFilters.Select(sf => sf.CorpusId).Distinct().Count() > 1 + || cc.SourceFilters[0].CorpusId + != engine.ParallelCorpora.Where(pc => pc.Id == cc.ParallelCorpusId).First().SourceCorpora[0].Id + ) + ) + { + throw new InvalidOperationException( + $"Only the first source corpus in a parallel corpus may be filtered for pretranslation." + ); + } + wordAlignmentCorpora.Add( + new WordAlignmentCorpus + { + ParallelCorpusRef = cc.ParallelCorpusId, + SourceFilters = cc.SourceFilters?.Select(Map).ToList(), + TargetFilters = cc.TargetFilters?.Select(Map).ToList() + } + ); + } + return wordAlignmentCorpora; + } + private static List? Map(Engine engine, IReadOnlyList? source) { if (source is null) @@ -869,7 +919,8 @@ private WordAlignmentBuildDto Map(Build source) State = source.State, DateFinished = source.DateFinished, Options = source.Options, - DeploymentVersion = source.DeploymentVersion + DeploymentVersion = source.DeploymentVersion, + ExecutionData = source.ExecutionData }; } @@ -893,6 +944,26 @@ private TrainingCorpusDto Map(string engineId, TrainingCorpus source) }; } + private WordAlignmentCorpusDto Map(string engineId, WordAlignmentCorpus source) + { + return new WordAlignmentCorpusDto + { + ParallelCorpus = + source.ParallelCorpusRef != null + ? new ResourceLinkDto + { + Id = source.ParallelCorpusRef, + Url = _urlService.GetUrl( + Endpoints.GetParallelTranslationCorpus, + new { id = engineId, parallelCorpusId = source.ParallelCorpusRef } + ) + } + : null, + SourceFilters = source.SourceFilters?.Select(Map).ToList(), + TargetFilters = source.TargetFilters?.Select(Map).ToList() + }; + } + private ParallelCorpusFilterDto Map(ParallelCorpusFilter source) { return new ParallelCorpusFilterDto diff --git a/src/Serval/src/Serval.WordAlignment/Models/Build.cs b/src/Serval/src/Serval.WordAlignment/Models/Build.cs index b20e871c..594dff85 100644 --- a/src/Serval/src/Serval.WordAlignment/Models/Build.cs +++ b/src/Serval/src/Serval.WordAlignment/Models/Build.cs @@ -7,7 +7,7 @@ public record Build : IInitializableEntity public string? Name { get; init; } public required string EngineRef { get; init; } public IReadOnlyList? TrainOn { get; init; } - public IReadOnlyList? WordAlignOn { get; init; } + public IReadOnlyList? WordAlignOn { get; init; } public int Step { get; init; } public double? PercentCompleted { get; init; } public string? Message { get; init; } @@ -16,6 +16,7 @@ public record Build : IInitializableEntity public DateTime? DateFinished { get; init; } public IReadOnlyDictionary? Options { get; init; } public string? DeploymentVersion { get; init; } + public IReadOnlyDictionary ExecutionData { get; init; } = new Dictionary(); public bool? IsInitialized { get; set; } public DateTime? DateCreated { get; set; } } diff --git a/src/Serval/src/Serval.WordAlignment/Models/WordAlignmentCorpus.cs b/src/Serval/src/Serval.WordAlignment/Models/WordAlignmentCorpus.cs new file mode 100644 index 00000000..9cfcc36b --- /dev/null +++ b/src/Serval/src/Serval.WordAlignment/Models/WordAlignmentCorpus.cs @@ -0,0 +1,8 @@ +namespace Serval.WordAlignment.Models; + +public record WordAlignmentCorpus +{ + public string? ParallelCorpusRef { get; set; } + public IReadOnlyList? SourceFilters { get; set; } + public IReadOnlyList? TargetFilters { get; set; } +} diff --git a/src/Serval/src/Serval.WordAlignment/Services/BuildService.cs b/src/Serval/src/Serval.WordAlignment/Services/BuildService.cs index c3069135..857e458c 100644 --- a/src/Serval/src/Serval.WordAlignment/Services/BuildService.cs +++ b/src/Serval/src/Serval.WordAlignment/Services/BuildService.cs @@ -4,13 +4,30 @@ public class BuildService(IRepository builds) : EntityServiceBase( { public async Task> GetAllAsync(string parentId, CancellationToken cancellationToken = default) { - return await Entities.GetAllAsync(e => e.EngineRef == parentId, cancellationToken); + return await Entities.GetAllAsync( + e => e.EngineRef == parentId && (e.IsInitialized == null || e.IsInitialized.Value), + cancellationToken + ); + } + + public override async Task GetAsync(string id, CancellationToken cancellationToken = default) + { + Build? build = await Entities.GetAsync( + e => e.Id == id && (e.IsInitialized == null || e.IsInitialized.Value), + cancellationToken + ); + if (build == null) + throw new EntityNotFoundException($"Could not find the {typeof(Build).Name} '{id}'."); + return build; } public Task GetActiveAsync(string parentId, CancellationToken cancellationToken = default) { return Entities.GetAsync( - b => b.EngineRef == parentId && (b.State == JobState.Active || b.State == JobState.Pending), + b => + b.EngineRef == parentId + && (b.IsInitialized == null || b.IsInitialized.Value) + && (b.State == JobState.Active || b.State == JobState.Pending), cancellationToken ); } @@ -21,7 +38,11 @@ public Task> GetNewerRevisionAsync( CancellationToken cancellationToken = default ) { - return GetNewerRevisionAsync(e => e.Id == id, minRevision, cancellationToken); + return GetNewerRevisionAsync( + e => e.Id == id && (e.IsInitialized == null || e.IsInitialized.Value), + minRevision, + cancellationToken + ); } public Task> GetActiveNewerRevisionAsync( @@ -31,7 +52,10 @@ public Task> GetActiveNewerRevisionAsync( ) { return GetNewerRevisionAsync( - b => b.EngineRef == parentId && (b.State == JobState.Active || b.State == JobState.Pending), + b => + b.EngineRef == parentId + && (b.IsInitialized == null || b.IsInitialized.Value) + && (b.State == JobState.Active || b.State == JobState.Pending), minRevision, cancellationToken ); diff --git a/src/Serval/src/Serval.WordAlignment/Services/EngineService.cs b/src/Serval/src/Serval.WordAlignment/Services/EngineService.cs index 731bc41a..7850cf8d 100644 --- a/src/Serval/src/Serval.WordAlignment/Services/EngineService.cs +++ b/src/Serval/src/Serval.WordAlignment/Services/EngineService.cs @@ -70,17 +70,11 @@ public override async Task CreateAsync(Engine engine, CancellationToken { engine.DateCreated = DateTime.UtcNow; await Entities.InsertAsync(engine, cancellationToken); - WordAlignmentEngineApi.WordAlignmentEngineApiClient client; - try - { - client = _grpcClientFactory.CreateClient( - engine.Type - ); - } - catch (InvalidOperationException) - { + WordAlignmentEngineApi.WordAlignmentEngineApiClient? client = + _grpcClientFactory.CreateClient(engine.Type); + if (client is null) throw new InvalidOperationException($"'{engine.Type}' is an invalid engine type."); - } + var request = new CreateRequest { EngineType = engine.Type, @@ -92,6 +86,11 @@ public override async Task CreateAsync(Engine engine, CancellationToken if (engine.Name is not null) request.EngineName = engine.Name; await client.CreateAsync(request, cancellationToken: cancellationToken); + await Entities.UpdateAsync( + engine, + u => u.Set(e => e.IsInitialized, true), + cancellationToken: CancellationToken.None + ); } catch (RpcException rpcex) { @@ -165,7 +164,7 @@ public async Task StartBuildAsync(Build build, CancellationToken cancellationTok { StartBuildRequest request; Dictionary? trainOn = build.TrainOn?.ToDictionary(c => c.ParallelCorpusRef!); - Dictionary? wordAlignOn = build.WordAlignOn?.ToDictionary(c => + Dictionary? wordAlignOn = build.WordAlignOn?.ToDictionary(c => c.ParallelCorpusRef! ); IReadOnlyList parallelCorpora = engine @@ -451,7 +450,7 @@ private Shared.Models.AlignedWordPair Map(V1.AlignedWordPair source) private V1.ParallelCorpus Map( Shared.Models.ParallelCorpus source, TrainingCorpus? trainingCorpus, - TrainingCorpus? wordAlignmentCorpus, + WordAlignmentCorpus? wordAlignmentCorpus, bool trainOnAllCorpora, bool wordAlignOnAllCorpora ) diff --git a/src/Serval/src/Serval.WordAlignment/Services/WordAlignmentPlatformServiceV1.cs b/src/Serval/src/Serval.WordAlignment/Services/WordAlignmentPlatformServiceV1.cs index e45cae0e..e998e5d2 100644 --- a/src/Serval/src/Serval.WordAlignment/Services/WordAlignmentPlatformServiceV1.cs +++ b/src/Serval/src/Serval.WordAlignment/Services/WordAlignmentPlatformServiceV1.cs @@ -278,8 +278,8 @@ await _engines.UpdateAsync( return Empty; } - public override async Task InsertInferences( - IAsyncStreamReader requestStream, + public override async Task InsertWordAlignments( + IAsyncStreamReader requestStream, ServerCallContext context ) { @@ -287,7 +287,7 @@ ServerCallContext context int nextModelRevision = 0; var batch = new List(); - await foreach (InsertInferencesRequest request in requestStream.ReadAllAsync(context.CancellationToken)) + await foreach (InsertWordAlignmentsRequest request in requestStream.ReadAllAsync(context.CancellationToken)) { if (request.EngineId != engineId) { @@ -328,4 +328,23 @@ ServerCallContext context return Empty; } + + public override async Task UpdateBuildExecutionData( + UpdateBuildExecutionDataRequest request, + ServerCallContext context + ) + { + await _builds.UpdateAsync( + b => b.Id == request.BuildId, + u => + { + // initialize ExecutionData if it's null + foreach (KeyValuePair entry in request.ExecutionData) + u.Set(b => b.ExecutionData[entry.Key], entry.Value); + }, + cancellationToken: context.CancellationToken + ); + + return new Empty(); + } } diff --git a/src/Serval/test/Serval.ApiServer.IntegrationTests/TranslationEngineTests.cs b/src/Serval/test/Serval.ApiServer.IntegrationTests/TranslationEngineTests.cs index 18b50a68..51515b53 100644 --- a/src/Serval/test/Serval.ApiServer.IntegrationTests/TranslationEngineTests.cs +++ b/src/Serval/test/Serval.ApiServer.IntegrationTests/TranslationEngineTests.cs @@ -86,7 +86,9 @@ public class TranslationEngineTests private const string DOES_NOT_EXIST_ENGINE_ID = "e00000000000000000000004"; private const string DOES_NOT_EXIST_CORPUS_ID = "c00000000000000000000001"; +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. private TestEnvironment _env; +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. [SetUp] public async Task SetUp() @@ -601,7 +603,7 @@ public async Task AddCorpusToEngineByIdAsync(IEnumerable scope, int expe Assert.That(engine, Is.Not.Null); Assert.Multiple(() => { - Assert.That(engine.Corpora[0].SourceFiles[0].Filename, Is.EqualTo(FILE1_FILENAME)); + Assert.That(engine!.Corpora[0].SourceFiles[0].Filename, Is.EqualTo(FILE1_FILENAME)); Assert.That(engine.Corpora[0].TargetFiles[0].Filename, Is.EqualTo(FILE2_FILENAME)); }); break; @@ -663,7 +665,7 @@ string engineId Assert.That(engine, Is.Not.Null); Assert.Multiple(() => { - Assert.That(engine.Corpora[0].SourceFiles[0].Filename, Is.EqualTo(FILE2_FILENAME)); + Assert.That(engine!.Corpora[0].SourceFiles[0].Filename, Is.EqualTo(FILE2_FILENAME)); Assert.That(engine.Corpora[0].TargetFiles[0].Filename, Is.EqualTo(FILE1_FILENAME)); }); break; @@ -760,7 +762,7 @@ public async Task GetCorpusByIdForEngineByIdAsync( case 200: { Assert.That(result, Is.Not.Null); - TranslationCorpus resultAfterAdd = await client.GetCorpusAsync(engineId, result.Id); + TranslationCorpus resultAfterAdd = await client.GetCorpusAsync(engineId, result!.Id); Assert.Multiple(() => { Assert.That(resultAfterAdd.Name, Is.EqualTo(result.Name)); @@ -838,7 +840,7 @@ public async Task AddParallelCorpusToEngineByIdAsync() Assert.That(engine, Is.Not.Null); Assert.Multiple(() => { - Assert.That(engine.ParallelCorpora[0].SourceCorpora[0].Files[0].Filename, Is.EqualTo(FILE1_FILENAME)); + Assert.That(engine!.ParallelCorpora[0].SourceCorpora[0].Files[0].Filename, Is.EqualTo(FILE1_FILENAME)); Assert.That(engine.ParallelCorpora[0].TargetCorpora[0].Files[0].Filename, Is.EqualTo(FILE2_FILENAME)); }); } @@ -887,7 +889,7 @@ public async Task UpdateParallelCorpusByIdForEngineByIdAsync() Assert.That(engine, Is.Not.Null); Assert.Multiple(() => { - Assert.That(engine.ParallelCorpora[0].SourceCorpora[0].Files[0].Filename, Is.EqualTo(FILE1_FILENAME)); + Assert.That(engine!.ParallelCorpora[0].SourceCorpora[0].Files[0].Filename, Is.EqualTo(FILE1_FILENAME)); Assert.That(engine.ParallelCorpora[0].TargetCorpora[0].Files[0].Filename, Is.EqualTo(FILE2_FILENAME)); }); } @@ -1322,7 +1324,7 @@ public async Task GetBuildByIdForEngineByIdAsync( case 200: { Assert.That(build, Is.Not.Null); - TranslationBuild result = await client.GetBuildAsync(engineId, build.Id); + TranslationBuild result = await client.GetBuildAsync(engineId, build!.Id); Assert.Multiple(() => { Assert.That(result.Revision, Is.EqualTo(1)); @@ -1517,7 +1519,7 @@ public async Task GetCurrentBuildForEngineByIdAsync( { Assert.That(build, Is.Not.Null); TranslationBuild result = await client.GetCurrentBuildAsync(engineId); - Assert.That(result.Id, Is.EqualTo(build.Id)); + Assert.That(result.Id, Is.EqualTo(build!.Id)); break; } case 204: @@ -1587,53 +1589,6 @@ public async Task CancelCurrentBuildForEngineByIdAsync( } } - [Test] - public async Task StartBuild_ParallelCorpus() - { - TranslationEnginesClient client = _env.CreateTranslationEnginesClient(); - TranslationParallelCorpus addedCorpus = await client.AddParallelCorpusAsync( - NMT_ENGINE1_ID, - TestParallelCorpusConfig - ); - PretranslateCorpusConfig ptcc = - new() - { - ParallelCorpusId = addedCorpus.Id, - SourceFilters = [new() { CorpusId = SOURCE_CORPUS_ID_1, TextIds = ["all"] }] - }; - TrainingCorpusConfig tcc = - new() - { - ParallelCorpusId = addedCorpus.Id, - SourceFilters = [new() { CorpusId = SOURCE_CORPUS_ID_1, TextIds = ["all"] }], - TargetFilters = [new() { CorpusId = TARGET_CORPUS_ID, TextIds = ["all"] }] - }; - ; - TranslationBuildConfig tbc = new TranslationBuildConfig - { - Pretranslate = [ptcc], - TrainOn = [tcc], - Options = """ - {"max_steps":10, - "use_key_terms":false, - "some_double":10.5, - "some_nested": {"more_nested": {"other_double":10.5}}, - "some_string":"string"} - """ - }; - TranslationBuild resultAfterStart; - Assert.ThrowsAsync(async () => - { - resultAfterStart = await client.GetCurrentBuildAsync(NMT_ENGINE1_ID); - }); - - TranslationBuild build = await client.StartBuildAsync(NMT_ENGINE1_ID, tbc); - Assert.That(build, Is.Not.Null); - - build = await client.GetCurrentBuildAsync(NMT_ENGINE1_ID); - Assert.That(build, Is.Not.Null); - } - [Test] public async Task StartBuildAsync_ParallelCorpus() { @@ -1717,11 +1672,11 @@ public async Task StartBuildAsync_Corpus_NoFilter() TranslationBuild build = await client.StartBuildAsync(NMT_ENGINE1_ID, tbc); Assert.That(build, Is.Not.Null); Assert.That(build.TrainOn, Is.Not.Null); - Assert.That(build.TrainOn.Count, Is.EqualTo(1)); + Assert.That(build.TrainOn!.Count, Is.EqualTo(1)); Assert.That(build.TrainOn[0].TextIds, Is.Null); Assert.That(build.TrainOn[0].ScriptureRange, Is.Null); Assert.That(build.Pretranslate, Is.Not.Null); - Assert.That(build.Pretranslate.Count, Is.EqualTo(1)); + Assert.That(build.Pretranslate!.Count, Is.EqualTo(1)); Assert.That(build.Pretranslate[0].TextIds, Is.Null); Assert.That(build.Pretranslate[0].ScriptureRange, Is.Null); @@ -1768,11 +1723,11 @@ public async Task StartBuildAsync_ParallelCorpus_NoFilter() TranslationBuild build = await client.StartBuildAsync(NMT_ENGINE1_ID, tbc); Assert.That(build, Is.Not.Null); Assert.That(build.TrainOn, Is.Not.Null); - Assert.That(build.TrainOn.Count, Is.EqualTo(1)); + Assert.That(build.TrainOn!.Count, Is.EqualTo(1)); Assert.That(build.TrainOn[0].TextIds, Is.Null); Assert.That(build.TrainOn[0].ScriptureRange, Is.Null); Assert.That(build.Pretranslate, Is.Not.Null); - Assert.That(build.Pretranslate.Count, Is.EqualTo(1)); + Assert.That(build.Pretranslate!.Count, Is.EqualTo(1)); Assert.That(build.Pretranslate[0].TextIds, Is.Null); Assert.That(build.Pretranslate[0].ScriptureRange, Is.Null); @@ -1897,7 +1852,7 @@ public async Task TryToQueueMultipleBuildsPerSingleUser() build = await client.StartBuildAsync(engineId, tbc); }); Assert.That(ex, Is.Not.Null); - Assert.That(ex.StatusCode, Is.EqualTo(expectedStatusCode)); + Assert.That(ex!.StatusCode, Is.EqualTo(expectedStatusCode)); } [Test] @@ -1991,7 +1946,7 @@ public async Task GetPretranslatedUsfmAsync_BookDoesNotExist() public async Task GetQueueAsync(string engineType) { TranslationEngineTypesClient client = _env.CreateTranslationEngineTypesClient(); - Client.Queue queue = await client.GetQueueAsync(engineType); + Queue queue = await client.GetQueueAsync(engineType); Assert.That(queue.Size, Is.EqualTo(0)); } @@ -2001,10 +1956,10 @@ public void GetQueueAsync_NotAuthorized() TranslationEngineTypesClient client = _env.CreateTranslationEngineTypesClient([Scopes.ReadFiles]); ServalApiException? ex = Assert.ThrowsAsync(async () => { - Client.Queue queue = await client.GetQueueAsync("Echo"); + Queue queue = await client.GetQueueAsync("Echo"); }); Assert.That(ex, Is.Not.Null); - Assert.That(ex.StatusCode, Is.EqualTo(403)); + Assert.That(ex!.StatusCode, Is.EqualTo(403)); } [Test] @@ -2028,7 +1983,7 @@ public void GetLanguageInfo_Error() Client.LanguageInfo languageInfo = await client.GetLanguageInfoAsync("Nmt", "abc"); }); Assert.That(ex, Is.Not.Null); - Assert.That(ex.StatusCode, Is.EqualTo(403)); + Assert.That(ex!.StatusCode, Is.EqualTo(403)); } [Test] diff --git a/src/Serval/test/Serval.ApiServer.IntegrationTests/WordAlignmentEngineTests.cs b/src/Serval/test/Serval.ApiServer.IntegrationTests/WordAlignmentEngineTests.cs index 276608ec..bb021bbc 100644 --- a/src/Serval/test/Serval.ApiServer.IntegrationTests/WordAlignmentEngineTests.cs +++ b/src/Serval/test/Serval.ApiServer.IntegrationTests/WordAlignmentEngineTests.cs @@ -5,8 +5,6 @@ namespace Serval.ApiServer; -#pragma warning disable CS0612 // Type or member is obsolete - [TestFixture] [Category("Integration")] public class WordAlignmentEngineTests @@ -56,7 +54,9 @@ public class WordAlignmentEngineTests private const string DOES_NOT_EXIST_ENGINE_ID = "e00000000000000000000004"; private const string DOES_NOT_EXIST_CORPUS_ID = "c00000000000000000000001"; +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. private TestEnvironment _env; +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. [SetUp] public async Task SetUp() @@ -188,7 +188,7 @@ public async Task GetAllAsync(IEnumerable scope, int expectedStatusCode) { case 200: ICollection results = await client.GetAllAsync(); - Assert.That(results, Has.Count.EqualTo(4)); + Assert.That(results, Has.Count.EqualTo(3)); //Only three are owned by client1 Assert.That(results.All(eng => eng.SourceLanguage.Equals("en"))); break; case 403: @@ -309,7 +309,7 @@ public async Task DeleteEngineByIdAsync(IEnumerable scope, int expectedS case 200: await client.DeleteAsync(engineId); ICollection results = await client.GetAllAsync(); - Assert.That(results, Has.Count.EqualTo(3)); + Assert.That(results, Has.Count.EqualTo(2)); //Only two are owned by client1 Assert.That(results.All(eng => eng.SourceLanguage.Equals("en"))); break; case 403: @@ -335,7 +335,7 @@ public async Task DeleteEngineByIdAsync(IEnumerable scope, int expectedS )] [TestCase(new[] { Scopes.ReadWordAlignmentEngines, Scopes.UpdateWordAlignmentEngines }, 409, ECHO_ENGINE1_ID)] [TestCase(new[] { Scopes.ReadFiles }, 403, ECHO_ENGINE1_ID)] //Arbitrary unrelated privilege - public async Task TranslateSegmentWithEngineByIdAsync( + public async Task GetWordAlignmentForSegmentPairWithEngineByIdAsync( IEnumerable scope, int expectedStatusCode, string engineId @@ -350,15 +350,19 @@ await _env.Builds.InsertAsync( ); Client.WordAlignmentResult result = await client.GetWordAlignmentAsync( engineId, - new WordAlignmentRequest { SourceSegment = "This is a test.", TargetSegment = "This is a test." }, - Arg.Any() + new WordAlignmentRequest { SourceSegment = "This is a test.", TargetSegment = "This is a test." } ); Assert.That(result.SourceTokens, Is.EqualTo("This is a test .".Split())); Assert.That(result.TargetTokens, Is.EqualTo("This is a test .".Split())); break; case 409: { - _env.EchoClient.GetWordAlignmentAsync(Arg.Any()) + _env.EchoClient.GetWordAlignmentAsync( + Arg.Any(), + null, + null, + Arg.Any() + ) .Returns(CreateAsyncUnaryCall(StatusCode.Aborted)); ServalApiException? ex = Assert.ThrowsAsync(async () => { @@ -906,7 +910,7 @@ public async Task GetBuildByIdForEngineByIdAsync( case 200: { Assert.That(build, Is.Not.Null); - WordAlignmentBuild result = await client.GetBuildAsync(engineId, build.Id); + WordAlignmentBuild result = await client.GetBuildAsync(engineId, build!.Id); Assert.Multiple(() => { Assert.That(result.Revision, Is.EqualTo(1)); @@ -930,7 +934,7 @@ public async Task GetBuildByIdForEngineByIdAsync( Assert.That(build, Is.Not.Null); ServalApiException? ex = Assert.ThrowsAsync(async () => { - await client.GetBuildAsync(engineId, build.Id, 3); + await client.GetBuildAsync(engineId, build!.Id, 3); }); Assert.That(ex?.StatusCode, Is.EqualTo(expectedStatusCode)); break; @@ -962,6 +966,7 @@ public async Task StartBuildForEngineByIdAsync(IEnumerable scope, int ex { WordAlignmentEnginesClient client = _env.CreateWordAlignmentEnginesClient(scope); TrainingCorpusConfig2 tcc; + WordAlignmentCorpusConfig wacc; WordAlignmentBuildConfig tbc; switch (expectedStatusCode) { @@ -973,12 +978,24 @@ public async Task StartBuildForEngineByIdAsync(IEnumerable scope, int ex tcc = new TrainingCorpusConfig2 { ParallelCorpusId = addedCorpus.Id, - SourceFilters = [new ParallelCorpusFilterConfig2 { TextIds = ["all"] }], - TargetFilters = [new ParallelCorpusFilterConfig2 { TextIds = ["all"] }] + SourceFilters = + [ + new ParallelCorpusFilterConfig2 { CorpusId = SOURCE_CORPUS_ID_1, TextIds = ["all"] } + ], + TargetFilters = [new ParallelCorpusFilterConfig2 { CorpusId = TARGET_CORPUS_ID, TextIds = ["all"] }] + }; + wacc = new WordAlignmentCorpusConfig + { + ParallelCorpusId = addedCorpus.Id, + SourceFilters = + [ + new ParallelCorpusFilterConfig2 { CorpusId = SOURCE_CORPUS_ID_1, TextIds = ["all"] } + ], + TargetFilters = [new ParallelCorpusFilterConfig2 { CorpusId = TARGET_CORPUS_ID, TextIds = ["all"] }] }; tbc = new WordAlignmentBuildConfig { - WordAlignOn = [tcc], + WordAlignOn = [wacc], TrainOn = [tcc], Options = """ {"max_steps":10, @@ -1010,10 +1027,28 @@ public async Task StartBuildForEngineByIdAsync(IEnumerable scope, int ex tcc = new TrainingCorpusConfig2 { ParallelCorpusId = "cccccccccccccccccccccccc", - SourceFilters = [new ParallelCorpusFilterConfig2 { TextIds = ["all"] }], - TargetFilters = [new ParallelCorpusFilterConfig2 { TextIds = ["all"] }] + SourceFilters = + [ + new ParallelCorpusFilterConfig2 { CorpusId = "ccccccccccccccccccccccc1", TextIds = ["all"] } + ], + TargetFilters = + [ + new ParallelCorpusFilterConfig2 { CorpusId = "ccccccccccccccccccccccc1", TextIds = ["all"] } + ] }; - tbc = new WordAlignmentBuildConfig { WordAlignOn = [tcc], TrainOn = [tcc] }; + wacc = new WordAlignmentCorpusConfig + { + ParallelCorpusId = "cccccccccccccccccccccccc", + SourceFilters = + [ + new ParallelCorpusFilterConfig2 { CorpusId = "ccccccccccccccccccccccc1", TextIds = ["all"] } + ], + TargetFilters = + [ + new ParallelCorpusFilterConfig2 { CorpusId = "ccccccccccccccccccccccc1", TextIds = ["all"] } + ] + }; + tbc = new WordAlignmentBuildConfig { WordAlignOn = [wacc], TrainOn = [tcc] }; ServalApiException? ex = Assert.ThrowsAsync(async () => { await client.StartBuildAsync(engineId, tbc); @@ -1038,13 +1073,20 @@ public async Task StartBuildForEngineAsync_UnparsableOptions() new() { ParallelCorpusId = addedCorpus.Id, - SourceFilters = [new ParallelCorpusFilterConfig2 { TextIds = ["all"] }], - TargetFilters = [new ParallelCorpusFilterConfig2 { TextIds = ["all"] }] + SourceFilters = [new ParallelCorpusFilterConfig2 { CorpusId = SOURCE_CORPUS_ID_1, TextIds = ["all"] }], + TargetFilters = [new ParallelCorpusFilterConfig2 { CorpusId = SOURCE_CORPUS_ID_1, TextIds = ["all"] }] + }; + WordAlignmentCorpusConfig wacc = + new() + { + ParallelCorpusId = addedCorpus.Id, + SourceFilters = [new ParallelCorpusFilterConfig2 { CorpusId = SOURCE_CORPUS_ID_1, TextIds = ["all"] }], + TargetFilters = [new ParallelCorpusFilterConfig2 { CorpusId = SOURCE_CORPUS_ID_1, TextIds = ["all"] }] }; WordAlignmentBuildConfig tbc = new() { - WordAlignOn = [tcc], + WordAlignOn = [wacc], TrainOn = [tcc], Options = "unparsable json" }; @@ -1087,7 +1129,7 @@ public async Task GetCurrentBuildForEngineByIdAsync( { Assert.That(build, Is.Not.Null); WordAlignmentBuild result = await client.GetCurrentBuildAsync(engineId); - Assert.That(result.Id, Is.EqualTo(build.Id)); + Assert.That(result.Id, Is.EqualTo(build!.Id)); break; } case 204: @@ -1163,7 +1205,7 @@ public async Task CancelCurrentBuildForEngineByIdAsync( } [Test] - public async Task StartBuild_ParallelCorpus() + public async Task StartBuildAsync_ParallelCorpus() { WordAlignmentEnginesClient client = _env.CreateWordAlignmentEnginesClient(); WordAlignmentParallelCorpus addedCorpus = await client.AddParallelCorpusAsync( @@ -1178,40 +1220,7 @@ public async Task StartBuild_ParallelCorpus() TargetFilters = [new() { CorpusId = TARGET_CORPUS_ID, TextIds = ["all"] }] }; ; - WordAlignmentBuildConfig tbc = new WordAlignmentBuildConfig - { - WordAlignOn = [tcc], - TrainOn = [tcc], - Options = """ - {"max_steps":10, - "use_key_terms":false, - "some_double":10.5, - "some_nested": {"more_nested": {"other_double":10.5}}, - "some_string":"string"} - """ - }; - WordAlignmentBuild resultAfterStart; - Assert.ThrowsAsync(async () => - { - resultAfterStart = await client.GetCurrentBuildAsync(STATISTICAL_ENGINE_ID); - }); - - WordAlignmentBuild build = await client.StartBuildAsync(STATISTICAL_ENGINE_ID, tbc); - Assert.That(build, Is.Not.Null); - - build = await client.GetCurrentBuildAsync(STATISTICAL_ENGINE_ID); - Assert.That(build, Is.Not.Null); - } - - [Test] - public async Task StartBuildAsync_ParallelCorpus() - { - WordAlignmentEnginesClient client = _env.CreateWordAlignmentEnginesClient(); - WordAlignmentParallelCorpus addedCorpus = await client.AddParallelCorpusAsync( - STATISTICAL_ENGINE_ID, - TestParallelCorpusConfig - ); - TrainingCorpusConfig2 tcc = + WordAlignmentCorpusConfig wacc = new() { ParallelCorpusId = addedCorpus.Id, @@ -1221,7 +1230,7 @@ public async Task StartBuildAsync_ParallelCorpus() ; WordAlignmentBuildConfig tbc = new WordAlignmentBuildConfig { - WordAlignOn = [tcc], + WordAlignOn = [wacc], TrainOn = [tcc], Options = """ {"max_steps":10, @@ -1260,9 +1269,17 @@ public async Task StartBuildAsync_ParallelCorpus_NoFilter() TargetFilters = [new() { CorpusId = TARGET_CORPUS_ID }] }; ; + WordAlignmentCorpusConfig wacc = + new() + { + ParallelCorpusId = addedCorpus.Id, + SourceFilters = [new() { CorpusId = SOURCE_CORPUS_ID_1 }], + TargetFilters = [new() { CorpusId = TARGET_CORPUS_ID }] + }; + ; WordAlignmentBuildConfig tbc = new WordAlignmentBuildConfig { - WordAlignOn = [tcc], + WordAlignOn = [wacc], TrainOn = [tcc], Options = """ {"max_steps":10, @@ -1281,27 +1298,27 @@ public async Task StartBuildAsync_ParallelCorpus_NoFilter() WordAlignmentBuild build = await client.StartBuildAsync(STATISTICAL_ENGINE_ID, tbc); Assert.That(build, Is.Not.Null); Assert.That(build.TrainOn, Is.Not.Null); - Assert.That(build.TrainOn.Count, Is.EqualTo(1)); - Assert.That(build.TrainOn[0].SourceFilters, Is.Null); - Assert.That(build.TrainOn[0].TargetFilters, Is.Null); + Assert.That(build.TrainOn!.Count, Is.EqualTo(1)); + Assert.That(build.TrainOn[0].SourceFilters, Is.Not.Null); + Assert.That(build.TrainOn[0].TargetFilters, Is.Not.Null); Assert.That(build.WordAlignOn, Is.Not.Null); - Assert.That(build.WordAlignOn.Count, Is.EqualTo(1)); - Assert.That(build.WordAlignOn[0].SourceFilters, Is.Null); - Assert.That(build.WordAlignOn[0].TargetFilters, Is.Null); + Assert.That(build.WordAlignOn!.Count, Is.EqualTo(1)); + Assert.That(build.WordAlignOn[0].SourceFilters, Is.Not.Null); + Assert.That(build.WordAlignOn[0].TargetFilters, Is.Not.Null); build = await client.GetCurrentBuildAsync(STATISTICAL_ENGINE_ID); Assert.That(build, Is.Not.Null); } [Test] - public async Task StartBuildAsync_ParallelCorpus_PretranslateNoCorpusSpecified() + public async Task StartBuildAsync_ParallelCorpus_WordAlignNoCorpusSpecified() { WordAlignmentEnginesClient client = _env.CreateWordAlignmentEnginesClient(); WordAlignmentParallelCorpus addedParallelCorpus = await client.AddParallelCorpusAsync( STATISTICAL_ENGINE_ID, TestMixedParallelCorpusConfig ); - TrainingCorpusConfig2 wacc = new() { }; + WordAlignmentCorpusConfig wacc = new() { }; TrainingCorpusConfig2 tcc = new() { ParallelCorpusId = addedParallelCorpus.Id }; WordAlignmentBuildConfig tbc = new WordAlignmentBuildConfig { WordAlignOn = [wacc], TrainOn = [tcc] }; WordAlignmentBuild resultAfterStart; @@ -1312,14 +1329,14 @@ public async Task StartBuildAsync_ParallelCorpus_PretranslateNoCorpusSpecified() } [Test] - public async Task StartBuildAsync_ParallelCorpus_PretranslateFilterOnMultipleSources() + public async Task StartBuildAsync_ParallelCorpus_WordAlignFilterOnMultipleSources() { WordAlignmentEnginesClient client = _env.CreateWordAlignmentEnginesClient(); WordAlignmentParallelCorpus addedParallelCorpus = await client.AddParallelCorpusAsync( STATISTICAL_ENGINE_ID, TestParallelCorpusConfig ); - TrainingCorpusConfig2 wacc = + WordAlignmentCorpusConfig wacc = new() { ParallelCorpusId = addedParallelCorpus.Id, @@ -1345,7 +1362,7 @@ public async Task StartBuildAsync_ParallelCorpus_TrainOnNoCorpusSpecified() STATISTICAL_ENGINE_ID, TestParallelCorpusConfig ); - TrainingCorpusConfig2 wacc = new() { ParallelCorpusId = addedParallelCorpus.Id }; + WordAlignmentCorpusConfig wacc = new() { ParallelCorpusId = addedParallelCorpus.Id }; TrainingCorpusConfig2 tcc = new() { }; WordAlignmentBuildConfig tbc = new WordAlignmentBuildConfig { WordAlignOn = [wacc], TrainOn = [tcc] }; WordAlignmentBuild resultAfterStart; @@ -1365,7 +1382,7 @@ public async Task TryToQueueMultipleBuildsPerSingleUser() engineId, TestParallelCorpusConfig ); - TrainingCorpusConfig2 wacc = new() { ParallelCorpusId = addedCorpus.Id }; + WordAlignmentCorpusConfig wacc = new() { ParallelCorpusId = addedCorpus.Id }; var tbc = new WordAlignmentBuildConfig { WordAlignOn = [wacc] }; WordAlignmentBuild build = await client.StartBuildAsync(engineId, tbc); _env.StatisticalClient.StartBuildAsync(Arg.Any(), null, null, Arg.Any()) @@ -1375,7 +1392,7 @@ public async Task TryToQueueMultipleBuildsPerSingleUser() build = await client.StartBuildAsync(engineId, tbc); }); Assert.That(ex, Is.Not.Null); - Assert.That(ex.StatusCode, Is.EqualTo(expectedStatusCode)); + Assert.That(ex!.StatusCode, Is.EqualTo(expectedStatusCode)); } [Test] @@ -1422,35 +1439,13 @@ public void GetWordAlignmentsByTextId_EngineDoesNotExist() Assert.That(ex?.StatusCode, Is.EqualTo(404)); } - // [Test] - // [TestCase("Nmt")] - // [TestCase("EchoWordAlignment")] - // public async Task GetQueueAsync(string engineType) - // { - // TranslationEngineTypesClient client = _env.CreateTranslationEngineTypesClient(); - // Client.Queue queue = await client.GetQueueAsync(engineType); - // Assert.That(queue.Size, Is.EqualTo(0)); - // } - - // [Test] - // public void GetQueueAsync_NotAuthorized() - // { - // TranslationEngineTypesClient client = _env.CreateTranslationEngineTypesClient([Scopes.ReadFiles]); - // ServalApiException? ex = Assert.ThrowsAsync(async () => - // { - // Client.Queue queue = await client.GetQueueAsync("EchoWordAlignment"); - // }); - // Assert.That(ex, Is.Not.Null); - // Assert.That(ex.StatusCode, Is.EqualTo(403)); - // } - [Test] public async Task DataFileUpdate_Propagated() { - WordAlignmentEnginesClient translationClient = _env.CreateWordAlignmentEnginesClient(); + WordAlignmentEnginesClient client = _env.CreateWordAlignmentEnginesClient(); DataFilesClient dataFilesClient = _env.CreateDataFilesClient(); CorporaClient corporaClient = _env.CreateCorporaClient(); - await translationClient.AddParallelCorpusAsync(ECHO_ENGINE1_ID, TestParallelCorpusConfig); + await client.AddParallelCorpusAsync(ECHO_ENGINE1_ID, TestParallelCorpusConfig); // Get the original files DataFile orgFileFromClient = await dataFilesClient.GetAsync(FILE1_SRC_ID); @@ -1586,13 +1581,7 @@ public TestEnvironment() public WordAlignmentEnginesClient CreateWordAlignmentEnginesClient(IEnumerable? scope = null) { - scope ??= - [ - Scopes.CreateWordAlignmentEngines, - Scopes.ReadWordAlignmentEngines, - Scopes.UpdateWordAlignmentEngines, - Scopes.DeleteWordAlignmentEngines - ]; + scope ??= Scopes.All; HttpClient httpClient = Factory .WithWebHostBuilder(builder => { @@ -1722,5 +1711,3 @@ protected override void DisposeManagedResources() } } } - -#pragma warning restore CS0612 // Type or member is obsolete diff --git a/src/Serval/test/Serval.E2ETests/ServalApiTests.cs b/src/Serval/test/Serval.E2ETests/ServalApiTests.cs index c0b77db0..9983f572 100644 --- a/src/Serval/test/Serval.E2ETests/ServalApiTests.cs +++ b/src/Serval/test/Serval.E2ETests/ServalApiTests.cs @@ -6,7 +6,9 @@ namespace Serval.E2ETests; [Category("E2E")] public class ServalApiTests { +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. private ServalClientHelper _helperClient; +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. [OneTimeSetUp] public async Task OneTimeSetup() @@ -167,10 +169,10 @@ public async Task NmtBatch() var executionData = build.ExecutionData!; Assert.That(executionData, Contains.Key("trainCount")); - Assert.That(executionData, Contains.Key("pretranslateCount")); + Assert.That(executionData, Contains.Key("inferenceCount")); int trainCount = Convert.ToInt32(executionData["trainCount"], CultureInfo.InvariantCulture); - int pretranslateCount = Convert.ToInt32(executionData["pretranslateCount"], CultureInfo.InvariantCulture); + int pretranslateCount = Convert.ToInt32(executionData["inferenceCount"], CultureInfo.InvariantCulture); Assert.That(trainCount, Is.GreaterThan(0)); Assert.That(pretranslateCount, Is.GreaterThan(0)); @@ -315,7 +317,7 @@ public async Task CircuitousRouteGetWordGraphAsync() await _helperClient.TranslationEnginesClient.GetWordGraphAsync(smtEngineId, "verdad"); }); Assert.That(ex, Is.Not.Null); - Assert.That(ex.StatusCode, Is.EqualTo(409)); + Assert.That(ex!.StatusCode, Is.EqualTo(409)); //Add corpus var corpus1 = await _helperClient.MakeParallelTextCorpus(["2JN.txt", "3JN.txt"], "es", "en", false); @@ -497,10 +499,20 @@ public async Task ParatextProjectNmtJobAsync() public async Task GetWordAlignment() { string engineId = await _helperClient.CreateNewEngineAsync("Statistical", "es", "en", "STAT1"); - string[] books = ["1JN.txt", "2JN.txt", "3JN.txt"]; - ParallelCorpusConfig train_corpus = await _helperClient.MakeParallelTextCorpus(books, "es", "en", false); - await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, train_corpus, false); - await _helperClient.BuildEngineAsync(engineId); + string[] books = ["1JN.txt", "2JN.txt", "MAT.txt"]; + ParallelCorpusConfig trainCorpus = await _helperClient.MakeParallelTextCorpus(books, "es", "en", false); + ParallelCorpusConfig testCorpus = await _helperClient.MakeParallelTextCorpus(["3JN.txt"], "es", "en", false); + string trainCorpusId = await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, trainCorpus, false); + string corpusId = await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, testCorpus, true); + _helperClient.WordAlignmentBuildConfig.TrainOn = + [ + new TrainingCorpusConfig2() { ParallelCorpusId = trainCorpusId } + ]; + _helperClient.WordAlignmentBuildConfig.WordAlignOn = + [ + new WordAlignmentCorpusConfig() { ParallelCorpusId = corpusId } + ]; + string buildId = await _helperClient.BuildEngineAsync(engineId); WordAlignmentResult tResult = await _helperClient.WordAlignmentEnginesClient.GetWordAlignmentAsync( engineId, new WordAlignmentRequest() { SourceSegment = "espíritu verdad", TargetSegment = "spirit truth" } @@ -515,6 +527,24 @@ public async Task GetWordAlignment() } ) ); + + WordAlignmentBuild build = await _helperClient.WordAlignmentEnginesClient.GetBuildAsync(engineId, buildId); + Assert.That(build.ExecutionData, Is.Not.Null); + + var executionData = build.ExecutionData!; + + Assert.That(executionData, Contains.Key("trainCount")); + Assert.That(executionData, Contains.Key("inferenceCount")); + + int trainCount = Convert.ToInt32(executionData["trainCount"], CultureInfo.InvariantCulture); + int wordAlignmentCount = Convert.ToInt32(executionData["inferenceCount"], CultureInfo.InvariantCulture); + + Assert.That(trainCount, Is.GreaterThan(0)); + Assert.That(wordAlignmentCount, Is.GreaterThan(0)); + + IList wordAlignments = + await _helperClient.WordAlignmentEnginesClient.GetAllWordAlignmentsAsync(engineId, corpusId); + Assert.That(wordAlignments, Has.Count.EqualTo(14)); //Number of verses in 3JN } [TearDown] diff --git a/src/Serval/test/Serval.E2ETests/ServalClientHelper.cs b/src/Serval/test/Serval.E2ETests/ServalClientHelper.cs index e3a68fa9..0f37dd0a 100644 --- a/src/Serval/test/Serval.E2ETests/ServalClientHelper.cs +++ b/src/Serval/test/Serval.E2ETests/ServalClientHelper.cs @@ -524,7 +524,7 @@ await WordAlignmentEnginesClient.AddParallelCorpusAsync( if (inference) { WordAlignmentBuildConfig.WordAlignOn!.Add( - new TrainingCorpusConfig2 { ParallelCorpusId = parallelCorpus.Id } + new WordAlignmentCorpusConfig { ParallelCorpusId = parallelCorpus.Id } ); } } diff --git a/src/Serval/test/Serval.Translation.Tests/Services/PlatformServiceTests.cs b/src/Serval/test/Serval.Translation.Tests/Services/PlatformServiceTests.cs index 0e2bdf26..543d6bef 100644 --- a/src/Serval/test/Serval.Translation.Tests/Services/PlatformServiceTests.cs +++ b/src/Serval/test/Serval.Translation.Tests/Services/PlatformServiceTests.cs @@ -38,7 +38,7 @@ await env.PlatformService.BuildRestarting( Assert.That(env.Engines.Get("e0").IsBuilding, Is.False); Assert.That(env.Pretranslations.Count, Is.EqualTo(0)); - await env.PlatformService.InsertInferences(new MockAsyncStreamReader("e0"), env.ServerCallContext); + await env.PlatformService.InsertPretranslations(new MockAsyncStreamReader("e0"), env.ServerCallContext); Assert.That(env.Pretranslations.Count, Is.EqualTo(1)); await env.PlatformService.BuildFaulted(new BuildFaultedRequest() { BuildId = "b0" }, env.ServerCallContext); @@ -50,12 +50,12 @@ await env.PlatformService.BuildRestarting( new BuildRestartingRequest() { BuildId = "b0" }, env.ServerCallContext ); - await env.PlatformService.InsertInferences(new MockAsyncStreamReader("e0"), env.ServerCallContext); + await env.PlatformService.InsertPretranslations(new MockAsyncStreamReader("e0"), env.ServerCallContext); Assert.That(env.Pretranslations.Count, Is.EqualTo(1)); await env.PlatformService.BuildCompleted(new BuildCompletedRequest() { BuildId = "b0" }, env.ServerCallContext); Assert.That(env.Pretranslations.Count, Is.EqualTo(1)); await env.PlatformService.BuildStarted(new BuildStartedRequest() { BuildId = "b0" }, env.ServerCallContext); - await env.PlatformService.InsertInferences(new MockAsyncStreamReader("e0"), env.ServerCallContext); + await env.PlatformService.InsertPretranslations(new MockAsyncStreamReader("e0"), env.ServerCallContext); await env.PlatformService.BuildCompleted(new BuildCompletedRequest() { BuildId = "b0" }, env.ServerCallContext); Assert.That(env.Pretranslations.Count, Is.EqualTo(1)); } @@ -114,7 +114,7 @@ public async Task UpdateBuildExecutionData() ExecutionData = new Dictionary { { "trainCount", "0" }, - { "pretranslateCount", "0" }, + { "inferenceCount", "0" }, { "staticCount", "0" } } }; @@ -125,10 +125,10 @@ public async Task UpdateBuildExecutionData() var executionData = build.ExecutionData; Assert.That(executionData, Contains.Key("trainCount")); - Assert.That(executionData, Contains.Key("pretranslateCount")); + Assert.That(executionData, Contains.Key("inferenceCount")); int trainCount = Convert.ToInt32(executionData["trainCount"], CultureInfo.InvariantCulture); - int pretranslateCount = Convert.ToInt32(executionData["pretranslateCount"], CultureInfo.InvariantCulture); + int pretranslateCount = Convert.ToInt32(executionData["inferenceCount"], CultureInfo.InvariantCulture); int staticCount = Convert.ToInt32(executionData["staticCount"], CultureInfo.InvariantCulture); Assert.That(trainCount, Is.EqualTo(0)); @@ -137,7 +137,7 @@ public async Task UpdateBuildExecutionData() var updateRequest = new UpdateBuildExecutionDataRequest() { BuildId = "123", EngineId = engine.Id }; updateRequest.ExecutionData.Add( - new Dictionary { { "trainCount", "4" }, { "pretranslateCount", "5" } } + new Dictionary { { "trainCount", "4" }, { "inferenceCount", "5" } } ); await env.PlatformService.UpdateBuildExecutionData(updateRequest, env.ServerCallContext); @@ -147,7 +147,7 @@ public async Task UpdateBuildExecutionData() executionData = build!.ExecutionData; trainCount = Convert.ToInt32(executionData["trainCount"], CultureInfo.InvariantCulture); - pretranslateCount = Convert.ToInt32(executionData["pretranslateCount"], CultureInfo.InvariantCulture); + pretranslateCount = Convert.ToInt32(executionData["inferenceCount"], CultureInfo.InvariantCulture); staticCount = Convert.ToInt32(executionData["staticCount"], CultureInfo.InvariantCulture); Assert.That(trainCount, Is.GreaterThan(0)); @@ -220,12 +220,12 @@ public TestEnvironment() public TranslationPlatformServiceV1 PlatformService { get; } } - private class MockAsyncStreamReader(string engineId) : IAsyncStreamReader + private class MockAsyncStreamReader(string engineId) : IAsyncStreamReader { private bool _endOfStream = false; public string EngineId { get; } = engineId; - public InsertInferencesRequest Current => new() { EngineId = EngineId }; + public InsertPretranslationsRequest Current => new() { EngineId = EngineId }; public Task MoveNext(CancellationToken cancellationToken) { diff --git a/src/Serval/test/Serval.WordAlignment.Tests/Services/BuildCleanupServiceTests.cs b/src/Serval/test/Serval.WordAlignment.Tests/Services/BuildCleanupServiceTests.cs new file mode 100644 index 00000000..db16d100 --- /dev/null +++ b/src/Serval/test/Serval.WordAlignment.Tests/Services/BuildCleanupServiceTests.cs @@ -0,0 +1,56 @@ +namespace Serval.WordAlignment.Services; + +[TestFixture] +public class BuildCleanupServiceTests +{ + [Test] + public async Task CleanupAsync() + { + TestEnvironment env = new(); + Assert.That(env.Builds.Count, Is.EqualTo(2)); + await env.CheckBuildsAsync(); + Assert.That(env.Builds.Count, Is.EqualTo(1)); + Assert.That((await env.Builds.GetAllAsync())[0].Id, Is.EqualTo("build2")); + } + + private class TestEnvironment + { + public MemoryRepository Builds { get; } + + public TestEnvironment() + { + Builds = new MemoryRepository(); + Builds.Add( + new Build + { + Id = "build1", + EngineRef = "engine1", + IsInitialized = false, + DateCreated = DateTime.UtcNow.Subtract(TimeSpan.FromHours(10)) + } + ); + Builds.Add( + new Build + { + Id = "build2", + EngineRef = "engine2", + IsInitialized = true, + DateCreated = DateTime.UtcNow.Subtract(TimeSpan.FromHours(10)) + } + ); + + Service = new BuildCleanupService( + Substitute.For(), + Substitute.For>(), + TimeSpan.Zero + ); + } + + public BuildCleanupService Service { get; } + + public async Task CheckBuildsAsync() + { + await Service.CheckEntitiesAsync(Builds, CancellationToken.None); + } + } +} diff --git a/src/Serval/test/Serval.WordAlignment.Tests/Services/EngineCleanupServiceTests.cs b/src/Serval/test/Serval.WordAlignment.Tests/Services/EngineCleanupServiceTests.cs new file mode 100644 index 00000000..0ed467fb --- /dev/null +++ b/src/Serval/test/Serval.WordAlignment.Tests/Services/EngineCleanupServiceTests.cs @@ -0,0 +1,64 @@ +namespace Serval.WordAlignment.Services; + +[TestFixture] +public class EngineCleanupServiceTests +{ + [Test] + public async Task CleanupAsync() + { + TestEnvironment env = new(); + Assert.That(env.Engines.Count, Is.EqualTo(2)); + await env.CheckEnginesAsync(); + Assert.That(env.Engines.Count, Is.EqualTo(1)); + Assert.That((await env.Engines.GetAllAsync())[0].Id, Is.EqualTo("engine2")); + } + + private class TestEnvironment + { + public MemoryRepository Engines { get; } + + public TestEnvironment() + { + Engines = new MemoryRepository(); + Engines.Add( + new Engine + { + Id = "engine1", + SourceLanguage = "en", + TargetLanguage = "es", + Type = "Nmt", + Owner = "client1", + IsInitialized = false, + DateCreated = DateTime.UtcNow.Subtract(TimeSpan.FromHours(10)), + ParallelCorpora = [] + } + ); + Engines.Add( + new Engine + { + Id = "engine2", + SourceLanguage = "en", + TargetLanguage = "es", + Type = "Nmt", + Owner = "client1", + IsInitialized = true, + DateCreated = DateTime.UtcNow.Subtract(TimeSpan.FromHours(10)), + ParallelCorpora = [] + } + ); + + Service = new EngineCleanupService( + Substitute.For(), + Substitute.For>(), + TimeSpan.Zero + ); + } + + public EngineCleanupService Service { get; } + + public async Task CheckEnginesAsync() + { + await Service.CheckEntitiesAsync(Engines, CancellationToken.None); + } + } +} diff --git a/src/Serval/test/Serval.WordAlignment.Tests/Services/EngineServiceTests.cs b/src/Serval/test/Serval.WordAlignment.Tests/Services/EngineServiceTests.cs index c5dd57ce..40c5e929 100644 --- a/src/Serval/test/Serval.WordAlignment.Tests/Services/EngineServiceTests.cs +++ b/src/Serval/test/Serval.WordAlignment.Tests/Services/EngineServiceTests.cs @@ -414,7 +414,7 @@ await env.Service.StartBuildAsync( Id = BUILD1_ID, EngineRef = engineId, TrainOn = [new TrainingCorpus { ParallelCorpusRef = "corpus1" }], - WordAlignOn = [new TrainingCorpus { ParallelCorpusRef = "corpus1" }] + WordAlignOn = [new WordAlignmentCorpus { ParallelCorpusRef = "corpus1" }] } ); _ = env.WordAlignmentServiceClient.Received() @@ -490,7 +490,7 @@ await env.Service.StartBuildAsync( Id = BUILD1_ID, EngineRef = engineId, TrainOn = [new TrainingCorpus { ParallelCorpusRef = "corpus1" }], - WordAlignOn = [new TrainingCorpus { ParallelCorpusRef = "corpus2" }] + WordAlignOn = [new WordAlignmentCorpus { ParallelCorpusRef = "corpus2" }] } ); _ = env.WordAlignmentServiceClient.Received() diff --git a/src/Serval/test/Serval.WordAlignment.Tests/Services/PlatformServiceTests.cs b/src/Serval/test/Serval.WordAlignment.Tests/Services/PlatformServiceTests.cs index ba3b144f..d9ae3fc7 100644 --- a/src/Serval/test/Serval.WordAlignment.Tests/Services/PlatformServiceTests.cs +++ b/src/Serval/test/Serval.WordAlignment.Tests/Services/PlatformServiceTests.cs @@ -1,3 +1,4 @@ +using System.Globalization; using Serval.WordAlignment.V1; namespace Serval.WordAlignment.Services; @@ -37,7 +38,7 @@ await env.PlatformService.BuildRestarting( Assert.That(env.Engines.Get("e0").IsBuilding, Is.False); Assert.That(env.WordAlignments.Count, Is.EqualTo(0)); - await env.PlatformService.InsertInferences(new MockAsyncStreamReader("e0"), env.ServerCallContext); + await env.PlatformService.InsertWordAlignments(new MockAsyncStreamReader("e0"), env.ServerCallContext); Assert.That(env.WordAlignments.Count, Is.EqualTo(1)); await env.PlatformService.BuildFaulted(new BuildFaultedRequest() { BuildId = "b0" }, env.ServerCallContext); @@ -49,12 +50,12 @@ await env.PlatformService.BuildRestarting( new BuildRestartingRequest() { BuildId = "b0" }, env.ServerCallContext ); - await env.PlatformService.InsertInferences(new MockAsyncStreamReader("e0"), env.ServerCallContext); + await env.PlatformService.InsertWordAlignments(new MockAsyncStreamReader("e0"), env.ServerCallContext); Assert.That(env.WordAlignments.Count, Is.EqualTo(1)); await env.PlatformService.BuildCompleted(new BuildCompletedRequest() { BuildId = "b0" }, env.ServerCallContext); Assert.That(env.WordAlignments.Count, Is.EqualTo(1)); await env.PlatformService.BuildStarted(new BuildStartedRequest() { BuildId = "b0" }, env.ServerCallContext); - await env.PlatformService.InsertInferences(new MockAsyncStreamReader("e0"), env.ServerCallContext); + await env.PlatformService.InsertWordAlignments(new MockAsyncStreamReader("e0"), env.ServerCallContext); await env.PlatformService.BuildCompleted(new BuildCompletedRequest() { BuildId = "b0" }, env.ServerCallContext); Assert.That(env.WordAlignments.Count, Is.EqualTo(1)); } @@ -90,6 +91,70 @@ await env.PlatformService.UpdateBuildStatus( Assert.That(env.Builds.Get("b0").PercentCompleted, Is.EqualTo(0.5)); } + [Test] + public async Task UpdateBuildExecutionData() + { + var env = new TestEnvironment(); + + var engine = new Engine() + { + Id = "e0", + Owner = "owner1", + Type = "nmt", + SourceLanguage = "en", + TargetLanguage = "es", + ParallelCorpora = [] + }; + await env.Engines.InsertAsync(engine); + + var build = new Build() + { + Id = "123", + EngineRef = "e0", + ExecutionData = new Dictionary + { + { "trainCount", "0" }, + { "inferenceCount", "0" }, + { "staticCount", "0" } + } + }; + await env.Builds.InsertAsync(build); + + Assert.That(build.ExecutionData, Is.Not.Null); + + var executionData = build.ExecutionData; + + Assert.That(executionData, Contains.Key("trainCount")); + Assert.That(executionData, Contains.Key("inferenceCount")); + + int trainCount = Convert.ToInt32(executionData["trainCount"], CultureInfo.InvariantCulture); + int wordAlignmentCount = Convert.ToInt32(executionData["inferenceCount"], CultureInfo.InvariantCulture); + int staticCount = Convert.ToInt32(executionData["staticCount"], CultureInfo.InvariantCulture); + + Assert.That(trainCount, Is.EqualTo(0)); + Assert.That(wordAlignmentCount, Is.EqualTo(0)); + Assert.That(staticCount, Is.EqualTo(0)); + + var updateRequest = new UpdateBuildExecutionDataRequest() { BuildId = "123", EngineId = engine.Id }; + updateRequest.ExecutionData.Add( + new Dictionary { { "trainCount", "4" }, { "inferenceCount", "5" } } + ); + + await env.PlatformService.UpdateBuildExecutionData(updateRequest, env.ServerCallContext); + + build = await env.Builds.GetAsync(c => c.Id == build.Id); + + executionData = build!.ExecutionData; + + trainCount = Convert.ToInt32(executionData["trainCount"], CultureInfo.InvariantCulture); + wordAlignmentCount = Convert.ToInt32(executionData["inferenceCount"], CultureInfo.InvariantCulture); + staticCount = Convert.ToInt32(executionData["staticCount"], CultureInfo.InvariantCulture); + + Assert.That(trainCount, Is.GreaterThan(0)); + Assert.That(wordAlignmentCount, Is.GreaterThan(0)); + Assert.That(staticCount, Is.EqualTo(0)); + } + [Test] public async Task IncrementCorpusSizeAsync() { @@ -155,12 +220,12 @@ public TestEnvironment() public WordAlignmentPlatformServiceV1 PlatformService { get; } } - private class MockAsyncStreamReader(string engineId) : IAsyncStreamReader + private class MockAsyncStreamReader(string engineId) : IAsyncStreamReader { private bool _endOfStream = false; public string EngineId { get; } = engineId; - public InsertInferencesRequest Current => new() { EngineId = EngineId }; + public InsertWordAlignmentsRequest Current => new() { EngineId = EngineId }; public Task MoveNext(CancellationToken cancellationToken) { diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs index 5e5fa959..32912734 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs @@ -5,7 +5,7 @@ public interface IParallelCorpusPreprocessingService Task PreprocessAsync( IReadOnlyList corpora, Func train, - Func pretranslate, + Func pretranslate, bool useKeyTerms = false ); } diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs index 670cb7ef..d1b063ca 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs @@ -28,7 +28,7 @@ internal int Seed public async Task PreprocessAsync( IReadOnlyList corpora, Func train, - Func pretranslate, + Func inference, bool useKeyTerms = false ) { @@ -59,6 +59,11 @@ public async Task PreprocessAsync( .Select(tc => FilterTrainingCorpora(tc.Corpus, tc.TextCorpus)) .ToArray(); + ITextCorpus targetPretranslateCorpus = targetCorpora + .Select(tc => FilterPretranslateCorpora(tc.Corpus, tc.TextCorpus)) + .ToArray() + .ChooseRandom(Seed); + ITextCorpus sourceTrainingCorpus = sourceTrainingCorpora.ChooseRandom(Seed); if (sourceTrainingCorpus.IsScripture()) { @@ -80,7 +85,7 @@ public async Task PreprocessAsync( foreach (Row row in CollapseRanges(trainingRows)) { await train(row); - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) + if (!parallelTrainingDataPresent && row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) { parallelTrainingDataPresent = true; } @@ -111,14 +116,16 @@ ParallelTextRow row in parallelKeyTermsCorpus.DistinctBy(row => } ITextCorpus sourcePretranslateCorpus = sourcePretranslateCorpora.ChooseFirst(); - IParallelTextCorpus pretranslateCorpus = sourcePretranslateCorpus.AlignRows( - targetCorpus, - allSourceRows: true - ); + INParallelTextCorpus pretranslateCorpus = new ITextCorpus[] + { + sourcePretranslateCorpus, + targetPretranslateCorpus, + targetCorpus + }.AlignMany([true, false, false]); - foreach (Row row in CollapseRanges(pretranslateCorpus.ToArray())) + foreach ((Row row, bool isInTrainingData) in CollapsePretranslateRanges(pretranslateCorpus.ToArray())) { - await pretranslate(row, corpus); + await inference(row, isInTrainingData, corpus); } } if (useKeyTerms && parallelTrainingDataPresent) @@ -130,20 +137,6 @@ ParallelTextRow row in parallelKeyTermsCorpus.DistinctBy(row => } } - private static IEnumerable<(CorpusFile File, Dictionary> Chapters)> GetChaptersPerFile( - MonolingualCorpus mc, - ITextCorpus tc - ) - { - Dictionary>? chapters = mc.TrainOnChapters; - if (chapters is null && mc.TrainOnTextIds is not null) - { - chapters = mc.TrainOnTextIds.Select(tid => (tid, new HashSet { })).ToDictionary(); - } - chapters ??= tc.Texts.Select(t => (t.Id, new HashSet() { })).ToDictionary(); - return mc.Files.Select(f => (f, chapters)); - } - private static ITextCorpus FilterPretranslateCorpora(MonolingualCorpus corpus, ITextCorpus textCorpus) { textCorpus = textCorpus.Transform(CleanSegment); @@ -234,6 +227,73 @@ private static IEnumerable CollapseRanges(ParallelTextRow[] rows) } } + private static IEnumerable<(Row, bool)> CollapsePretranslateRanges(NParallelTextRow[] rows) + { + StringBuilder srcSegBuffer = new(); + StringBuilder trgSegBuffer = new(); + List refs = []; + string textId = ""; + bool hasUnfinishedRange = false; + bool isInTrainingData = false; + + foreach (NParallelTextRow row in rows) + { + //row at 0 is source filtered for pretranslation, row at 1 is target filtered for pretranslation, row at 2 is target filtered for training + if ( + hasUnfinishedRange + && (!row.IsInRange(0) || row.IsRangeStart(0)) + && (!row.IsInRange(1) || row.IsRangeStart(1)) + && (!row.IsInRange(2) || row.IsRangeStart(2)) + ) + { + yield return ( + new Row(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1), + isInTrainingData + ); + + srcSegBuffer.Clear(); + trgSegBuffer.Clear(); + refs.Clear(); + isInTrainingData = false; + hasUnfinishedRange = false; + } + + textId = row.TextId; + refs.AddRange(row.NRefs[2].Count > 0 ? row.NRefs[2] : row.NRefs[1]); + isInTrainingData = isInTrainingData || row.Text(2).Length > 0; + + if (row.Text(0).Length > 0) + { + if (srcSegBuffer.Length > 0) + srcSegBuffer.Append(' '); + srcSegBuffer.Append(row.Text(0)); + } + if (row.Text(1).Length > 0) + { + if (trgSegBuffer.Length > 0) + trgSegBuffer.Append(' '); + trgSegBuffer.Append(row.Text(1)); + } + + if (row.IsInRange(0) || row.IsInRange(1) || row.IsInRange(2)) + { + hasUnfinishedRange = true; + continue; + } + + yield return (new Row(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1), isInTrainingData); + + srcSegBuffer.Clear(); + trgSegBuffer.Clear(); + refs.Clear(); + isInTrainingData = false; + } + if (hasUnfinishedRange) + { + yield return (new Row(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1), isInTrainingData); + } + } + private static bool IsScriptureRow(TextRow parallelTextRow) { return parallelTextRow.Ref is ScriptureRef sr && sr.IsVerse; diff --git a/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs b/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs index cdd1884f..ebe1fa78 100644 --- a/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs +++ b/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs @@ -72,7 +72,7 @@ public async Task TestParallelCorpusPreprocessor() } ]; int trainCount = 0; - int pretranslateCount = 0; + int inferenceCount = 0; await processor.PreprocessAsync( corpora, row => @@ -81,18 +81,22 @@ await processor.PreprocessAsync( trainCount++; return Task.CompletedTask; }, - (row, _) => + (row, isInTrainingData, _) => { - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0) - pretranslateCount++; + if (row.SourceSegment.Length > 0 && !isInTrainingData) + { + inferenceCount++; + } + return Task.CompletedTask; }, false ); + Assert.Multiple(() => { Assert.That(trainCount, Is.EqualTo(2)); - Assert.That(pretranslateCount, Is.EqualTo(3)); + Assert.That(inferenceCount, Is.EqualTo(3)); }); } }