Skip to content

Commit

Permalink
Mobile NeRF general fixes (KhronosGroup#1092)
Browse files Browse the repository at this point in the history
* Generic updates to the mobile nerf sample

Signed-off-by: Rodrigo Holztrattner <[email protected]>

* Remove unnecessary #ifdefs

Signed-off-by: Rodrigo Holztrattner <[email protected]>

---------

Signed-off-by: Rodrigo Holztrattner <[email protected]>
  • Loading branch information
RodrigoHolztrattner-QuIC authored Jul 15, 2024
1 parent b76f0f0 commit 4edd653
Show file tree
Hide file tree
Showing 14 changed files with 897 additions and 456 deletions.
491 changes: 332 additions & 159 deletions samples/general/mobile_nerf/mobile_nerf.cpp

Large diffs are not rendered by default.

47 changes: 28 additions & 19 deletions samples/general/mobile_nerf/mobile_nerf.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class MobileNerf : public ApiVulkanSample
alignas(4) float tan_half_fov;
} global_uniform;

struct PushConstants
{
unsigned int weight_idx;
} push_constants;

#define WEIGHTS_0_COUNT (176)
#define WEIGHTS_1_COUNT (256)
// The third layer weights' size is changed from 48 to 64 to make sure a 16 bytes alignement
Expand All @@ -78,13 +83,8 @@ class MobileNerf : public ApiVulkanSample

struct Vertex
{
alignas(4) glm::vec3 position;
alignas(4) glm::vec2 tex_coord;

bool operator==(const Vertex &other) const
{
return position == other.position && tex_coord == other.tex_coord;
}
glm::vec3 position;
glm::vec2 tex_coord;
};

struct InstancingInfo
Expand Down Expand Up @@ -130,10 +130,6 @@ class MobileNerf : public ApiVulkanSample
// Deferred mode will only have one set of descriptor per model
std::vector<VkDescriptorSet> descriptor_set_first_pass{VK_NULL_HANDLE};

// Stores references to each models mlp weights and uniform buffers
std::unique_ptr<vkb::core::Buffer> *weights_buffer_ref;
std::unique_ptr<vkb::core::Buffer> *uniform_buffer_ref;

int sub_model_num;
int model_index;
};
Expand All @@ -147,7 +143,7 @@ class MobileNerf : public ApiVulkanSample
// Uniform buffer for each model
std::vector<std::unique_ptr<vkb::core::Buffer>> uniform_buffers;

// buffer to store instance data
// Buffer to store instance data
std::unique_ptr<vkb::core::Buffer> instance_buffer{nullptr};

// Common
Expand All @@ -160,6 +156,7 @@ class MobileNerf : public ApiVulkanSample
void load_scene(int model_index, int sub_model_index, int models_entry);
void initialize_mlp_uniform_buffers(int model_index);
void update_uniform_buffers();
void update_weights_buffers();

void create_texture(int model_index, int sub_model_index, int models_entry);
void create_texture_helper(std::string const &texturePath, Texture &texture);
Expand All @@ -181,9 +178,16 @@ class MobileNerf : public ApiVulkanSample
void create_descriptor_sets_first_pass(Model &model);
void prepare_pipelines();

unsigned int color_attach_0_idx;
unsigned int color_attach_1_idx;
unsigned int color_attach_2_idx;
unsigned int color_attach_3_idx;
unsigned int depth_attach_idx;
unsigned int swapchain_attach_idx;

struct Attachments_baseline
{
FrameBufferAttachment feature_0, feature_1, feature_2;
FrameBufferAttachment feature_0, feature_1, feature_2, weights_idx;
};

std::vector<Attachments_baseline> frameAttachments;
Expand All @@ -209,20 +213,25 @@ class MobileNerf : public ApiVulkanSample
// For loading nerf assets map
json asset_map;
std::vector<std::string> model_path;
bool combo_mode;
std::vector<bool> using_original_nerf_models;
bool use_deferred;
bool do_rotation;
bool combo_mode = false;
bool use_deferred = false;
bool do_rotation = false;

