Skip to content

Commit

Permalink
Add configure command
Browse files Browse the repository at this point in the history
Useful for benchmarking mostly
  • Loading branch information
jelmervdl committed May 4, 2022
1 parent 50db857 commit cf92a04
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
10 changes: 10 additions & 0 deletions scripts/native_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ async def translate(self, text, src=None, trg=None, *, model=None, pivot=None, h
async def download_model(self, model_id, *, update=lambda data: None):
return await self.request("DownloadModel", {"modelID": str(model_id)}, update=update)

async def configure(self, *, threads:int = None, cache_size:int = None):
options = {}
if threads is not None:
options["threads"] = int(threads)
if cache_size is not None:
options["cacheSize"] = int(cache_size)
return await self.request("Configure", options)


def first(iterable, *default):
"""Returns the first value of anything iterable, or throws StopIteration
Expand Down Expand Up @@ -186,6 +194,8 @@ async def test():
if model["id"] == selected_model["id"]
)

await tl.configure(threads=1, cache_size=0)

# Perform some translations, switching between the models
translations = await asyncio.gather(
tl.translate("Hello world!", "en", "de"),
Expand Down
23 changes: 21 additions & 2 deletions src/cli/NativeMsgIface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,17 @@ void NativeMsgIface::handleRequest(ListRequest request) {
writeResponse(request, modelsJson);
}

void NativeMsgIface::handleRequest(ConfigureRequest request) {
marian::bergamot::AsyncService::Config serviceConfig;
serviceConfig.numWorkers = request.threads;
serviceConfig.cacheSize = request.cacheSize;
service_.reset();
service_ = std::make_shared<marian::bergamot::AsyncService>(serviceConfig);

QJsonObject response{}; // I don't know...
writeResponse(request, response);
}

void NativeMsgIface::handleRequest(DownloadRequest request) {
// Edge case: client issued a DownloadRequest before fetching the list of
// remote models because it knows the model ID from a previous run. We still
Expand Down Expand Up @@ -284,7 +295,7 @@ request_variant NativeMsgIface::parseJsonInput(QByteArray input) {

// Define what are mandatory and what are optional request keys
static const QStringList mandatoryKeys({"command", "id", "data"}); // Expected in every message
static const QSet<QString> commandTypes({"ListModels", "DownloadModel", "Translate"});
static const QSet<QString> commandTypes({"ListModels", "DownloadModel", "Translate", "Configure"});
// Json doesn't have schema validation, so validate here, in place:
QString command;
int id;
Expand Down Expand Up @@ -359,13 +370,21 @@ request_variant NativeMsgIface::parseJsonInput(QByteArray input) {
ret.id = id;
for (auto&& key : mandatoryKeysDownload) {
QJsonValueRef val = data[key];
if (val.isNull()) {
if (val.isNull() || val.isUndefined()) {
return MalformedRequest{id, QString("data field key %1 cannot be null!").arg(key)};
} else {
ret.modelID = val.toString();
}
}
return ret;
} else if (command == "Configure") {
ConfigureRequest ret;
ret.id = id;
if (!data["threads"].isUndefined())
ret.threads = data["threads"].toInt();
if (!data["cacheSize"].isUndefined())
ret.cacheSize = data["cacheSize"].toInt();
return ret;
} else {
return MalformedRequest{id, QString("Developer error. We shouldn't ever be here! Command: %1").arg(command)};
}
Expand Down
33 changes: 32 additions & 1 deletion src/cli/NativeMsgIface.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,41 @@ struct DownloadRequest : Request {

Q_DECLARE_METATYPE(DownloadRequest);

/**
* Change TranslateLocally resource usage for this session.
*
* Request:
* {
* "id": int,
* "command": "Configure",
* "data": {
* "threads": int
* "cacheSize": int (0 means disabled)
* }
* }
*
* Successful response:
* {
* "id": int,
* "success": true,
* "data": {}
* }
*/
struct ConfigureRequest : Request {
int threads;
int cacheSize;
};

Q_DECLARE_METATYPE(ConfigureRequest);

/**
* Internal structure to handle a request that is missing a required field.
*/
struct MalformedRequest : Request {
QString error;
};

using request_variant = std::variant<TranslationRequest, ListRequest, DownloadRequest, MalformedRequest>;
using request_variant = std::variant<TranslationRequest, ListRequest, DownloadRequest, ConfigureRequest, MalformedRequest>;

/**
* Internal structure to cache a loaded direct model (i.e. no pivoting)
Expand Down Expand Up @@ -410,6 +437,10 @@ private slots:
*/
void handleRequest(DownloadRequest myJsonInput);

/**
*/
void handleRequest(ConfigureRequest myJsonInput);

/**
* @brief handleRequest handles a request type MalformedRequest and writes to stdout
* @param myJsonInput MalformedRequest
Expand Down

0 comments on commit cf92a04

Please sign in to comment.