Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mhochsteger committed Oct 16, 2024
1 parent b00a094 commit 55e8e58
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 107 deletions.
15 changes: 0 additions & 15 deletions shader.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,6 @@ struct Edge { v: vec2<u32> };
struct Segment { v: vec2<u32>, index: i32 };
struct Trig { v: vec3<u32>, index: i32 };

struct ClippingPlane {
normal: vec3<f32>,
dist: f32,
};

struct Colormap {
min: f32,
max: f32,
}

struct Complex {
re: f32,
imag: f32,
};

struct Uniforms {
mat: mat4x4<f32>,
clipping_plane: vec4<f32>,
Expand Down
181 changes: 89 additions & 92 deletions webgpu.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,50 @@
import ctypes as ct
import math
import sys

import js
import pyodide.ffi
from pyodide.ffi import create_proxy


class ClippingPlaneUniform(ct.Structure):
_fields_ = [("normal", ct.c_float * 3), ("dist", ct.c_float)]


class ComplexUniform(ct.Structure):
_fields_ = [("re", ct.c_float), ("im", ct.c_float)]


class ColormapUniform(ct.Structure):
_fields_ = [("min", ct.c_float), ("max", ct.c_float)]


class Uniforms(ct.Structure):
_fields_ = [
("mat", ct.c_float * 16),
("clipping_plane", ClippingPlaneUniform),
("colormap", ColormapUniform),
("scaling", ComplexUniform),
("aspect", ct.c_float),
("eval_mode", ct.c_uint32),
("do_clipping", ct.c_uint32),
("padding", ct.c_uint32),
]


class WebGPU:
"""WebGPU management class, handles "global" state, like device, canvas, colormap and uniforms"""

def __init__(self, device, canvas, format):
def __init__(self, device, canvas):
self.device = device
self.format = format
self.format = js.navigator.gpu.getPreferredCanvasFormat()
self.canvas = canvas

self.context = canvas.getContext("webgpu")
self.context.configure(
to_js(
{
"device": device,
"format": format,
"format": self.format,
"alphaMode": "premultiplied",
}
)
Expand All @@ -33,7 +57,6 @@ def __init__(self, device, canvas, format):
}
)
)

