Skip to content

Commit

Permalink
Add compute support
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Aug 22, 2024
1 parent 85401b9 commit 7a623e2
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 29 deletions.
20 changes: 4 additions & 16 deletions Sources/backends/d3d11.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,8 @@ static const char *shaderString(shader_stage stage, int version) {
return "vs_4_0";
case SHADER_STAGE_FRAGMENT:
return "ps_4_0";
/*case EShLangGeometry:
return "gs_4_0";
case EShLangTessControl:
return "unsupported";
case EShLangTessEvaluation:
return "unsupported";
case EShLangCompute:
return "cs_4_0";*/
case SHADER_STAGE_COMPUTE:
return "cs_4_0";
}
}
else if (version == 5) {
Expand All @@ -40,14 +34,8 @@ static const char *shaderString(shader_stage stage, int version) {
return "vs_5_0";
case SHADER_STAGE_FRAGMENT:
return "ps_5_0";
/*case EShLangGeometry:
return "gs_5_0";
case EShLangTessControl:
return "hs_5_0";
case EShLangTessEvaluation:
return "ds_5_0";
case EShLangCompute:
return "cs_5_0";*/
case SHADER_STAGE_COMPUTE:
return "cs_5_0";
}
}

Expand Down
10 changes: 2 additions & 8 deletions Sources/backends/d3d12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,8 @@ static const wchar_t *shader_string(shader_stage stage) {
return L"vs_6_0";
case SHADER_STAGE_FRAGMENT:
return L"ps_6_0";
/*case EShLangGeometry:
return L"gs_6_0";
case EShLangTessControl:
return L"hs_6_0";
case EShLangTessEvaluation:
return L"ds_6_0";
case EShLangCompute:
return L"cs_6_0";*/
case SHADER_STAGE_COMPUTE:
return L"cs_6_0";
default: {
debug_context context = {0};
error(context, "Unsupported shader stage/version combination");
Expand Down
74 changes: 72 additions & 2 deletions Sources/backends/hlsl.c
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,11 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
*offset += sprintf(&hlsl[*offset], "%s main(", type_string(f->return_type.type));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 ") {\n", type_string(f->parameter_types[parameter_index].type),
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64, type_string(f->parameter_types[parameter_index].type),
parameter_ids[parameter_index]);
}
else {
*offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64 ") {\n", type_string(f->parameter_types[parameter_index].type),
*offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type),
parameter_ids[parameter_index]);
}
}
Expand Down Expand Up @@ -344,6 +344,20 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
*offset += sprintf(&hlsl[*offset], ") : SV_Target0 {\n");
}
}
else if (stage == SHADER_STAGE_COMPUTE) {
*offset += sprintf(&hlsl[*offset], "[numthreads(64, 1, 1)] %s main(", type_string(f->return_type.type));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
*offset +=
sprintf(&hlsl[*offset], "%s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
else {
*offset +=
sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
}
else {
debug_context context = {0};
error(context, "Unsupported shader stage");
Expand Down Expand Up @@ -537,6 +551,47 @@ static void hlsl_export_fragment(char *directory, api_kind d3d, function *main)
write_bytecode(hlsl, directory, filename, var_name, output, output_size);
}

static void hlsl_export_compute(char *directory, api_kind d3d, function *main) {
char *hlsl = (char *)calloc(1024 * 1024, 1);
size_t offset = 0;

write_types(hlsl, &offset, SHADER_STAGE_COMPUTE, NO_TYPE, NO_TYPE, main);

write_globals(hlsl, &offset, main);

write_functions(hlsl, &offset, SHADER_STAGE_COMPUTE, main);

debug_context context = {0};

uint8_t *output = NULL;
size_t output_size = 0;
int result = 1;
switch (d3d) {
case API_DIRECT3D9:
error(context, "Compute shaders are not supported in Direct3D 9");
break;
case API_DIRECT3D11:
result = compile_hlsl_to_d3d11(hlsl, &output, &output_size, SHADER_STAGE_COMPUTE, false);
break;
case API_DIRECT3D12:
result = compile_hlsl_to_d3d12(hlsl, &output, &output_size, SHADER_STAGE_COMPUTE, false);
break;
default:
error(context, "Unsupported API for HLSL");
}
check(result == 0, context, "HLSL compilation failed");

char *name = get_name(main->name);

char filename[512];
sprintf(filename, "kong_%s", name);

char var_name[256];
sprintf(var_name, "%s_code", name);

write_bytecode(hlsl, directory, filename, var_name, output, output_size);
}

void hlsl_export(char *directory, api_kind d3d) {
int cbuffer_index = 0;
int texture_index = 0;
Expand Down Expand Up @@ -601,11 +656,26 @@ void hlsl_export(char *directory, api_kind d3d) {
}
}

function *compute_shaders[256];
size_t compute_shaders_size = 0;

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->attribute == add_name("compute")) {
compute_shaders[compute_shaders_size] = f;
compute_shaders_size += 1;
}
}