glm::vec3 camera_pos = glm::vec3(-2.2f, 2.2f, 2.2f);

// Currently combo mode translation are hard-coded
glm::mat4x4 combo_model_transform[4] = {
glm::translate(glm::vec3(0.5, 0.75, 0)), glm::translate(glm::vec3(0.5, 0.25, 0)),
glm::translate(glm::vec3(0, -0.25, 0.5)), glm::translate(glm::vec3(0, -0.75, -0.5))};

// For instancing
InstancingInfo instancing_info;

// Viewport Setting
float fov = 60.0f;
uint32_t view_port_width;
uint32_t view_port_height;
float fov = 60.0f;
uint32_t view_port_width = width;
uint32_t view_port_height = height;
bool use_native_screen_size = false;
};

Expand Down
37 changes: 16 additions & 21 deletions shaders/mobile_nerf/merged.frag
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* ------------------------------------------------------------------------
*
* THIS IS A MODIFIED VERSION OF THE ORIGINAL FILE
*
*
* The original file, along with the original Apache-2.0 LICENSE can be found at:
* https://github.com/google-research/jax3d/tree/main/jax3d/projects/mobilenerf
*
Expand Down Expand Up @@ -49,19 +49,19 @@ precision highp float;
#define BIAS_2_COUNT (4)
layout(binding = 3) uniform mlp_weights
{
vec4 data[(WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
vec4 data[(WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
BIAS_0_COUNT + BIAS_1_COUNT + BIAS_2_COUNT)/4]; // Array of floats
} weights;


vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
{

vec3 res;

int bias_0_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT;
vec4 intermediate_one[4] = vec4[](
weights.data[bias_0_ind/4],
weights.data[bias_0_ind/4],
weights.data[bias_0_ind/4 + 1],
weights.data[bias_0_ind/4 + 2],
weights.data[bias_0_ind/4 + 3]
Expand All @@ -72,7 +72,7 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
intermediate_one[ 0] += (multiplier) * weights.data[ weightFirstInd/4]; \
intermediate_one[ 1] += (multiplier) * weights.data[ weightFirstInd/4 + 1]; \
intermediate_one[ 2] += (multiplier) * weights.data[ weightFirstInd/4 + 2]; \
intermediate_one[ 3] += (multiplier) * weights.data[ weightFirstInd/4 + 3];
intermediate_one[ 3] += (multiplier) * weights.data[ weightFirstInd/4 + 3];

APPLY_WEIGHTS_0( f0.r, 0)
APPLY_WEIGHTS_0( f0.g, 16)
Expand All @@ -87,10 +87,10 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
APPLY_WEIGHTS_0( -viewdir.b, 144)
APPLY_WEIGHTS_0( viewdir.g, 160)

int bias_1_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
int bias_1_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
BIAS_0_COUNT;
vec4 intermediate_two[4] = vec4[](
weights.data[bias_1_ind/4],
weights.data[bias_1_ind/4],
weights.data[bias_1_ind/4 + 1],
weights.data[bias_1_ind/4 + 2],
weights.data[bias_1_ind/4 + 3]
Expand All @@ -104,7 +104,7 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
intermediate_two[ 2] += intermediate * weights.data[ WEIGHTS_0_COUNT/4 + oneInd * 4 + 2]; \
intermediate_two[ 3] += intermediate * weights.data[ WEIGHTS_0_COUNT/4 + oneInd * 4 + 3]; \
}

APPLY_WEIGHTS_1( intermediate_one[0].r, 0)
APPLY_WEIGHTS_1( intermediate_one[0].g, 1)
APPLY_WEIGHTS_1( intermediate_one[0].b, 2)
Expand All @@ -122,15 +122,15 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
APPLY_WEIGHTS_1( intermediate_one[3].b, 14)
APPLY_WEIGHTS_1( intermediate_one[3].a, 15)

int bias_2_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
int bias_2_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
BIAS_0_COUNT + BIAS_1_COUNT;
vec4 result = weights.data[bias_2_ind/4];

#define APPLY_WEIGHTS_2(intermediate, oneInd) \
if(intermediate > 0.0f){ \
result += intermediate * weights.data[ WEIGHTS_0_COUNT/4 + WEIGHTS_1_COUNT/4 + oneInd]; \
}

APPLY_WEIGHTS_2(intermediate_two[0].r, 0)
APPLY_WEIGHTS_2(intermediate_two[0].g, 1)
APPLY_WEIGHTS_2(intermediate_two[0].b, 2)
Expand All @@ -147,9 +147,9 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
APPLY_WEIGHTS_2(intermediate_two[3].g,13)
APPLY_WEIGHTS_2(intermediate_two[3].b,14)
APPLY_WEIGHTS_2(intermediate_two[3].a,15)

result = 1.0 / (1.0 + exp(-result));
return vec3(result * viewdir.a+(1.0-viewdir.a));
return vec3(result * viewdir.a+(1.0-viewdir.a));
}

//////////////////////////////////////////////////////////////
Expand All @@ -164,7 +164,7 @@ float Convert_sRGB_ToLinear(float value)
: pow((value + 0.055) / 1.055, 2.4);
}

vec3 Convert_sRGB_ToLinear(vec3 value)
vec3 Convert_sRGB_ToLinear(vec3 value)
{
return vec3(Convert_sRGB_ToLinear(value.x), Convert_sRGB_ToLinear(value.y), Convert_sRGB_ToLinear(value.z));
}
Expand All @@ -181,20 +181,15 @@ void main(void)
vec4 pixel_1 = texture(textureInput_1, flipped);
vec4 feature_0 = pixel_0;
vec4 feature_1 = pixel_1;

vec4 rayDirection = vec4(normalize(rayDirectionIn), 1.0f);

// For debugging only
// o_color_0 = vec4( texCoord_frag.x, 0.0f, 0.0f, 1.0f);

// deal with iphone
feature_0.a = feature_0.a*2.0-1.0;
feature_1.a = feature_1.a*2.0-1.0;
rayDirection.a = rayDirection.a*2.0-1.0;

// Original

// Original
o_color.rgb = Convert_sRGB_ToLinear(evaluateNetwork(feature_0,feature_1,rayDirection));
//o_color.rgb = feature_0.rgb;
o_color.a = 1.0;
}
36 changes: 16 additions & 20 deletions shaders/mobile_nerf/merged_morpheus.frag
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* ------------------------------------------------------------------------
*
* THIS IS A MODIFIED VERSION OF THE ORIGINAL FILE
*
*
* The original file, along with the original Apache-2.0 LICENSE can be found at:
* https://github.com/google-research/jax3d/tree/main/jax3d/projects/mobilenerf
*
Expand Down Expand Up @@ -49,19 +49,19 @@ precision highp float;
#define BIAS_2_COUNT (4)
layout(binding = 3) uniform mlp_weights
{
vec4 data[(WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
vec4 data[(WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
BIAS_0_COUNT + BIAS_1_COUNT + BIAS_2_COUNT)/4]; // Array of floats
} weights;


vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
{

vec3 res;

int bias_0_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT;
vec4 intermediate_one[4] = vec4[](
weights.data[bias_0_ind/4],
weights.data[bias_0_ind/4],
weights.data[bias_0_ind/4 + 1],
weights.data[bias_0_ind/4 + 2],
weights.data[bias_0_ind/4 + 3]
Expand All @@ -72,7 +72,7 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
intermediate_one[ 0] += (multiplier) * weights.data[ weightFirstInd/4]; \
intermediate_one[ 1] += (multiplier) * weights.data[ weightFirstInd/4 + 1]; \
intermediate_one[ 2] += (multiplier) * weights.data[ weightFirstInd/4 + 2]; \
intermediate_one[ 3] += (multiplier) * weights.data[ weightFirstInd/4 + 3];
intermediate_one[ 3] += (multiplier) * weights.data[ weightFirstInd/4 + 3];

APPLY_WEIGHTS_0( f0.r, 0)
APPLY_WEIGHTS_0( f0.g, 16)
Expand All @@ -87,10 +87,10 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
APPLY_WEIGHTS_0( (-viewdir.b + 1.0 )/2, 144)
APPLY_WEIGHTS_0( (viewdir.g + 1.0 )/2, 160)

int bias_1_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
int bias_1_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
BIAS_0_COUNT;
vec4 intermediate_two[4] = vec4[](
weights.data[bias_1_ind/4],
weights.data[bias_1_ind/4],
weights.data[bias_1_ind/4 + 1],
weights.data[bias_1_ind/4 + 2],
weights.data[bias_1_ind/4 + 3]
Expand All @@ -104,7 +104,7 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
intermediate_two[ 2] += intermediate * weights.data[ WEIGHTS_0_COUNT/4 + oneInd * 4 + 2]; \
intermediate_two[ 3] += intermediate * weights.data[ WEIGHTS_0_COUNT/4 + oneInd * 4 + 3]; \
}

APPLY_WEIGHTS_1( intermediate_one[0].r, 0)
APPLY_WEIGHTS_1( intermediate_one[0].g, 1)
APPLY_WEIGHTS_1( intermediate_one[0].b, 2)
Expand All @@ -122,15 +122,15 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
APPLY_WEIGHTS_1( intermediate_one[3].b, 14)
APPLY_WEIGHTS_1( intermediate_one[3].a, 15)

int bias_2_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
int bias_2_ind = WEIGHTS_0_COUNT + WEIGHTS_1_COUNT + WEIGHTS_2_COUNT +
BIAS_0_COUNT + BIAS_1_COUNT;
vec4 result = weights.data[bias_2_ind/4];

#define APPLY_WEIGHTS_2(intermediate, oneInd) \
if(intermediate > 0.0f){ \
result += intermediate * weights.data[ WEIGHTS_0_COUNT/4 + WEIGHTS_1_COUNT/4 + oneInd]; \
}

APPLY_WEIGHTS_2(intermediate_two[0].r, 0)
APPLY_WEIGHTS_2(intermediate_two[0].g, 1)
APPLY_WEIGHTS_2(intermediate_two[0].b, 2)
Expand All @@ -147,9 +147,9 @@ vec3 evaluateNetwork( vec4 f0, vec4 f1, vec4 viewdir)
APPLY_WEIGHTS_2(intermediate_two[3].g,13)
APPLY_WEIGHTS_2(intermediate_two[3].b,14)
APPLY_WEIGHTS_2(intermediate_two[3].a,15)

result = 1.0 / (1.0 + exp(-result));
return vec3(result * viewdir.a+(1.0-viewdir.a));
return vec3(result * viewdir.a+(1.0-viewdir.a));
}

//////////////////////////////////////////////////////////////
Expand All @@ -164,7 +164,7 @@ float Convert_sRGB_ToLinear(float value)
: pow((value + 0.055) / 1.055, 2.4);
}

vec3 Convert_sRGB_ToLinear(vec3 value)
vec3 Convert_sRGB_ToLinear(vec3 value)
{
return vec3(Convert_sRGB_ToLinear(value.x), Convert_sRGB_ToLinear(value.y), Convert_sRGB_ToLinear(value.z));
}
Expand All @@ -181,19 +181,15 @@ void main(void)
vec4 pixel_1 = texture(textureInput_1, flipped);
vec4 feature_0 = pixel_0;
vec4 feature_1 = pixel_1;

vec4 rayDirection = vec4(normalize(rayDirectionIn), 1.0f);

// For debugging only
// o_color_0 = vec4( texCoord_frag.x, 0.0f, 0.0f, 1.0f);

// deal with iphone
feature_0.a = feature_0.a*2.0-1.0;
feature_1.a = feature_1.a*2.0-1.0;
rayDirection.a = rayDirection.a*2.0-1.0;

// Original
o_color.rgb = Convert_sRGB_ToLinear(evaluateNetwork(feature_0,feature_1,rayDirection));
// o_color.rgb = feature_0.rgb;
o_color.a = 1.0;
}
Loading

0 comments on commit 4edd653

Please sign in to comment.