Skip to content

Commit

Permalink
Improve auto injected disable code to be more optimizer friendly (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
AmoebaChant authored Oct 8, 2024
1 parent 828cbcc commit 04aaf9d
Show file tree
Hide file tree
Showing 44 changed files with 530 additions and 314 deletions.
40 changes: 0 additions & 40 deletions packages/core/src/blocks/disableableBlock.ts

This file was deleted.

100 changes: 100 additions & 0 deletions packages/core/src/blocks/disableableShaderBlock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import type { SmartFilter } from "../smartFilter.js";
import type { ConnectionPoint } from "../connection/connectionPoint.js";

import { ConnectionPointType } from "../connection/connectionPointType.js";
import { createStrongRef } from "../runtime/strongRef.js";
import { ShaderBlock } from "./shaderBlock.js";
import { injectAutoSampleDisableCode } from "../utils/shaderCodeUtils.js";

/**
* The interface that describes the disableable block.
*/
export interface IDisableableBlock {
/**
* The disabled connection point of the block.
*/
disabled: ConnectionPoint<ConnectionPointType.Boolean>;
}

/**
* The strategy to use for making a block disableable.
*/
export enum BlockDisableStrategy {
/**
* The shader code is responsible for defining and consulting a uniform named disabled
* and no-oping (returning texture2D(mainInputTexture, vUV)) if the value is true.
*/
Manual = 0,

/**
* The Smart Filter system will automatically add code to sample the mainInputTexture and return immediately if disabled,
* and otherwise use the value within the block's shader code. If you need to modify UVs before sampling the default input texture,
* you'll need to use the Manual strategy instead.
*/
AutoSample = 1,
}

/**
* A ShaderBlock that can be disabled. The optimizer can optionally remove disabled blocks from the graph,
* or they can be controlled by the disabled connection point at runtime. If disabled, they pass the
* mainInputTexture through to the output connection point.
*/
export abstract class DisableableShaderBlock extends ShaderBlock implements IDisableableBlock {
/**
* The disabled connection point of the block.
*/
public readonly disabled = this._registerOptionalInput(
"disabled",
ConnectionPointType.Boolean,
createStrongRef(false)
);

/**
* The strategy to use for making this block disableable.
*/
public readonly blockDisableStrategy: BlockDisableStrategy;

// The shader code is a static per block type. When an instance of a block is created, we may need to alter
// that code based on the block's disable strategy. We only want to do this once per block type, or we could
// incorrectly modify the shader code multiple times (once per block instance). Here we use a static boolean
// which will be per block type to track if we've already modified the shader code for this block type.
// This is more memory efficient than the alternative of making a copy of the shader code for each block instance
// and modifying each copy.
private static _HasModifiedShaderCode = false;
private get _hasModifiedShaderCode() {
return (this.constructor as typeof DisableableShaderBlock)._HasModifiedShaderCode;
}
private set _hasModifiedShaderCode(value: boolean) {
(this.constructor as typeof DisableableShaderBlock)._HasModifiedShaderCode = value;
}

/**
* Instantiates a new block.
* @param smartFilter - Defines the smart filter the block belongs to
* @param name - Defines the name of the block
* @param disableOptimization - Defines if the block should not be optimized (default: false)
* @param disableStrategy - Defines the strategy to use for making this block disableable (default: BlockDisableStrategy.AutoSample)
*/
constructor(
smartFilter: SmartFilter,
name: string,
disableOptimization = false,
disableStrategy = BlockDisableStrategy.AutoSample
) {
super(smartFilter, name, disableOptimization);
this.blockDisableStrategy = disableStrategy;

// If we haven't already modified the shader code for this block type, do so now
if (!this._hasModifiedShaderCode) {
this._hasModifiedShaderCode = true;

// Apply the disable strategy
const shaderProgram = this.getShaderProgram();
switch (this.blockDisableStrategy) {
case BlockDisableStrategy.AutoSample:
injectAutoSampleDisableCode(shaderProgram);
break;
}
}
}
}
6 changes: 3 additions & 3 deletions packages/core/src/blocks/inputBlock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { SmartFilter } from "../smartFilter";
import type { ConnectionPointValue } from "../connection/connectionPointType";
import type { RuntimeData } from "../connection/connectionPoint";
import type { ConnectionPointWithDefault } from "../connection/connectionPointWithDefault";
import type { DisableableBlock } from "./disableableBlock";
import type { DisableableShaderBlock } from "./disableableShaderBlock";
import { BaseBlock } from "../blocks/baseBlock.js";
import { createStrongRef } from "../runtime/strongRef.js";
import { ConnectionPointType } from "../connection/connectionPointType.js";
Expand Down Expand Up @@ -33,8 +33,8 @@ export function isTextureInputBlock(block: BaseBlock): block is InputBlock<Conne
* @param block - The block to check
* @returns true if the block is a disableable block, otherwise false
*/
export function isDisableableBlock(block: BaseBlock): block is DisableableBlock {
return (block as DisableableBlock).disabled !== undefined;
export function isDisableableShaderBlock(block: BaseBlock): block is DisableableShaderBlock {
return (block as DisableableShaderBlock).disabled !== undefined;
}

/**
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/blocks/outputBlock.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { InitializationData, SmartFilter } from "../smartFilter";
import { ConnectionPointType } from "../connection/connectionPointType.js";
import { BaseBlock } from "./baseBlock.js";
import { Binding, ShaderRuntime } from "../runtime/shaderRuntime.js";
import { ShaderBinding, ShaderRuntime } from "../runtime/shaderRuntime.js";
import type { Nullable } from "@babylonjs/core/types";
import type { RenderTargetWrapper } from "@babylonjs/core/Engines/renderTargetWrapper";
import { registerFinalRenderCommand } from "../utils/renderTargetUtils.js";
Expand Down Expand Up @@ -95,7 +95,7 @@ export class OutputBlock extends BaseBlock {
/**
* Shader binding to use when the OutputBlock is directly connected to a texture InputBlock.
*/
class OutputShaderBinding extends Binding {
class OutputShaderBinding extends ShaderBinding {
private readonly _inputTexture: RuntimeData<ConnectionPointType.Texture>;

/**
Expand Down
8 changes: 4 additions & 4 deletions packages/core/src/blocks/shaderBlock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ import "@babylonjs/core/Engines/Extensions/engine.renderTarget.js";

import type { InitializationData, SmartFilter } from "../smartFilter";
import type { ShaderProgram } from "../utils/shaderCodeUtils";
import type { Binding } from "../runtime/shaderRuntime";
import type { ShaderBinding } from "../runtime/shaderRuntime";
import type { ConnectionPoint } from "../connection/connectionPoint";
import { ShaderRuntime } from "../runtime/shaderRuntime.js";
import { ConnectionPointType } from "../connection/connectionPointType.js";
import { createCommand } from "../command/command.js";
import { DisableableBlock } from "./disableableBlock.js";
import { undecorateSymbol } from "../utils/shaderCodeUtils.js";
import { getRenderTargetWrapper, registerFinalRenderCommand } from "../utils/renderTargetUtils.js";
import { BaseBlock } from "./baseBlock.js";

/**
* This is the base class for all shader blocks.
Expand All @@ -19,7 +19,7 @@ import { getRenderTargetWrapper, registerFinalRenderCommand } from "../utils/ren
*
* The only required function to implement is the bind function.
*/
export abstract class ShaderBlock extends DisableableBlock {
export abstract class ShaderBlock extends BaseBlock {
/**
* The class name of the block.
*/
Expand All @@ -30,7 +30,7 @@ export abstract class ShaderBlock extends DisableableBlock {
* It should throw an error if required inputs are missing.
* @returns The class instance that binds the data to the effect
*/
public abstract getShaderBinding(): Binding;
public abstract getShaderBinding(): ShaderBinding;

/**
* The shader program (vertex and fragment code) to use to render the block
Expand Down
7 changes: 4 additions & 3 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ export { BaseBlock } from "./blocks/baseBlock.js";
export { InputBlock, type InputBlockEditorData } from "./blocks/inputBlock.js";
export { type AnyInputBlock } from "./blocks/inputBlock.js";
export { ShaderBlock } from "./blocks/shaderBlock.js";
export { DisableableShaderBlock, BlockDisableStrategy as DisableStrategy } from "./blocks/disableableShaderBlock.js";
export { AggregateBlock } from "./blocks/aggregateBlock.js";
export { ShaderBinding, ShaderRuntime } from "./runtime/shaderRuntime.js";
export { type ShaderProgram, injectDisableUniform } from "./utils/shaderCodeUtils.js";
export { type IDisableableBlock } from "./blocks/disableableBlock.js";
export { DisableableShaderBinding, ShaderBinding, ShaderRuntime } from "./runtime/shaderRuntime.js";
export { type ShaderProgram } from "./utils/shaderCodeUtils.js";
export { type IDisableableBlock } from "./blocks/disableableShaderBlock.js";

export { type SmartFilterRuntime } from "./runtime/smartFilterRuntime.js";
export { InternalSmartFilterRuntime } from "./runtime/smartFilterRuntime.js";
Expand Down
14 changes: 7 additions & 7 deletions packages/core/src/optimization/optimizedShaderBlock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ import type { SmartFilter } from "../smartFilter";
import type { ShaderProgram } from "../utils/shaderCodeUtils";
import type { RuntimeData } from "../connection/connectionPoint";
import { ShaderBlock } from "../blocks/shaderBlock.js";
import { Binding } from "../runtime/shaderRuntime.js";
import { ShaderBinding } from "../runtime/shaderRuntime.js";
import { ConnectionPointType } from "../connection/connectionPointType.js";

/**
* The shader bindings for the OptimizedShader block.
* @internal
*/
export class OptimizedShaderBinding extends Binding {
private _shaderBindings: Binding[];
export class OptimizedShaderBinding extends ShaderBinding {
private _shaderBindings: ShaderBinding[];
private _inputTextures: { [name: string]: RuntimeData<ConnectionPointType.Texture> };

/**
Expand All @@ -22,7 +22,7 @@ export class OptimizedShaderBinding extends Binding {
* @param inputTextures - The list of input textures to bind
*/
constructor(
shaderBindings: Binding[],
shaderBindings: ShaderBinding[],
inputTextures: { [name: string]: RuntimeData<ConnectionPointType.Texture> }
) {
super();
Expand Down Expand Up @@ -56,7 +56,7 @@ export class OptimizedShaderBinding extends Binding {
* @internal
*/
export class OptimizedShaderBlock extends ShaderBlock {
private _shaderBindings: Nullable<Binding[]>;
private _shaderBindings: Nullable<ShaderBinding[]>;
private _inputTextures: { [name: string]: RuntimeData<ConnectionPointType.Texture> } = {};
private _shaderProgram: ShaderProgram;

Expand Down Expand Up @@ -104,15 +104,15 @@ export class OptimizedShaderBlock extends ShaderBlock {
* Sets the list of shader bindings to use to render the block.
* @param shaderBindings - The list of shader bindings to use to render the block
*/
public setShaderBindings(shaderBindings: Binding[]): void {
public setShaderBindings(shaderBindings: ShaderBinding[]): void {
this._shaderBindings = shaderBindings;
}

/**
* Get the class instance that binds all the required data to the shader (effect) when rendering.
* @returns The class instance that binds the data to the effect
*/
public getShaderBinding(): Binding {
public getShaderBinding(): ShaderBinding {
if (this._shaderBindings === null) {
throw new Error("Shader bindings not set!");
}
Expand Down
48 changes: 43 additions & 5 deletions packages/core/src/optimization/smartFilterOptimizer.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import type { Nullable } from "@babylonjs/core/types";

import type { ConnectionPoint } from "../connection/connectionPoint";
import type { Binding } from "../runtime/shaderRuntime";
import type { ShaderBinding } from "../runtime/shaderRuntime";
import type { InputBlock } from "../blocks/inputBlock";
import type { BaseBlock } from "../blocks/baseBlock";
import { SmartFilter } from "../smartFilter.js";
import { ConnectionPointType } from "../connection/connectionPointType.js";
import { ShaderBlock } from "../blocks/shaderBlock.js";
import { isTextureInputBlock } from "../blocks/inputBlock.js";
import { OptimizedShaderBlock } from "./optimizedShaderBlock.js";
import { decorateChar, decorateSymbol, getShaderFragmentCode, undecorateSymbol } from "../utils/shaderCodeUtils.js";
import {
AutoDisableMainInputColorName,
decorateChar,
decorateSymbol,
getShaderFragmentCode,
undecorateSymbol,
} from "../utils/shaderCodeUtils.js";
import { DependencyGraph } from "./dependencyGraph.js";
import { DisableableBlock } from "../blocks/disableableBlock.js";
import { DisableableShaderBlock, BlockDisableStrategy } from "../blocks/disableableShaderBlock.js";

const showDebugData = false;

Expand Down Expand Up @@ -132,6 +138,9 @@ export class SmartFilterOptimizer {
const connectionsToReconnect: [ConnectionPoint, ConnectionPoint][] = [];

if (this._options.removeDisabledBlocks) {
// Need to propagate runtime data to ensure we can tell if a block is disabled
this._sourceSmartFilter.output.ownerBlock.propagateRuntimeData();

const alreadyVisitedBlocks = new Set<BaseBlock>();
this._disconnectDisabledBlocks(
this._sourceSmartFilter.output.connectedTo.ownerBlock,
Expand Down Expand Up @@ -203,7 +212,7 @@ export class SmartFilterOptimizer {
this._disconnectDisabledBlocks(input.connectedTo.ownerBlock, alreadyVisitedBlocks, inputsToReconnect);
}

if (block instanceof DisableableBlock && block.disabled.runtimeData.value) {
if (block instanceof DisableableShaderBlock && block.disabled.runtimeData.value) {
block.disconnectFromGraph(inputsToReconnect);
}
}
Expand Down Expand Up @@ -508,6 +517,14 @@ export class SmartFilterOptimizer {
throw `The connection point corresponding to the input named "${samplerName}" in block named "${block.name}" is not connected!`;
}

// If we are using the AutoSample strategy, we must preprocess the code that samples the texture
if (
block instanceof DisableableShaderBlock &&
block.blockDisableStrategy === BlockDisableStrategy.AutoSample
) {
code = this._applyAutoSampleStrategy(code, sampler);
}

const parentBlock = input.connectedTo.ownerBlock;

if (isTextureInputBlock(parentBlock)) {
Expand Down Expand Up @@ -629,7 +646,7 @@ export class SmartFilterOptimizer {
});

// Sets the remapping of the shader variables
const blockOwnerToShaderBinding = new Map<ShaderBlock, Binding>();
const blockOwnerToShaderBinding = new Map<ShaderBlock, ShaderBinding>();

let codeUniforms = "";
let codeConsts = "";
Expand Down Expand Up @@ -703,4 +720,25 @@ export class SmartFilterOptimizer {

return optimizedBlock;
}

/**
* If this block used DisableStrategy.AutoSample, find all the sampleTexture calls which just pass the vUV,
* skip the first one, and for all others replace with the local variable created by the DisableStrategy.AutoSample
*
* @param code - The shader code to process
* @param sampler - The name of the sampler
*
* @returns The processed code
*/
private _applyAutoSampleStrategy(code: string, sampler: string): string {
let isFirstMatch = true;
const rx = new RegExp(`sampleTexture\\s*\\(\\s*${sampler}\\s*,\\s*vUV\\s*\\)`, "g");
return code.replace(rx, (match) => {
if (isFirstMatch) {
isFirstMatch = false;
return match;
}
return decorateSymbol(AutoDisableMainInputColorName);
});
}
}
Loading

0 comments on commit 04aaf9d

Please sign in to comment.