for (size_t i = 0; i < vertex_shaders_size; ++i) {
hlsl_export_vertex(directory, d3d, vertex_shaders[i]);
}

for (size_t i = 0; i < fragment_shaders_size; ++i) {
hlsl_export_fragment(directory, d3d, fragment_shaders[i]);
}

for (size_t i = 0; i < compute_shaders_size; ++i) {
hlsl_export_compute(directory, d3d, compute_shaders[i]);
}
}
33 changes: 31 additions & 2 deletions Sources/integrations/kinc.c
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ void kinc_export(char *directory, api_kind api) {
}
}

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->attribute == add_name("compute")) {
fprintf(output, "extern kinc_g4_compute_shader %s;\n\n", get_name(f->name));
}
}

fprintf(output, "#endif\n");

fclose(output);
Expand Down Expand Up @@ -445,6 +452,13 @@ void kinc_export(char *directory, api_kind api) {
}
}
}

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->attribute == add_name("compute")) {
fprintf(output, "#include \"kong_%s.h\"\n", get_name(f->name));
}
}
}

fprintf(output, "\n#include <kinc/graphics4/graphics.h>\n\n");
Expand Down Expand Up @@ -531,6 +545,13 @@ void kinc_export(char *directory, api_kind api) {
}
}

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->attribute == add_name("compute")) {
fprintf(output, "kinc_g4_compute_shader %s;\n", get_name(f->name));
}
}

if (api == API_WEBGPU) {
fprintf(output, "\nvoid kinc_g5_internal_webgpu_create_shader_module(const void *source, size_t length);\n");
}
Expand Down Expand Up @@ -677,11 +698,11 @@ void kinc_export(char *directory, api_kind api) {
for (size_t j = 0; j < t->members.size; ++j) {
if (api == API_OPENGL) {
fprintf(output, "\tkinc_g4_vertex_structure_add(&%s_structure, \"%s_%s\", %s);\n", get_name(t->name), get_name(t->name),
get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type));
get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type));
}
else {
fprintf(output, "\tkinc_g4_vertex_structure_add(&%s_structure, \"%s\", %s);\n", get_name(t->name),
get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type));
get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type));
}
}
fprintf(output, "\n");
Expand Down Expand Up @@ -723,6 +744,14 @@ void kinc_export(char *directory, api_kind api) {
}
}
}

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->attribute == add_name("compute")) {
fprintf(output, "\tkinc_g4_compute_shader_init(&%s, %s_code, %s_code_size);\n", get_name(f->name), get_name(f->name), get_name(f->name));
}
}

fprintf(output, "}\n");

fclose(output);
Expand Down
2 changes: 1 addition & 1 deletion Sources/shader_stage.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#pragma once

typedef enum shader_stage { SHADER_STAGE_VERTEX, SHADER_STAGE_FRAGMENT } shader_stage;
typedef enum shader_stage { SHADER_STAGE_VERTEX, SHADER_STAGE_FRAGMENT, SHADER_STAGE_COMPUTE } shader_stage;

0 comments on commit 7a623e2

Please sign in to comment.