diff --git a/Sources/backends/d3d11.c b/Sources/backends/d3d11.c index fe34cad..2bdb096 100644 --- a/Sources/backends/d3d11.c +++ b/Sources/backends/d3d11.c @@ -24,14 +24,8 @@ static const char *shaderString(shader_stage stage, int version) { return "vs_4_0"; case SHADER_STAGE_FRAGMENT: return "ps_4_0"; - /*case EShLangGeometry: - return "gs_4_0"; - case EShLangTessControl: - return "unsupported"; - case EShLangTessEvaluation: - return "unsupported"; - case EShLangCompute: - return "cs_4_0";*/ + case SHADER_STAGE_COMPUTE: + return "cs_4_0"; } } else if (version == 5) { @@ -40,14 +34,8 @@ static const char *shaderString(shader_stage stage, int version) { return "vs_5_0"; case SHADER_STAGE_FRAGMENT: return "ps_5_0"; - /*case EShLangGeometry: - return "gs_5_0"; - case EShLangTessControl: - return "hs_5_0"; - case EShLangTessEvaluation: - return "ds_5_0"; - case EShLangCompute: - return "cs_5_0";*/ + case SHADER_STAGE_COMPUTE: + return "cs_5_0"; } } diff --git a/Sources/backends/d3d12.cpp b/Sources/backends/d3d12.cpp index a063f42..6367730 100644 --- a/Sources/backends/d3d12.cpp +++ b/Sources/backends/d3d12.cpp @@ -21,14 +21,8 @@ static const wchar_t *shader_string(shader_stage stage) { return L"vs_6_0"; case SHADER_STAGE_FRAGMENT: return L"ps_6_0"; - /*case EShLangGeometry: - return L"gs_6_0"; - case EShLangTessControl: - return L"hs_6_0"; - case EShLangTessEvaluation: - return L"ds_6_0"; - case EShLangCompute: - return L"cs_6_0";*/ + case SHADER_STAGE_COMPUTE: + return L"cs_6_0"; default: { debug_context context = {0}; error(context, "Unsupported shader stage/version combination"); diff --git a/Sources/backends/hlsl.c b/Sources/backends/hlsl.c index 1311338..3269936 100644 --- a/Sources/backends/hlsl.c +++ b/Sources/backends/hlsl.c @@ -298,11 +298,11 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func *offset += sprintf(&hlsl[*offset], "%s main(", type_string(f->return_type.type)); for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) { if (parameter_index == 0) { - *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 ") {\n", type_string(f->parameter_types[parameter_index].type), + *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]); } else { - *offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64 ") {\n", type_string(f->parameter_types[parameter_index].type), + *offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]); } } @@ -344,6 +344,20 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func *offset += sprintf(&hlsl[*offset], ") : SV_Target0 {\n"); } } + else if (stage == SHADER_STAGE_COMPUTE) { + *offset += sprintf(&hlsl[*offset], "[numthreads(64, 1, 1)] %s main(", type_string(f->return_type.type)); + for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) { + if (parameter_index == 0) { + *offset += + sprintf(&hlsl[*offset], "%s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]); + } + else { + *offset += + sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]); + } + } + *offset += sprintf(&hlsl[*offset], ") {\n"); + } else { debug_context context = {0}; error(context, "Unsupported shader stage"); @@ -537,6 +551,47 @@ static void hlsl_export_fragment(char *directory, api_kind d3d, function *main) write_bytecode(hlsl, directory, filename, var_name, output, output_size); } +static void hlsl_export_compute(char *directory, api_kind d3d, function *main) { + char *hlsl = (char *)calloc(1024 * 1024, 1); + size_t offset = 0; + + write_types(hlsl, &offset, SHADER_STAGE_COMPUTE, NO_TYPE, NO_TYPE, main); + + write_globals(hlsl, &offset, main); + + write_functions(hlsl, &offset, SHADER_STAGE_COMPUTE, main); + + debug_context context = {0}; + + uint8_t *output = NULL; + size_t output_size = 0; + int result = 1; + switch (d3d) { + case API_DIRECT3D9: + error(context, "Compute shaders are not supported in Direct3D 9"); + break; + case API_DIRECT3D11: + result = compile_hlsl_to_d3d11(hlsl, &output, &output_size, SHADER_STAGE_COMPUTE, false); + break; + case API_DIRECT3D12: + result = compile_hlsl_to_d3d12(hlsl, &output, &output_size, SHADER_STAGE_COMPUTE, false); + break; + default: + error(context, "Unsupported API for HLSL"); + } + check(result == 0, context, "HLSL compilation failed"); + + char *name = get_name(main->name); + + char filename[512]; + sprintf(filename, "kong_%s", name); + + char var_name[256]; + sprintf(var_name, "%s_code", name); + + write_bytecode(hlsl, directory, filename, var_name, output, output_size); +} + void hlsl_export(char *directory, api_kind d3d) { int cbuffer_index = 0; int texture_index = 0; @@ -601,6 +656,17 @@ void hlsl_export(char *directory, api_kind d3d) { } } + function *compute_shaders[256]; + size_t compute_shaders_size = 0; + + for (function_id i = 0; get_function(i) != NULL; ++i) { + function *f = get_function(i); + if (f->attribute == add_name("compute")) { + compute_shaders[compute_shaders_size] = f; + compute_shaders_size += 1; + } + } + for (size_t i = 0; i < vertex_shaders_size; ++i) { hlsl_export_vertex(directory, d3d, vertex_shaders[i]); } @@ -608,4 +674,8 @@ void hlsl_export(char *directory, api_kind d3d) { for (size_t i = 0; i < fragment_shaders_size; ++i) { hlsl_export_fragment(directory, d3d, fragment_shaders[i]); } + + for (size_t i = 0; i < compute_shaders_size; ++i) { + hlsl_export_compute(directory, d3d, compute_shaders[i]); + } } diff --git a/Sources/integrations/kinc.c b/Sources/integrations/kinc.c index 0f614a4..6c4ef72 100644 --- a/Sources/integrations/kinc.c +++ b/Sources/integrations/kinc.c @@ -409,6 +409,13 @@ void kinc_export(char *directory, api_kind api) { } } + for (function_id i = 0; get_function(i) != NULL; ++i) { + function *f = get_function(i); + if (f->attribute == add_name("compute")) { + fprintf(output, "extern kinc_g4_compute_shader %s;\n\n", get_name(f->name)); + } + } + fprintf(output, "#endif\n"); fclose(output); @@ -445,6 +452,13 @@ void kinc_export(char *directory, api_kind api) { } } } + + for (function_id i = 0; get_function(i) != NULL; ++i) { + function *f = get_function(i); + if (f->attribute == add_name("compute")) { + fprintf(output, "#include \"kong_%s.h\"\n", get_name(f->name)); + } + } } fprintf(output, "\n#include \n\n"); @@ -531,6 +545,13 @@ void kinc_export(char *directory, api_kind api) { } } + for (function_id i = 0; get_function(i) != NULL; ++i) { + function *f = get_function(i); + if (f->attribute == add_name("compute")) { + fprintf(output, "kinc_g4_compute_shader %s;\n", get_name(f->name)); + } + } + if (api == API_WEBGPU) { fprintf(output, "\nvoid kinc_g5_internal_webgpu_create_shader_module(const void *source, size_t length);\n"); } @@ -677,11 +698,11 @@ void kinc_export(char *directory, api_kind api) { for (size_t j = 0; j < t->members.size; ++j) { if (api == API_OPENGL) { fprintf(output, "\tkinc_g4_vertex_structure_add(&%s_structure, \"%s_%s\", %s);\n", get_name(t->name), get_name(t->name), - get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type)); + get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type)); } else { fprintf(output, "\tkinc_g4_vertex_structure_add(&%s_structure, \"%s\", %s);\n", get_name(t->name), - get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type)); + get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type)); } } fprintf(output, "\n"); @@ -723,6 +744,14 @@ void kinc_export(char *directory, api_kind api) { } } } + + for (function_id i = 0; get_function(i) != NULL; ++i) { + function *f = get_function(i); + if (f->attribute == add_name("compute")) { + fprintf(output, "\tkinc_g4_compute_shader_init(&%s, %s_code, %s_code_size);\n", get_name(f->name), get_name(f->name), get_name(f->name)); + } + } + fprintf(output, "}\n"); fclose(output); diff --git a/Sources/shader_stage.h b/Sources/shader_stage.h index 7d214d8..71af35f 100644 --- a/Sources/shader_stage.h +++ b/Sources/shader_stage.h @@ -1,3 +1,3 @@ #pragma once -typedef enum shader_stage { SHADER_STAGE_VERTEX, SHADER_STAGE_FRAGMENT } shader_stage; +typedef enum shader_stage { SHADER_STAGE_VERTEX, SHADER_STAGE_FRAGMENT, SHADER_STAGE_COMPUTE } shader_stage;