Skip to content

Commit

Permalink
fix: rpc round trip
Browse files Browse the repository at this point in the history
  • Loading branch information
kevmo314 committed Sep 13, 2024
1 parent bdaeeb9 commit 662cf6a
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 80 deletions.
10 changes: 10 additions & 0 deletions api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef SCUDA_API_H
#define SCUDA_API_H

#define RPC_nvmlInitWithFlags 0
#define RPC_nvmlInit_v2 1
#define RPC_nvmlShutdown 2

#define RPC_nvmlDeviceGetName 3

#endif
115 changes: 69 additions & 46 deletions client.cu
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#define _GNU_SOURCE

#include <arpa/inet.h>
#include <stdio.h>
#include <dlfcn.h>
#include <string.h>
#include <nvml.h>
#include <unistd.h>
#include <pthread.h>
#include <vector>

#include "api.h"

int sockfd;

int open_rpc_client()
{
int connfd;
struct sockaddr_in servaddr, cli;
struct sockaddr_in servaddr;

if (sockfd != 0)
{
Expand All @@ -39,7 +39,11 @@ int open_rpc_client()
return sockfd;
}

nvmlReturn_t send_rpc_message(void **response, int *len, const char *op, const void *args, const int argslen)
// TODO: can this be done via a template?
nvmlReturn_t send_rpc_message(
const unsigned int op,
std::vector<std::pair<const void *, const int>> requests = {},
std::vector<std::pair<void *, const int>> responses = {})
{
static int next_request_id = 0, active_response_id = -1;
static pthread_mutex_t mutex;
Expand All @@ -50,57 +54,76 @@ nvmlReturn_t send_rpc_message(void **response, int *len, const char *op, const v

int request_id = next_request_id++;

uint8_t oplen = (uint8_t)strlen(op);
if (write(sockfd, (void *)&request_id, sizeof(int)) < 0)
printf("Sending request %d\n", request_id);
if (write(sockfd, &request_id, sizeof(int)) < 0 ||
write(sockfd, &op, sizeof(unsigned int)) < 0)
{
pthread_mutex_unlock(&mutex);
return NVML_ERROR_GPU_IS_LOST;
}

if (write(sockfd, (void *)&oplen, sizeof(uint8_t)) < 0)
return NVML_ERROR_GPU_IS_LOST;
if (write(sockfd, op, oplen) < 0)
return NVML_ERROR_GPU_IS_LOST;
if (write(sockfd, (void *)&argslen, sizeof(int)) < 0)
return NVML_ERROR_GPU_IS_LOST;
if (write(sockfd, args, argslen) < 0)
return NVML_ERROR_GPU_IS_LOST;
printf("Sending %lu requests\n", requests.size());

for (auto r : requests)
if (write(sockfd, r.first, r.second) < 0)
{
pthread_mutex_unlock(&mutex);
return NVML_ERROR_GPU_IS_LOST;
}

// wait for the response
while (true)
{
printf("Waiting for response %d\n", request_id);
while (active_response_id != request_id && active_response_id != -1)
pthread_cond_wait(&cond, &mutex);

printf("Got response active %d\n", active_response_id);

// we currently own mutex. if active response id is -1, read the response id
if (active_response_id == -1)
{
if (read(sockfd, (void *)&active_response_id, sizeof(int)) < 0)
return NVML_ERROR_GPU_IS_LOST;
continue;
}
else
{
// it's our turn to read the response.
nvmlReturn_t ret;
if (read(sockfd, (void *)&ret, sizeof(nvmlReturn_t)) < 0)
return NVML_ERROR_GPU_IS_LOST;
if (ret != NVML_SUCCESS || response == NULL)
if (read(sockfd, &active_response_id, sizeof(int)) < 0)
{
pthread_mutex_unlock(&mutex);
return ret;
return NVML_ERROR_GPU_IS_LOST;
}

if (read(sockfd, (void *)len, sizeof(int)) < 0)
return NVML_ERROR_GPU_IS_LOST;
if (*len > 0)
printf("Got response id %d\n", active_response_id);

if (active_response_id != request_id)
{
*response = malloc(*len);
if (read(sockfd, *response, *len) < 0)
return NVML_ERROR_GPU_IS_LOST;
pthread_cond_broadcast(&cond);
continue;
}
}

active_response_id = -1;

// we are done, unlock and return.
printf("Reading response %d\n", request_id);

// it's our turn to read the response.
nvmlReturn_t ret;
if (read(sockfd, &ret, sizeof(nvmlReturn_t)) < 0 || ret != NVML_SUCCESS)
{
pthread_mutex_unlock(&mutex);
return ret;
return NVML_ERROR_GPU_IS_LOST;
}

printf("Reading %lu responses\n", responses.size());

for (auto r : responses)
if (read(sockfd, r.first, r.second) < 0)
{
pthread_mutex_unlock(&mutex);
return NVML_ERROR_GPU_IS_LOST;
}

printf("done!\n");

// we are done, unlock and return.
pthread_mutex_unlock(&mutex);
return ret;
}
}

Expand All @@ -113,39 +136,39 @@ void close_rpc_client()
nvmlReturn_t nvmlInitWithFlags(unsigned int flags)
{
open_rpc_client();
return send_rpc_message(NULL, NULL, "nvmlInitWithFlags", (void *)&flags, sizeof(unsigned int));
return send_rpc_message(RPC_nvmlInitWithFlags, {{&flags, sizeof(unsigned int)}});
}

nvmlReturn_t nvmlInit_v2()
{
open_rpc_client();
return send_rpc_message(NULL, NULL, "nvmlInit_v2", NULL, 0);
return send_rpc_message(RPC_nvmlInit_v2);
}

nvmlReturn_t nvmlShutdown()
{
open_rpc_client();
return send_rpc_message(NULL, NULL, "nvmlShutdown", NULL, 0);
return send_rpc_message(RPC_nvmlShutdown);
}

nvmlReturn_t nvmlDeviceGetName(nvmlDevice_t device, char *name, unsigned int length)
{
open_rpc_client();
return send_rpc_message((void **)&name, (int *)&length, "nvmlDeviceGetName", (void *)&device, sizeof(nvmlDevice_t));
return send_rpc_message(RPC_nvmlDeviceGetName, {{&device, sizeof(nvmlDevice_t)}, {&length, sizeof(int)}}, {{name, length}});
}

void *dlsym(void *handle, const char *name)
void *dlsym(void *handle, const char *name) __THROW
{
printf("Resolving symbol: %s\n", name);

if (!strcmp(name, "nvmlInitWithFlags"))
return (void *)nvmlInitWithFlags;
if (!strcmp(name, "nvmlInit_v2"))
return (void *)nvmlInit_v2;
if (!strcmp(name, "nvmlShutdown"))
return (void *)nvmlShutdown;
if (!strcmp(name, "nvmlDeviceGetName"))
return (void *)nvmlDeviceGetName;
// if (!strcmp(name, "nvmlInit_v2"))
// return (void *)nvmlInit_v2;
// if (!strcmp(name, "nvmlShutdown"))
// return (void *)nvmlShutdown;
// if (!strcmp(name, "nvmlDeviceGetName"))
// return (void *)nvmlDeviceGetName;

static void *(*real_dlsym)(void *, const char *) = NULL;
if (real_dlsym == NULL)
Expand Down
2 changes: 1 addition & 1 deletion local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

libscuda_path="$(pwd)/libscuda.so"
client_path="$(pwd)/client.cu"
server_path="$(pwd)/server.c"
server_path="$(pwd)/server.cu"
server_out_path="$(pwd)/server"

build() {
Expand Down
90 changes: 57 additions & 33 deletions server.c → server.cu
Original file line number Diff line number Diff line change
@@ -1,78 +1,90 @@
#define _GNU_SOURCE

#include <arpa/inet.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <pthread.h>
#include <unistd.h>
#include <nvml.h>
#include <sys/socket.h>

#include "api.h"

#define PORT 14833
#define MAX_CLIENTS 10

int sockfd;

void nvmlInitWithFlagsHandler(int connfd, unsigned int flags) {
void nvmlInitWithFlagsHandler(int connfd, unsigned int flags)
{
nvmlReturn_t result = nvmlInitWithFlags(flags);
write(connfd, &result, sizeof(result));
}

void nvmlShutdownHandler(int connfd) {
void nvmlShutdownHandler(int connfd)
{
nvmlReturn_t result = nvmlShutdown();
write(connfd, &result, sizeof(result));
}

void nvmlDeviceGetNameHandler(int connfd, nvmlDevice_t device) {
void nvmlDeviceGetNameHandler(int connfd, nvmlDevice_t device)
{
char name[NVML_DEVICE_NAME_BUFFER_SIZE];
nvmlReturn_t result = nvmlDeviceGetName(device, name, NVML_DEVICE_NAME_BUFFER_SIZE);
write(connfd, &result, sizeof(result));
if (result == NVML_SUCCESS) {
if (result == NVML_SUCCESS)
{
int len = strlen(name) + 1; // including null terminator
write(connfd, &len, sizeof(len));
write(connfd, name, len);
}
}

void *client_handler(void *arg) {
void *client_handler(void *arg)
{
int connfd = *(int *)arg;
free(arg);

char operation[64];
int oplen;
unsigned int op;
int request_id;
int argslen;

while (read(connfd, &oplen, sizeof(oplen)) > 0) {
// Read the operation name
if (read(connfd, operation, oplen) <= 0) {
break;
}
operation[oplen] = '\0'; // Null-terminate operation string
while (read(connfd, &request_id, sizeof(int)) >= 0)
{
if (read(connfd, &op, sizeof(unsigned int)) < 0)
goto exit;

if (write(connfd, &request_id, sizeof(int)) < 0)
goto exit;

// Handle different NVML operations
if (strcmp(operation, "nvmlInitWithFlags") == 0) {
switch (op)
{
case RPC_nvmlInitWithFlags:
{
unsigned int flags;
read(connfd, &argslen, sizeof(argslen));
read(connfd, &flags, argslen);
nvmlInitWithFlagsHandler(connfd, flags);
} else if (strcmp(operation, "nvmlShutdown") == 0) {
nvmlShutdownHandler(connfd);
} else if (strcmp(operation, "nvmlDeviceGetName") == 0) {
nvmlDevice_t device;
read(connfd, &argslen, sizeof(argslen));
read(connfd, &device, argslen);
nvmlDeviceGetNameHandler(connfd, device);
if (read(connfd, &flags, sizeof(unsigned int)) < 0)
goto exit;
printf("Received nvmlInitWithFlags request %d %d\n", request_id, flags);
nvmlReturn_t result = nvmlInitWithFlags(flags);
printf("Sending nvmlInitWithFlags response %d %d\n", request_id, result);
if (write(connfd, &result, sizeof(nvmlReturn_t)) < 0)
goto exit;
break;
}
}
}

exit:
close(connfd);
pthread_exit(NULL);
}

int main() {
int main()
{
struct sockaddr_in servaddr, cli;
sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd == -1) {
if (sockfd == -1)
{
printf("Socket creation failed.\n");
exit(EXIT_FAILURE);
}
Expand All @@ -82,25 +94,37 @@ int main() {
servaddr.sin_family = AF_INET;
servaddr.sin_addr.s_addr = INADDR_ANY;
servaddr.sin_port = htons(PORT);
if (bind(sockfd, (struct sockaddr *)&servaddr, sizeof(servaddr)) != 0) {

const int enable = 1;
if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)) < 0)
{
printf("Socket bind failed.\n");
exit(EXIT_FAILURE);
}

if (bind(sockfd, (struct sockaddr *)&servaddr, sizeof(servaddr)) != 0)
{
printf("Socket bind failed.\n");
exit(EXIT_FAILURE);
}

// Listen for clients
if (listen(sockfd, MAX_CLIENTS) != 0) {
if (listen(sockfd, MAX_CLIENTS) != 0)
{
printf("Listen failed.\n");
exit(EXIT_FAILURE);
}

printf("Server listening on port %d...\n", PORT);

// Server loop
while (1) {
while (1)
{
socklen_t len = sizeof(cli);
int *connfd = malloc(sizeof(int));
int *connfd = (int *)malloc(sizeof(int));
*connfd = accept(sockfd, (struct sockaddr *)&cli, &len);
if (*connfd < 0) {
if (*connfd < 0)
{
printf("Server accept failed.\n");
free(connfd);
continue;
Expand Down

0 comments on commit 662cf6a

Please sign in to comment.