Skip to content

Commit

Permalink
Split HLSL program into statements
Browse files Browse the repository at this point in the history
  • Loading branch information
AndresTraks committed Dec 15, 2024
1 parent 86e9b74 commit 02bff80
Show file tree
Hide file tree
Showing 17 changed files with 174 additions and 127 deletions.
5 changes: 0 additions & 5 deletions Hlsl/Compiler/NodeGrouper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,6 @@ public bool CanGroupComponents(HlslTreeNode node1, HlslTreeNode node2, bool allo
CanGroupComponents(compare1.LessValue, compare2.LessValue) &&
CanGroupComponents(compare1.GreaterEqualValue, compare2.GreaterEqualValue);
}
else if (operation1 is ClipOperation clip1 &&
operation2 is ClipOperation clip2)
{
return CanGroupComponents(clip1.Value, clip2.Value);
}
else if (operation1 is LengthOperation length1 &&
operation2 is LengthOperation length2)
{
Expand Down
6 changes: 6 additions & 0 deletions Hlsl/FlowControl/BreakStatement.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace HlslDecompiler.Hlsl.FlowControl
{
public class BreakStatement : IStatement
{
}
}
12 changes: 12 additions & 0 deletions Hlsl/FlowControl/ClipStatement.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace HlslDecompiler.Hlsl.FlowControl
{
public class ClipStatement : IStatement
{
public HlslTreeNode Value { get; }

public ClipStatement(HlslTreeNode value)
{
Value = value;
}
}
}
6 changes: 6 additions & 0 deletions Hlsl/FlowControl/IStatement.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace HlslDecompiler.Hlsl.FlowControl
{
public interface IStatement
{
}
}
65 changes: 65 additions & 0 deletions Hlsl/FlowControl/StatementSequence.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using HlslDecompiler.DirectXShaderModel;
using System;
using System.Collections.Generic;
using System.Linq;

namespace HlslDecompiler.Hlsl.FlowControl
{
public class StatementSequence : IStatement
{
public Dictionary<RegisterComponentKey, HlslTreeNode> Outputs { get; set; }

public Dictionary<RegisterKey, HlslTreeNode> GroupAssignments()
{
return GroupComponents(Outputs.Where(IsAssignment));
}

public Dictionary<RegisterKey, HlslTreeNode> GroupOutputs()
{
return GroupComponents(Outputs.Where(IsOutput));
}

private static Dictionary<RegisterKey, HlslTreeNode> GroupComponents(IEnumerable<KeyValuePair<RegisterComponentKey, HlslTreeNode>> outputsByComponent)
{
return outputsByComponent
.OrderBy(o => o.Key.ComponentIndex)
.GroupBy(o => o.Key.RegisterKey)
.ToDictionary(
o => o.Key,
o => (HlslTreeNode)new GroupNode(o.Select(o => o.Value).ToArray()));
}

private static bool IsAssignment(KeyValuePair<RegisterComponentKey, HlslTreeNode> operation)
{
if (operation.Key.RegisterKey is D3D9RegisterKey key9)
{
return key9.Type == RegisterType.Temp;
}
else if (operation.Key.RegisterKey is D3D10RegisterKey key10)
{
return key10.OperandType == OperandType.Temp;
}
else
{
throw new NotImplementedException();
}
}

private static bool IsOutput(KeyValuePair<RegisterComponentKey, HlslTreeNode> operation)
{
if (operation.Key.RegisterKey is D3D9RegisterKey key9)
{
RegisterType type = key9.Type;
return type == RegisterType.Output || type == RegisterType.ColorOut || type == RegisterType.DepthOut;
}
else if (operation.Key.RegisterKey is D3D10RegisterKey key10)
{
return key10.OperandType == OperandType.Output;
}
else
{
throw new NotImplementedException();
}
}
}
}
51 changes: 3 additions & 48 deletions Hlsl/HlslAst.cs
Original file line number Diff line number Diff line change
@@ -1,62 +1,17 @@
using HlslDecompiler.DirectXShaderModel;
using HlslDecompiler.Hlsl.TemplateMatch;
using System;
using HlslDecompiler.Hlsl.FlowControl;
using System.Collections.Generic;
using System.Linq;

namespace HlslDecompiler.Hlsl
{
public class HlslAst
{
public List<StatementSequence> Statements { get; private set; }
public List<IStatement> Statements { get; private set; }
public RegisterState RegisterState { get; private set; }

public HlslAst(List<StatementSequence> statements, RegisterState registerState)
public HlslAst(List<IStatement> statements, RegisterState registerState)
{
Statements = statements;
RegisterState = registerState;
}

public List<Dictionary<RegisterKey, HlslTreeNode>> ReduceTree(NodeGrouper nodeGrouper)
{
var templateMatcher = new TemplateMatcher(nodeGrouper);
return Statements
.Select(GroupOutputs)
.Select(outputs => outputs.ToDictionary(r => r.Key, r => templateMatcher.Reduce(r.Value)))
.ToList();
}

public static Dictionary<RegisterKey, HlslTreeNode> GroupOutputs(StatementSequence statements)
{
IEnumerable<KeyValuePair<RegisterComponentKey, HlslTreeNode>> outputsByComponent =
statements.Outputs.Where(o =>
{
if (o.Value is ClipOperation)
{
return true;
}

if (o.Key.RegisterKey is D3D9RegisterKey key9)
{
RegisterType type = key9.Type;
return type == RegisterType.Output || type == RegisterType.ColorOut || type == RegisterType.DepthOut;
}
else if (o.Key.RegisterKey is D3D10RegisterKey key10)
{
return key10.OperandType == OperandType.Output;
}
else
{
throw new NotImplementedException();
}
});
var outputsByRegister = outputsByComponent
.OrderBy(o => o.Key.ComponentIndex)
.GroupBy(o => o.Key.RegisterKey)
.ToDictionary(
o => o.Key,
o => (HlslTreeNode)new GroupNode(o.Select(o => o.Value).ToArray()));
return outputsByRegister;
}
}
}
58 changes: 40 additions & 18 deletions Hlsl/HlslAstWriter.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using HlslDecompiler.DirectXShaderModel;
using HlslDecompiler.Hlsl.FlowControl;
using HlslDecompiler.Hlsl.TemplateMatch;
using System.Collections.Generic;
using System.Linq;

Expand Down Expand Up @@ -26,34 +28,54 @@ protected override void WriteMethodBody()
private void WriteAst(HlslAst ast)
{
var compiler = new NodeCompiler(_registers);
var nodeGrouper = new NodeGrouper(_registers);
var templateMatcher = new TemplateMatcher(nodeGrouper);

List<Dictionary<RegisterKey, HlslTreeNode>> sequenceRoots = ast.ReduceTree(new NodeGrouper(_registers));
foreach (var roots in sequenceRoots)
for (int i = 0; i < ast.Statements.Count; i++)
{
if (roots.Count == 1)
IStatement statement = ast.Statements[i];
if (statement is StatementSequence sequence)
{
if (roots.First().Value.Inputs.First() is ClipOperation)
bool isLastStatement = i == ast.Statements.Count - 1;

Dictionary<RegisterKey, HlslTreeNode> roots = isLastStatement
? sequence.GroupOutputs()
: sequence.GroupAssignments();
roots = roots.ToDictionary(r => r.Key, r => templateMatcher.Reduce(r.Value));

if (isLastStatement)
{
string statement = compiler.Compile(roots.First().Value);
WriteLine($"{statement};");
if (roots.Count == 1)
{
string compiled = compiler.Compile(roots.Single().Value);
WriteLine($"return {compiled};");
}
else
{
foreach (var rootGroup in roots)
{
RegisterDeclaration outputRegister = _registers.MethodOutputRegisters[rootGroup.Key];
string compiled = compiler.Compile(rootGroup.Value);
WriteLine($"o.{outputRegister.Name} = {compiled};");
}
WriteLine();
WriteLine($"return o;");
}
}
else
{
string statement = compiler.Compile(roots.Single().Value);
WriteLine($"return {statement};");
foreach (var rootGroup in roots)
{
string registerName = _registers.GetRegisterName(rootGroup.Key);
string compiled = compiler.Compile(rootGroup.Value);
WriteLine($"float4 {registerName} = {compiled};");
}
}
}
else
else if (statement is ClipStatement clip)
{
foreach (var rootGroup in roots)
{
RegisterDeclaration outputRegister = _registers.MethodOutputRegisters[rootGroup.Key];
string statement = compiler.Compile(rootGroup.Value);
WriteLine($"o.{outputRegister.Name} = {statement};");
}

WriteLine();
WriteLine($"return o;");
string compiled = compiler.Compile(clip.Value);
WriteLine($"clip({compiled});");
}
}
}
Expand Down
52 changes: 27 additions & 25 deletions Hlsl/InstructionParser.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using HlslDecompiler.DirectXShaderModel;
using HlslDecompiler.Hlsl.FlowControl;
using HlslDecompiler.Util;
using System;
using System.Collections.Generic;
Expand All @@ -12,7 +13,7 @@ class InstructionParser
{
private Dictionary<RegisterComponentKey, HlslTreeNode> _activeOutputs;
private RegisterState _registerState;
private List<StatementSequence> _statementSequences;
private List<IStatement> _sequences;
private StatementSequence _currentStatementSequence;

public static HlslAst Parse(ShaderModel shader)
Expand All @@ -25,10 +26,10 @@ private HlslAst ParseToAst(ShaderModel shader)
{
_activeOutputs = new Dictionary<RegisterComponentKey, HlslTreeNode>();
_registerState = new RegisterState(shader);
_statementSequences = new List<StatementSequence>();
_sequences = new List<IStatement>();

_currentStatementSequence = new StatementSequence();
_statementSequences.Add(_currentStatementSequence);
_sequences.Add(_currentStatementSequence);

LoadConstantOutputs(shader);

Expand All @@ -53,14 +54,22 @@ private HlslAst ParseToAst(ShaderModel shader)

_currentStatementSequence.Outputs = new Dictionary<RegisterComponentKey, HlslTreeNode>(_activeOutputs);

return new HlslAst(_statementSequences, _registerState);
return new HlslAst(_sequences, _registerState);
}

private void ParseInstruction(D3D9Instruction instruction)
{
if (instruction.HasDestination)
{
ParseAssignmentInstruction(instruction);
if (instruction.Opcode == Opcode.TexKill)
{
InsertClip(instruction);
EndStatementSequence(_activeOutputs);
}
else
{
ParseAssignmentInstruction(instruction);
}
}
else
{
Expand Down Expand Up @@ -101,14 +110,8 @@ private void ParseInstruction(D3D10Instruction instruction)
{
case D3D10Opcode.Discard:
{
RegisterKey registerKey = instruction.GetParamRegisterKey(0);
_currentStatementSequence.Outputs = new Dictionary<RegisterComponentKey, HlslTreeNode>();
RegisterComponentKey registerComponentKey = new RegisterComponentKey(registerKey, 0);
ClipOperation clip = new ClipOperation(new RegisterInputNode(registerComponentKey));
_currentStatementSequence.Outputs.Add(registerComponentKey, clip);

_currentStatementSequence = new StatementSequence();
_statementSequences.Add(_currentStatementSequence);
EndStatementSequence(_activeOutputs);
InsertClip(instruction);
break;
}
case D3D10Opcode.DclTemps:
Expand Down Expand Up @@ -290,16 +293,9 @@ private void ParseAssignmentInstruction(D3D9Instruction instruction)
newOutputs[destinationKey] = instructionTree;
}

if (instruction.Opcode == Opcode.TexKill)
{
EndStatementSequence(newOutputs);
}
else
foreach (var output in newOutputs)
{
foreach (var output in newOutputs)
{
_activeOutputs[output.Key] = output.Value;
}
_activeOutputs[output.Key] = output.Value;
}
}

Expand All @@ -308,7 +304,15 @@ private void EndStatementSequence(Dictionary<RegisterComponentKey, HlslTreeNode>
_currentStatementSequence.Outputs = new Dictionary<RegisterComponentKey, HlslTreeNode>(newOutputs);

_currentStatementSequence = new StatementSequence();
_statementSequences.Add(_currentStatementSequence);
_sequences.Add(_currentStatementSequence);
}

private void InsertClip(Instruction instruction)
{
RegisterKey registerKey = instruction.GetParamRegisterKey(0);
var registerComponentKey = new RegisterComponentKey(registerKey, 0);
var clip = new ClipStatement(new RegisterInputNode(registerComponentKey));
_sequences.Add(clip);
}

private void ParseAssignmentInstruction(D3D10Instruction instruction)
Expand Down Expand Up @@ -445,8 +449,6 @@ private HlslTreeNode CreateInstructionTree(D3D9Instruction instruction, Register
return new SignGreaterOrEqualOperation(inputs[0], inputs[1]);
case Opcode.Slt:
return new SignLessOperation(inputs[0], inputs[1]);
case Opcode.TexKill:
return new ClipOperation(inputs[0]);
default:
throw new NotImplementedException();
}
Expand Down
12 changes: 0 additions & 12 deletions Hlsl/Operations/ClipOperation.cs

This file was deleted.

10 changes: 0 additions & 10 deletions Hlsl/StatementSequence.cs

This file was deleted.

1 change: 1 addition & 0 deletions HlslDecompiler.Tests/DecompileDxbcTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using HlslDecompiler.DirectXShaderModel;
using NUnit.Framework;
using System.IO;
using NUnit.Framework.Legacy;

namespace HlslDecompiler.Tests
{
Expand Down
Loading

0 comments on commit 02bff80

Please sign in to comment.