Skip to content

Commit

Permalink
Support intersection and anyhit shaders
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Aug 27, 2024
1 parent 13e4ad0 commit 08b3cd1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 5 deletions.
90 changes: 90 additions & 0 deletions Sources/backends/hlsl.c
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ static size_t raymiss_shaders_size = 0;
static function *rayclosesthit_shaders[256];
static size_t rayclosesthit_shaders_size = 0;

static function *rayintersection_shaders[256];
static size_t rayintersection_shaders_size = 0;

static function *rayanyhit_shaders[256];
static size_t rayanyhit_shaders_size = 0;

static bool is_raygen_shader(function *f) {
for (size_t rayshader_index = 0; rayshader_index < raygen_shaders_size; ++rayshader_index) {
if (f == raygen_shaders[rayshader_index]) {
Expand Down Expand Up @@ -273,6 +279,24 @@ static bool is_rayclosesthit_shader(function *f) {
return false;
}

static bool is_rayintersection_shader(function *f) {
for (size_t rayshader_index = 0; rayshader_index < rayintersection_shaders_size; ++rayshader_index) {
if (f == rayintersection_shaders[rayshader_index]) {
return true;
}
}
return false;
}

static bool is_rayanyhit_shader(function *f) {
for (size_t rayshader_index = 0; rayshader_index < rayanyhit_shaders_size; ++rayshader_index) {
if (f == rayanyhit_shaders[rayshader_index]) {
return true;
}
}
return false;
}

static void write_functions(char *hlsl, size_t *offset, shader_stage stage, function *main, function **rayshaders, size_t rayshaders_count) {
function *functions[256];
size_t functions_size = 0;
Expand Down Expand Up @@ -522,6 +546,48 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
*offset += sprintf(&hlsl[*offset], "\t%s _%" PRIu64 " = _kong_triangle_intersection_attributes.barycentrics;\n",
type_string(f->parameter_types[1].type), parameter_ids[1]);
}
else if (is_rayintersection_shader(f)) {
debug_context context = {0};
check(f->parameters_size == 0, context, "intersection shader can not have any parameters");

*offset += sprintf(&hlsl[*offset], "[shader(\"intersection\")]\n");

*offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
*offset +=
sprintf(&hlsl[*offset], "inout %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
else {
*offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
}
else if (is_rayanyhit_shader(f)) {
debug_context context = {0};
check(f->parameters_size == 2, context, "anyhit shader requires two arguments");
check(f->parameter_types[1].type == float2_id, context, "Second parameter of a rayanyhit shader needs to be a float2");

*offset += sprintf(&hlsl[*offset], "[shader(\"anyhit\")]\n");

*offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
*offset +=
sprintf(&hlsl[*offset], "inout %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
else if (parameter_index == 1) {
*offset += sprintf(&hlsl[*offset], ", BuiltInTriangleIntersectionAttributes _kong_triangle_intersection_attributes");
}
else {
*offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
*offset += sprintf(&hlsl[*offset], "\t%s _%" PRIu64 " = _kong_triangle_intersection_attributes.barycentrics;\n",
type_string(f->parameter_types[1].type), parameter_ids[1]);
}
else {
*offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
Expand Down Expand Up @@ -927,6 +993,14 @@ static void hlsl_export_all_ray_shaders(char *directory) {
all_rayshaders[all_rayshaders_size] = rayclosesthit_shaders[rayshader_index];
all_rayshaders_size += 1;
}
for (size_t rayshader_index = 0; rayshader_index < rayintersection_shaders_size; ++rayshader_index) {
all_rayshaders[all_rayshaders_size] = rayintersection_shaders[rayshader_index];
all_rayshaders_size += 1;
}
for (size_t rayshader_index = 0; rayshader_index < rayanyhit_shaders_size; ++rayshader_index) {
all_rayshaders[all_rayshaders_size] = rayanyhit_shaders[rayshader_index];
all_rayshaders_size += 1;
}

write_types(hlsl, &offset, SHADER_STAGE_RAY_GENERATION, NO_TYPE, NO_TYPE, NULL, all_rayshaders, all_rayshaders_size);

Expand Down Expand Up @@ -1064,6 +1138,8 @@ void hlsl_export(char *directory, api_kind d3d) {
name_id raygen_shader_name = NO_NAME;
name_id raymiss_shader_name = NO_NAME;
name_id rayclosesthit_shader_name = NO_NAME;
name_id rayintersection_shader_name = NO_NAME;
name_id rayanyhit_shader_name = NO_NAME;

for (size_t j = 0; j < t->members.size; ++j) {
if (t->members.m[j].name == add_name("gen")) {
Expand All @@ -1075,6 +1151,12 @@ void hlsl_export(char *directory, api_kind d3d) {
else if (t->members.m[j].name == add_name("closest")) {
rayclosesthit_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("intersection")) {
rayintersection_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("any")) {
rayanyhit_shader_name = t->members.m[j].value.identifier;
}
}

debug_context context = {0};
Expand All @@ -1096,6 +1178,14 @@ void hlsl_export(char *directory, api_kind d3d) {
rayclosesthit_shaders[rayclosesthit_shaders_size] = f;
rayclosesthit_shaders_size += 1;
}
else if (f->name == rayintersection_shader_name) {
rayintersection_shaders[rayintersection_shaders_size] = f;
rayintersection_shaders_size += 1;
}
else if (f->name == rayanyhit_shader_name) {
rayanyhit_shaders[rayanyhit_shaders_size] = f;
rayanyhit_shaders_size += 1;
}
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion Sources/shader_stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ typedef enum shader_stage {
SHADER_STAGE_COMPUTE,
SHADER_STAGE_RAY_GENERATION,
SHADER_STAGE_RAY_MISS,
SHADER_STAGE_RAY_CLOSEST_HIT
SHADER_STAGE_RAY_CLOSEST_HIT,
SHADER_STAGE_RAY_INTERSECTION,
SHADER_STAGE_RAY_ANY_HIT
} shader_stage;
14 changes: 10 additions & 4 deletions tests/in/test.kong
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fun comp(): void {

// based on https://landelare.github.io/2023/02/18/dxr-tutorial.html

/*struct Payload {
struct Payload {
color: float3;
allow_reflection: bool;
missed: bool;
Expand Down Expand Up @@ -94,14 +94,20 @@ fun closesthit(payload: Payload, uv: float2): void {
payload.color = float3(1, 0, 1);
}

fun intersect(): void {}

fun anyhit(payload: Payload, uv: float2): void {}

#[raypipe]
struct RayPipe {
gen = sendrays;
miss = raymissed;
closest = closesthit;
}*/
intersection = intersect;
any = anyhit;
}

struct FragmentIn {
/*struct FragmentIn {
position: float4;
}

Expand Down Expand Up @@ -143,4 +149,4 @@ struct Pipe {
amplification = amplify;
mesh = meshy;
fragment = pixel;
}
}*/

0 comments on commit 08b3cd1

Please sign in to comment.