From fd30e0c67dbea269fa2f266c3ad71ca86bfbf048 Mon Sep 17 00:00:00 2001 From: JMS55 <47158642+JMS55@users.noreply.github.com> Date: Mon, 10 Jun 2024 13:18:43 -0700 Subject: [PATCH] Fix meshlet vertex attribute interpolation (#13775) # Objective - Mikktspace requires that we normalize world normals/tangents _before_ interpolation across vertices, and then do _not_ normalize after. I had it backwards. - We do not (am not supposed to?) need a second set of barycentrics for motion vectors. If you think about the typical raster pipeline, in the vertex shader we calculate previous_world_position, and then it gets interpolated using the current triangle's barycentrics. ## Solution - Fix normal/tangent processing - Reuse barycentrics for motion vector calculations - Not implementing this for 0.14, but long term I aim to remove explicit vertex tangents and calculate them in the shader on the fly. ## Testing - I tested out some of the normal maps we have in repo. Didn't seem to make a difference, but mikktspace is all about correctness across various baking tools. I probably just didn't have any of the ones that would cause it to break. - Didn't test motion vectors as there's a known bug with the depth buffer and meshlets that I'm waiting on the render graph rewrite to fix. --- .../meshlet/visibility_buffer_resolve.wgsl | 77 +++++++++++-------- .../bevy_pbr/src/render/mesh_functions.wgsl | 8 +- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl b/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl index 3af1c1a506084..baf72afcc4cab 100644 --- a/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl +++ b/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl @@ -13,8 +13,8 @@ unpack_meshlet_vertex, }, mesh_view_bindings::view, - mesh_functions::mesh_position_local_to_world, - mesh_types::MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT, + mesh_functions::{mesh_position_local_to_world, sign_determinant_model_3x3m}, + mesh_types::{Mesh, MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT}, view_transformations::{position_world_to_clip, frag_coord_to_ndc}, } #import bevy_render::maths::{affine3_to_square, mat2x4_f32_to_mat3x3_unpack} @@ -99,6 +99,7 @@ fn resolve_vertex_output(frag_coord: vec4) -> VertexOutput { let cluster_id = packed_ids >> 6u; let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id]; let meshlet = meshlets[meshlet_id]; + let triangle_id = extractBits(packed_ids, 0u, 6u); let index_ids = meshlet.start_index_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u); let indices = meshlet.start_vertex_id + vec3(get_meshlet_index(index_ids.x), get_meshlet_index(index_ids.y), get_meshlet_index(index_ids.z)); @@ -108,9 +109,9 @@ fn resolve_vertex_output(frag_coord: vec4) -> VertexOutput { let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]); let instance_id = meshlet_cluster_instance_ids[cluster_id]; - let instance_uniform = meshlet_instance_uniforms[instance_id]; - let world_from_local = affine3_to_square(instance_uniform.world_from_local); + var instance_uniform = meshlet_instance_uniforms[instance_id]; + let world_from_local = affine3_to_square(instance_uniform.world_from_local); let world_position_1 = mesh_position_local_to_world(world_from_local, vec4(vertex_1.position, 1.0)); let world_position_2 = mesh_position_local_to_world(world_from_local, vec4(vertex_2.position, 1.0)); let world_position_3 = mesh_position_local_to_world(world_from_local, vec4(vertex_3.position, 1.0)); @@ -126,27 +127,19 @@ fn resolve_vertex_output(frag_coord: vec4) -> VertexOutput { ); let world_position = mat3x4(world_position_1, world_position_2, world_position_3) * partial_derivatives.barycentrics; - let vertex_normal = mat3x3(vertex_1.normal, vertex_2.normal, vertex_3.normal) * partial_derivatives.barycentrics; - let world_normal = normalize( - mat2x4_f32_to_mat3x3_unpack( - instance_uniform.local_from_world_transpose_a, - instance_uniform.local_from_world_transpose_b, - ) * vertex_normal - ); + let world_normal = mat3x3( + normal_local_to_world(vertex_1.normal, &instance_uniform), + normal_local_to_world(vertex_2.normal, &instance_uniform), + normal_local_to_world(vertex_3.normal, &instance_uniform), + ) * partial_derivatives.barycentrics; let uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.barycentrics; let ddx_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddx; let ddy_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddy; - let vertex_tangent = mat3x4(vertex_1.tangent, vertex_2.tangent, vertex_3.tangent) * partial_derivatives.barycentrics; - let world_tangent = vec4( - normalize( - mat3x3( - world_from_local[0].xyz, - world_from_local[1].xyz, - world_from_local[2].xyz - ) * vertex_tangent.xyz - ), - vertex_tangent.w * (f32(bool(instance_uniform.flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0) - ); + let world_tangent = mat3x4( + tangent_local_to_world(vertex_1.tangent, world_from_local, instance_uniform.flags), + tangent_local_to_world(vertex_2.tangent, world_from_local, instance_uniform.flags), + tangent_local_to_world(vertex_3.tangent, world_from_local, instance_uniform.flags), + ) * partial_derivatives.barycentrics; #ifdef PREPASS_FRAGMENT #ifdef MOTION_VECTOR_PREPASS @@ -154,15 +147,7 @@ fn resolve_vertex_output(frag_coord: vec4) -> VertexOutput { let previous_world_position_1 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_1.position, 1.0)); let previous_world_position_2 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_2.position, 1.0)); let previous_world_position_3 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_3.position, 1.0)); - let previous_clip_position_1 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_1.xyz, 1.0); - let previous_clip_position_2 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_2.xyz, 1.0); - let previous_clip_position_3 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_3.xyz, 1.0); - let previous_partial_derivatives = compute_partial_derivatives( - array(previous_clip_position_1, previous_clip_position_2, previous_clip_position_3), - frag_coord_ndc, - view.viewport.zw, - ); - let previous_world_position = mat3x4(previous_world_position_1, previous_world_position_2, previous_world_position_3) * previous_partial_derivatives.barycentrics; + let previous_world_position = mat3x4(previous_world_position_1, previous_world_position_2, previous_world_position_3) * partial_derivatives.barycentrics; let motion_vector = calculate_motion_vector(world_position, previous_world_position); #endif #endif @@ -184,4 +169,34 @@ fn resolve_vertex_output(frag_coord: vec4) -> VertexOutput { #endif ); } + +fn normal_local_to_world(vertex_normal: vec3, instance_uniform: ptr) -> vec3 { + if any(vertex_normal != vec3(0.0)) { + return normalize( + mat2x4_f32_to_mat3x3_unpack( + (*instance_uniform).local_from_world_transpose_a, + (*instance_uniform).local_from_world_transpose_b, + ) * vertex_normal + ); + } else { + return vertex_normal; + } +} + +fn tangent_local_to_world(vertex_tangent: vec4, world_from_local: mat4x4, mesh_flags: u32) -> vec4 { + if any(vertex_tangent != vec4(0.0)) { + return vec4( + normalize( + mat3x3( + world_from_local[0].xyz, + world_from_local[1].xyz, + world_from_local[2].xyz, + ) * vertex_tangent.xyz + ), + vertex_tangent.w * sign_determinant_model_3x3m(mesh_flags) + ); + } else { + return vertex_tangent; + } +} #endif diff --git a/crates/bevy_pbr/src/render/mesh_functions.wgsl b/crates/bevy_pbr/src/render/mesh_functions.wgsl index 2ebc96dd331e1..b58004cadf1e9 100644 --- a/crates/bevy_pbr/src/render/mesh_functions.wgsl +++ b/crates/bevy_pbr/src/render/mesh_functions.wgsl @@ -55,11 +55,11 @@ fn mesh_normal_local_to_world(vertex_normal: vec3, instance_index: u32) -> // Calculates the sign of the determinant of the 3x3 model matrix based on a // mesh flag -fn sign_determinant_model_3x3m(instance_index: u32) -> f32 { +fn sign_determinant_model_3x3m(mesh_flags: u32) -> f32 { // bool(u32) is false if 0u else true // f32(bool) is 1.0 if true else 0.0 // * 2.0 - 1.0 remaps 0.0 or 1.0 to -1.0 or 1.0 respectively - return f32(bool(mesh[instance_index].flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0; + return f32(bool(mesh_flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0; } fn mesh_tangent_local_to_world(world_from_local: mat4x4, vertex_tangent: vec4, instance_index: u32) -> vec4 { @@ -76,12 +76,12 @@ fn mesh_tangent_local_to_world(world_from_local: mat4x4, vertex_tangent: ve mat3x3( world_from_local[0].xyz, world_from_local[1].xyz, - world_from_local[2].xyz + world_from_local[2].xyz, ) * vertex_tangent.xyz ), // NOTE: Multiplying by the sign of the determinant of the 3x3 model matrix accounts for // situations such as negative scaling. - vertex_tangent.w * sign_determinant_model_3x3m(instance_index) + vertex_tangent.w * sign_determinant_model_3x3m(mesh[instance_index].flags) ); } else { return vertex_tangent;