Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update appstack to show new create trainspace endpoint #1114

Closed
wants to merge 9 commits into from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ frontend/src/backend_outputs/*
!frontend/src/backend_outputs/my_deep_learning_model.onnx

# Terraform Files
dlp-terraform/dynamodb/terraform.tfstate
dlp-terraform/.terraform
dlp-terraform/.terraform.lock.hcl
dlp-terraform/terraform.tfstate
Expand Down
16 changes: 6 additions & 10 deletions dlp-terraform/dynamodb/dynamodb.tf
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ terraform {
}

provider "aws" {
region = "us-west-2"
region = "us-east-1"
}

resource "aws_dynamodb_table" "execution-table" {
Expand Down Expand Up @@ -93,26 +93,22 @@ resource "aws_dynamodb_table" "userprogress_table" {
}

resource "aws_dynamodb_table" "trainspace" {
name = "trainspace"
hash_key = "trainspace_id"
name = "TrainspaceTable"
billing_mode = "PROVISIONED"
hash_key = "trainspace_id"
write_capacity = 10
read_capacity = 10
attribute {
name = "trainspace_id"
type = "S"
}
attribute {
name = "uid"
name = "user_id"
type = "S"
}
ttl {
enabled = true
attribute_name = "expiryPeriod"
}
global_secondary_index {
name = "uid"
hash_key = "uid"
name = "user_id_index"
hash_key = "user_id"
write_capacity = 10
read_capacity = 10
projection_type = "ALL"
Expand Down
102 changes: 102 additions & 0 deletions serverless/packages/functions/src/trainspace/create_trainspace.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import { APIGatewayProxyHandlerV2 } from "aws-lambda";
import parseJwt from "@dlp-sst-app/core/parseJwt";
import { v4 as uuidv4 } from 'uuid';
import { DynamoDBClient } from '@aws-sdk/client-dynamodb';
import { DynamoDBDocumentClient, PutCommand, PutCommandInput } from '@aws-sdk/lib-dynamodb';
import { TrainStatus } from './constants';

export const handler: APIGatewayProxyHandlerV2 = async (event) => {
if (event) {
const user_id: string = parseJwt(event.headers.authorization ?? "")["user_id"];
const eventBody = JSON.parse(event.body? event.body : "");
const trainspaceId = uuidv4();

const datasetArray = eventBody['datasets'];
const modelArray = eventBody['models'];
const blockArray = eventBody['blocks'];

let putCommandInput: PutCommandInput = {
TableName: "TrainspaceTable",
Item:
{
trainspace_id: trainspaceId,
created: Date.now().toString(),
uid: user_id,

datasets: datasetArray.map((eventBody: {[x: string] : any;}) => ({
dataset: removeUndefinedValues(eventBody['dataset_source'] == 's3' ? {
data_source: eventBody['dataset_source']?.trim(),
dataset_id: eventBody['dataset_id']?.trim(),
s3_url: eventBody['s3_url']?.trim(),
} : {
data_source: eventBody['dataset_source']?.trim(),
dataset_id: eventBody['dataset_id']?.trim()
})
})),
models: modelArray.map((eventBody: { [x: string]: any; }) => ({
model: removeUndefinedValues({
model_name: eventBody['model_name']?.trim(),
model_id: eventBody['model_id']?.trim(),
//dataset_id: eventBody['dataset_id']?.trim(),
model_versions: {
model_name: eventBody['model_name']?.trim(),
s3_url: eventBody['model_s3_url']?.trim()
}
})
})),
blocks: blockArray.map((eventBody: {[x: string] : any;}) => ({
block: removeUndefinedValues({
block_id: eventBody['block_id']?.trim(),
block_type: eventBody['block_type']?.trim(),
//not sure what to do with "other info based on the type of block"
run_data: {
s3_url: eventBody['block_s3_url']?.trim(),
string: eventBody['run_string']?.trim(),
number: eventBody['run_number']?.trim()
}
})
})),
status: TrainStatus.QUEUED
}
}

if (putCommandInput == null)
{
return {
statusCode: 400,
body: JSON.stringify({ message: "Invalid request body" })
}
}

const client = new DynamoDBClient({});
const docClient = DynamoDBDocumentClient.from(client);

const command = new PutCommand(putCommandInput);
const response = await docClient.send(command);

if (response.$metadata.httpStatusCode != 200) {
return {
statusCode: 500,
body: JSON.stringify({ message: "Internal server error."})
};
}

return {
statusCode: 200,
body: JSON.stringify({ trainspaceId: trainspaceId, message: "Successfully created a new trainspace."})
};
}
return {
statusCode: 404,
body: JSON.stringify({ message: "Not Found" }),
};
};
function removeUndefinedValues(obj: { [key: string]: any }) {
const newObj: { [key: string]: any } = {};
for (const key in obj) {
if (obj[key] !== undefined) {
newObj[key] = obj[key];
}
}
return newObj;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export const handler: APIGatewayProxyHandlerV2 = async (event) => {

do {
const getCommand: QueryCommand = new QueryCommand({
TableName: "trainspace",
TableName: "TrainspaceTable",
IndexName: "uid",
KeyConditionExpression: "uid = :uid",
ExpressionAttributeValues: {
Expand Down
2 changes: 1 addition & 1 deletion serverless/sst.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export default {
config(_input) {
return {
name: "dlp-sst-app",
region: "us-west-2",
region: "us-east-1",
};
},
stacks(app) {
Expand Down
9 changes: 9 additions & 0 deletions serverless/stacks/AppStack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ export function AppStack({ stack }: StackContext) {
permissions: ["dynamodb:PutItem"]
}
},
//general trainspace
"POST /trainspace/create": {
function: {
handler: "packages/functions/src/trainspace/create_trainspace.handler",
permissions: ["dynamodb:PutItem"]
}
},
"GET /trainspace/{id}": {
function: {
handler: "packages/functions/src/trainspace/get_trainspace.handler",
Expand Down Expand Up @@ -87,6 +94,8 @@ export function AppStack({ stack }: StackContext) {
api.getFunction("POST /trainspace/tabular")?.functionName ?? "",
PutImageTrainspaceFunctionName:
api.getFunction("POST /trainspace/tabular")?.functionName ?? "",
CreateTrainspaceFunctionName:
api.getFunction("POST /trainspace/create")?.functionName ?? "",
GetAllTrainspaceIdsFunctionName:
api.getFunction("GET /trainspace")?.functionName ?? "",
GetTrainspaceByIdFunctionName:
Expand Down
Loading
Loading