Skip to content

Commit

Permalink
Call 'GetAsksById' once per DoWork (#504)
Browse files Browse the repository at this point in the history
* Call 'GetAsksById' once per DoWork
Fix bug (? no testing to verify, but I think some build jobs were getting double counted)

* Get queue by ID - take time from 3-5 seconds per query to < 100ms.

* Address review comments

* Address reviewer comments

---------

Co-authored-by: John Lambert <[email protected]>
  • Loading branch information
Enkidu93 and johnml1135 authored Oct 8, 2024
1 parent 52dc1ba commit 083d68e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,36 +51,32 @@ protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken
if (trainingEngines.Count == 0)
return;

Dictionary<string, ClearMLTask> tasks = new();
Dictionary<string, int> queuePositions = new();
Dictionary<string, ClearMLTask> tasks = (
await _clearMLService.GetTasksByIdAsync(
trainingEngines.Select(e => e.CurrentBuild!.JobId),
cancellationToken
)
).ToDictionary(t => t.Id);
Dictionary<TranslationEngineType, Dictionary<string, int>> queuePositionsPerEngineType = new();

foreach (TranslationEngineType engineType in _queuePerEngineType.Keys)
foreach ((TranslationEngineType engineType, string queueName) in _queuePerEngineType)
{
var tasksPerEngineType = (
await _clearMLService.GetTasksByIdAsync(
trainingEngines.Select(e => e.CurrentBuild!.JobId),
cancellationToken
)
)
.UnionBy(
await _clearMLService.GetTasksForQueueAsync(_queuePerEngineType[engineType], cancellationToken),
t => t.Id
var tasksPerEngineType = tasks
.Where(kvp =>
trainingEngines.Where(te => te.CurrentBuild?.JobId == kvp.Key).FirstOrDefault()?.Type
== engineType
)
.Select(kvp => kvp.Value)
.UnionBy(await _clearMLService.GetTasksForQueueAsync(queueName, cancellationToken), t => t.Id)
.ToDictionary(t => t.Id);
// add new keys to dictionary
foreach (KeyValuePair<string, ClearMLTask> kvp in tasksPerEngineType)
tasks.TryAdd(kvp.Key, kvp.Value);

var queuePositionsPerEngineType = tasksPerEngineType
queuePositionsPerEngineType[engineType] = tasksPerEngineType
.Values.Where(t => t.Status is ClearMLTaskStatus.Queued or ClearMLTaskStatus.Created)
.OrderBy(t => t.Created)
.Select((t, i) => (Position: i, Task: t))
.ToDictionary(e => e.Task.Name, e => e.Position);
// add new keys to dictionary
foreach (KeyValuePair<string, int> kvp in queuePositionsPerEngineType)
queuePositions.TryAdd(kvp.Key, kvp.Value);

_queueSizePerEngineType[engineType] = queuePositionsPerEngineType.Count;
_queueSizePerEngineType[engineType] = queuePositionsPerEngineType[engineType].Count;
}

var dataAccessContext = scope.ServiceProvider.GetRequiredService<IDataAccessContext>();
Expand All @@ -100,7 +96,7 @@ await UpdateTrainJobStatus(
engine.CurrentBuild.BuildId,
new ProgressStatus(step: 0, percentCompleted: 0.0),
//CurrentBuild.BuildId should always equal the corresponding task.Name
queuePositions[engine.CurrentBuild.BuildId] + 1,
queuePositionsPerEngineType[engine.Type][engine.CurrentBuild.BuildId] + 1,
cancellationToken
);
}
Expand Down
43 changes: 39 additions & 4 deletions src/Machine/src/Serval.Machine.Shared/Services/ClearMLService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ ILogger<ClearMLService> logger

private readonly IClearMLAuthenticationService _clearMLAuthService = clearMLAuthService;
private readonly ILogger<ClearMLService> _logger = logger;
private readonly AsyncLock _lock = new AsyncLock();
private ImmutableDictionary<string, string>? _queueNamesToIds = null;

public async Task<string?> GetProjectIdAsync(string name, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -145,13 +147,43 @@ public async Task<IReadOnlyList<ClearMLTask>> GetTasksForQueueAsync(
CancellationToken cancellationToken = default
)
{
var body = new JsonObject { ["name"] = queue };
JsonObject? result = await CallAsync("queues", "get_all_ex", body, cancellationToken);
var tasks = (JsonArray?)result?["data"]?["queues"]?[0]?["entries"];
IDictionary<string, string> queueNamesToIds = await PopulateQueueNamesToIdsAsync(
cancellationToken: cancellationToken
);
if (!queueNamesToIds.TryGetValue(queue, out string? queueId))
{
queueNamesToIds = await PopulateQueueNamesToIdsAsync(refresh: true, cancellationToken);
if (!queueNamesToIds.TryGetValue(queue, out queueId))
{
throw new InvalidOperationException($"Queue {queue} does not exist");
}
}
var body = new JsonObject { ["queue"] = queueId };
JsonObject? result = await CallAsync("queues", "get_by_id", body, cancellationToken);
var tasks = (JsonArray?)result?["data"]?["queue"]?["entries"];
IEnumerable<string> taskIds = tasks?.Select(t => (string)t?["id"]!) ?? new List<string>();
return await GetTasksByIdAsync(taskIds, cancellationToken);
}

private async Task<IDictionary<string, string>> PopulateQueueNamesToIdsAsync(
bool refresh = false,
CancellationToken cancellationToken = default
)
{
using (await _lock.LockAsync(cancellationToken))
{
if (!refresh && _queueNamesToIds != null)
return _queueNamesToIds;
JsonObject? result = await CallAsync("queues", "get_all", new JsonObject(), cancellationToken);
var queues = (JsonArray?)result?["data"]?["queues"];
if (queues is null)
throw new InvalidOperationException("Malformed response from ClearML server.");

_queueNamesToIds = queues.ToImmutableDictionary(q => (string)q!["name"]!, q => (string)q!["id"]!);
}
return _queueNamesToIds;
}

public async Task<ClearMLTask?> GetTaskByNameAsync(string name, CancellationToken cancellationToken = default)
{
IReadOnlyList<ClearMLTask> tasks = await GetTasksAsync(new JsonObject { ["name"] = name }, cancellationToken);
Expand All @@ -165,7 +197,10 @@ public Task<IReadOnlyList<ClearMLTask>> GetTasksByIdAsync(
CancellationToken cancellationToken = default
)
{
return GetTasksAsync(new JsonObject { ["id"] = JsonValue.Create(ids.ToArray()) }, cancellationToken);
string[] idArray = ids.ToArray();
if (!idArray.Any())
return Task.FromResult(Array.Empty<ClearMLTask>() as IReadOnlyList<ClearMLTask>);
return GetTasksAsync(new JsonObject { ["id"] = JsonValue.Create(idArray) }, cancellationToken);
}

private async Task<IReadOnlyList<ClearMLTask>> GetTasksAsync(
Expand Down
1 change: 1 addition & 0 deletions src/Machine/src/Serval.Machine.Shared/Usings.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
global using System.Collections.Concurrent;
global using System.Collections.Immutable;
global using System.Data;
global using System.Diagnostics;
global using System.Formats.Tar;
Expand Down

0 comments on commit 083d68e

Please sign in to comment.