Skip to content

Commit

Permalink
Merge pull request #1052 from benptc/splat-training-authentication
Browse files Browse the repository at this point in the history
adds authentication and error handling to SplatTasks
  • Loading branch information
benptc authored Apr 24, 2024
2 parents 2f8a933 + cc8ea58 commit 7cdf59c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
13 changes: 9 additions & 4 deletions controllers/object.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ const path = require('path');
const formidable = require('formidable');
const utilities = require('../libraries/utilities');
const {fileExists, unlinkIfExists, mkdirIfNotExists} = utilities;
const {startSplatTask} = require('./object/SplatTask.js');
const {startSplatTask, splatTasks} = require('./object/SplatTask.js');
const { beatPort } = require('../config.js');

// Variables populated from server.js with setup()
Expand Down Expand Up @@ -387,18 +387,23 @@ const generateXml = async function(objectID, body, callback) {

/**
* @param {string} objectId
* @param {string|undefined} credentials - Bearer <JWT>
* @return {{done: boolean, gaussianSplatRequestId: string|undefined}} result
*/
async function requestGaussianSplatting(objectId) {
async function requestGaussianSplatting(objectId, credentials) {
const object = utilities.getObject(objects, objectId);
if (!object) {
throw new Error('Object not found');
}

let splatTask = await startSplatTask(object);
let splatTask = await startSplatTask(object, credentials);
let status = splatTask.getStatus();
if (status.error && !status.gaussianSplatRequestId) {
delete splatTasks[object.objectId]; // tasks resulting in error shouldn't be saved as oldTasks that can be resumed
}
// Starting splat task can modify object
await utilities.writeObjectToFile(objects, objectId, globalVariables.saveToDisk);
return splatTask.getStatus();
return status;
}

/**
Expand Down
44 changes: 37 additions & 7 deletions controllers/object/SplatTask.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@ const {fileExists} = utilities;

const SPLAT_HOST = 'change me:3000';

if (SPLAT_HOST.includes('change me')) {
console.warn('Edit the SPLAT_HOST if you want to enable Gaussian Splat training');
}

/**
* A class for starting and monitoring the progress of an area target to splat
* conversion problem, persisting the resulting file if successful
*/
class SplatTask {
/**
* @param {ObjectModel} object
* @param {string|null} credentials - Bearer <JWT>
*/
constructor(object) {
constructor(object, credentials) {
this.object = object;
this.credentials = credentials;
this.gaussianSplatRequestId = null;
this.error = null;
if (this.object.gaussianSplatRequestId) {
this.gaussianSplatRequestId = this.object.gaussianSplatRequestId;
}
Expand Down Expand Up @@ -62,13 +69,34 @@ class SplatTask {
method: 'POST',
headers: {
...form.getHeaders(),
Authorization: this.credentials,
},
body: form,
});

const gaussianSplatRequestId = await res.text();
this.object.gaussianSplatRequestId = gaussianSplatRequestId;
this.gaussianSplatRequestId = gaussianSplatRequestId;
let responseText = null;
try {
responseText = await res.text();
} catch (e) {
console.warn(`error parsing SplatTask /upload response (status ${res.status})`);
this.error = 'Unable to process the training server\'s response.';
return;
}

if (res.status === 200) {
const gaussianSplatRequestId = responseText;
this.object.gaussianSplatRequestId = gaussianSplatRequestId;
this.gaussianSplatRequestId = gaussianSplatRequestId;
} else {
// response is a string if successful, a JSON {error: 'reason'} if not
try {
this.error = JSON.parse(responseText).error;
} catch (e) {
console.warn(`error parsing SplatTask /upload response (status ${res.status})`);
this.error = 'Unable to process the training server\'s response.';
}
return null;
}
}

this.openSocket();
Expand Down Expand Up @@ -102,12 +130,13 @@ class SplatTask {
}

/**
* @return {{done: boolean, gaussianSplatRequestId: string|undefined}} splat status
* @return {{done: boolean, gaussianSplatRequestId: string|null, error: string|null}} splat status
*/
getStatus() {
return {
done: this.done,
gaussianSplatRequestId: this.gaussianSplatRequestId,
error: this.error,
};
}

Expand Down Expand Up @@ -149,15 +178,16 @@ module.exports.splatTasks = splatTasks;

/**
* @param {ObjectModel} object
* @param {string|null} credentials
*/
module.exports.startSplatTask = async function startSplatTask(object) {
module.exports.startSplatTask = async function startSplatTask(object, credentials) {
const objectId = object.objectId;
const oldTask = splatTasks[object.objectId];
if (oldTask) {
return oldTask;
}

splatTasks[objectId] = new SplatTask(object);
splatTasks[objectId] = new SplatTask(object, credentials);
// Kick off the splatting to the point where we get a request id
await splatTasks[objectId].start();
return splatTasks[objectId];
Expand Down
5 changes: 3 additions & 2 deletions routers/object.js
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,10 @@ router.post('/:objectName/requestGaussianSplatting/', async function (req, res)
return;
}
// splat status (commonly referred to as "splattus") is
// {done: boolean, gaussianSplatRequestId: string|undefined}
// {done: boolean, gaussianSplatRequestId: string|undefined, error: string|null}
try {
const splatStatus = await objectController.requestGaussianSplatting(req.params.objectName);
// req.headers.authorization is expected to be `Bearer ${JWT}`
const splatStatus = await objectController.requestGaussianSplatting(req.params.objectName, req.headers.authorization);
res.json(splatStatus);
} catch (e) {
console.error(e);
Expand Down

0 comments on commit 7cdf59c

Please sign in to comment.