self.colormap_texture, self.colormap_sampler = create_colormap(device)
self.depth_format = "depth24plus"
self.depth_stencil = {
Expand Down Expand Up @@ -81,6 +104,41 @@ def update_uniform_buffer(self):
buffer = js.Uint8Array.new(bytes(self.uniforms))
self.device.queue.writeBuffer(self.uniform_buffer, 0, buffer)

def get_bindings(self):
"""Returns layout and resource arrays used to create binding layout and binding groups
Current entires are: Uniforms, colormap texture, colormap sampler"""

FRAGMENT = js.GPUShaderStage.FRAGMENT
BOTH = js.GPUShaderStage.VERTEX | FRAGMENT

layouts = [
{
"visibility": BOTH,
"buffer": {"type": "uniform"},
},
{
"visibility": FRAGMENT,
"texture": {
"sampleType": "float",
"viewDimension": "1d",
"multisamples": False,
},
},
{
"visibility": FRAGMENT,
"sampler": {"type": "filtering"},
},
]
resources = [
{"resource": res}
for res in [
{"buffer": self.uniform_buffer},
self.colormap_texture.createView(),
self.colormap_sampler,
]
]
return layouts, resources

def begin_render_pass(self, command_encoder):
render_pass_encoder = command_encoder.beginRenderPass(
to_js(
Expand Down Expand Up @@ -118,10 +176,8 @@ def __init__(self, mesh, gpu, shader_code):
self.gpu = gpu

self._create_buffers()
self._create_bind_group_layout()
self._create_bind_group()

self._create_pipeline_layout()
self._create_bind_group()
self._create_pipelines(shader_code)

def _create_buffers(self):
Expand Down Expand Up @@ -165,70 +221,36 @@ def _create_buffers(self):
self.gpu.device.queue.writeBuffer(buffers[name], 0, values)
self._buffers = buffers

def _create_bind_group_layout(self):
self._bind_group_layout = self.gpu.device.createBindGroupLayout(
to_js(
def _create_bind_group(self):
"""Get binding data from WebGPU class and add values used for mesh rendering"""
VERTEX = js.GPUShaderStage.VERTEX
FRAGMENT = js.GPUShaderStage.FRAGMENT
BOTH = VERTEX | FRAGMENT

layouts, resources = self.gpu.get_bindings()

for name in ["vertices", "edges", "trigs"]:
layouts.append(
{
"entries": [
{
"binding": 0,
"visibility": js.GPUShaderStage.VERTEX
| js.GPUShaderStage.FRAGMENT,
"buffer": {"type": "uniform"},
},
{
"binding": 1,
"visibility": js.GPUShaderStage.FRAGMENT,
"texture": {
"sampleType": "float",
"viewDimension": "1d",
"multisamples": False,
},
},
{
"binding": 2,
"visibility": js.GPUShaderStage.FRAGMENT,
"sampler": {"type": "filtering"},
},
{
"binding": 3,
"visibility": js.GPUShaderStage.VERTEX,
"buffer": {"type": "read-only-storage"},
},
{
"binding": 4,
"visibility": js.GPUShaderStage.VERTEX,
"buffer": {"type": "read-only-storage"},
},
{
"binding": 5,
"visibility": js.GPUShaderStage.VERTEX,
"buffer": {"type": "read-only-storage"},
},
],
"visibility": BOTH,
"buffer": {"type": "read-only-storage"},
}
)
resources.append({"resource": {"buffer": self._buffers[name]}})

for i in range(len(layouts)):
layouts[i]["binding"] = i
resources[i]["binding"] = i

self._bind_group_layout = self.gpu.device.createBindGroupLayout(
to_js({"entries": layouts})
)

def _create_bind_group(self):
self._bind_group = self.gpu.device.createBindGroup(
to_js(
{
"layout": self._bind_group_layout,
"entries": [
{"binding": 0, "resource": {"buffer": self.gpu.uniform_buffer}},
{
"binding": 1,
"resource": self.gpu.colormap_texture.createView(),
},
{"binding": 2, "resource": self.gpu.colormap_sampler},
{
"binding": 3,
"resource": {"buffer": self._buffers["vertices"]},
},
{"binding": 4, "resource": {"buffer": self._buffers["edges"]}},
{"binding": 5, "resource": {"buffer": self._buffers["trigs"]}},
],
"entries": resources,
}
)
)
Expand All @@ -239,6 +261,7 @@ def _create_pipeline_layout(self):
)

def _create_pipelines(self, shader_code):
self._create_pipeline_layout()
shader_module = self.gpu.device.createShaderModule(to_js({"code": shader_code}))
edges_pipeline = self.gpu.device.createRenderPipeline(
to_js(
Expand Down Expand Up @@ -293,31 +316,6 @@ def draw(self, encoder):
encoder.draw(3, self.n_trigs, 0, 0)


class ClippingPlane(ct.Structure):
_fields_ = [("normal", ct.c_float * 3), ("dist", ct.c_float)]


class Complex(ct.Structure):
_fields_ = [("re", ct.c_float), ("im", ct.c_float)]


class Colormap(ct.Structure):
_fields_ = [("min", ct.c_float), ("max", ct.c_float)]


class Uniforms(ct.Structure):
_fields_ = [
("mat", ct.c_float * 16),
("clipping_plane", ClippingPlane),
("colormap", Colormap),
("scaling", Complex),
("aspect", ct.c_float),
("eval_mode", ct.c_uint32),
("do_clipping", ct.c_uint32),
("padding", ct.c_uint32),
]


def to_js(value):
return pyodide.ffi.to_js(value, dict_converter=js.Object.fromEntries)

Expand Down Expand Up @@ -377,7 +375,7 @@ def on_mousemove(ev):
js.requestAnimationFrame(_render_function)


async def init_canvas(canvas):
def init_canvas(canvas):
if not js.navigator.gpu:
abort()

Expand Down Expand Up @@ -437,16 +435,15 @@ def create_colormap(device):


async def main(canvas=None, shader_url="./shader.wgsl"):
canvas = await init_canvas(canvas)
canvas = init_canvas(canvas)
adapter = await js.navigator.gpu.requestAdapter()

if not adapter:
abort()

device = await adapter.requestDevice()
format = js.navigator.gpu.getPreferredCanvasFormat()

gpu = WebGPU(device, canvas, format)
gpu = WebGPU(device, canvas)

from netgen.occ import unit_square

Expand Down

0 comments on commit 55e8e58

Please sign in to comment.