From 02bff8022922657946c4f3b739fde5544c468063 Mon Sep 17 00:00:00 2001 From: Andres Traks Date: Sun, 15 Dec 2024 11:40:33 +0200 Subject: [PATCH] Split HLSL program into statements --- Hlsl/Compiler/NodeGrouper.cs | 5 -- Hlsl/FlowControl/BreakStatement.cs | 6 ++ Hlsl/FlowControl/ClipStatement.cs | 12 ++++ Hlsl/FlowControl/IStatement.cs | 6 ++ Hlsl/FlowControl/StatementSequence.cs | 65 +++++++++++++++++++ Hlsl/HlslAst.cs | 51 +-------------- Hlsl/HlslAstWriter.cs | 58 ++++++++++++----- Hlsl/InstructionParser.cs | 52 ++++++++------- Hlsl/Operations/ClipOperation.cs | 12 ---- Hlsl/StatementSequence.cs | 10 --- HlslDecompiler.Tests/DecompileDxbcTests.cs | 1 + HlslDecompiler.Tests/DecompileTests.cs | 3 +- .../HlslDecompiler.Tests.csproj | 9 ++- .../ShaderSources/ps_3_0/clip.fx | 3 +- .../ShaderSources/ps_3_0/loop.fx | 6 +- .../{loop_instruction.fx => loop.fx} | 0 HlslDecompiler.csproj | 2 +- 17 files changed, 174 insertions(+), 127 deletions(-) create mode 100644 Hlsl/FlowControl/BreakStatement.cs create mode 100644 Hlsl/FlowControl/ClipStatement.cs create mode 100644 Hlsl/FlowControl/IStatement.cs create mode 100644 Hlsl/FlowControl/StatementSequence.cs delete mode 100644 Hlsl/Operations/ClipOperation.cs delete mode 100644 Hlsl/StatementSequence.cs rename HlslDecompiler.Tests/ShaderSources/ps_3_0_instruction/{loop_instruction.fx => loop.fx} (100%) diff --git a/Hlsl/Compiler/NodeGrouper.cs b/Hlsl/Compiler/NodeGrouper.cs index aa1dc3a..48bf68c 100644 --- a/Hlsl/Compiler/NodeGrouper.cs +++ b/Hlsl/Compiler/NodeGrouper.cs @@ -169,11 +169,6 @@ public bool CanGroupComponents(HlslTreeNode node1, HlslTreeNode node2, bool allo CanGroupComponents(compare1.LessValue, compare2.LessValue) && CanGroupComponents(compare1.GreaterEqualValue, compare2.GreaterEqualValue); } - else if (operation1 is ClipOperation clip1 && - operation2 is ClipOperation clip2) - { - return CanGroupComponents(clip1.Value, clip2.Value); - } else if (operation1 is LengthOperation length1 && operation2 is LengthOperation length2) { diff --git a/Hlsl/FlowControl/BreakStatement.cs b/Hlsl/FlowControl/BreakStatement.cs new file mode 100644 index 0000000..5ebeeea --- /dev/null +++ b/Hlsl/FlowControl/BreakStatement.cs @@ -0,0 +1,6 @@ +namespace HlslDecompiler.Hlsl.FlowControl +{ + public class BreakStatement : IStatement + { + } +} diff --git a/Hlsl/FlowControl/ClipStatement.cs b/Hlsl/FlowControl/ClipStatement.cs new file mode 100644 index 0000000..7260374 --- /dev/null +++ b/Hlsl/FlowControl/ClipStatement.cs @@ -0,0 +1,12 @@ +namespace HlslDecompiler.Hlsl.FlowControl +{ + public class ClipStatement : IStatement + { + public HlslTreeNode Value { get; } + + public ClipStatement(HlslTreeNode value) + { + Value = value; + } + } +} diff --git a/Hlsl/FlowControl/IStatement.cs b/Hlsl/FlowControl/IStatement.cs new file mode 100644 index 0000000..7f8fff4 --- /dev/null +++ b/Hlsl/FlowControl/IStatement.cs @@ -0,0 +1,6 @@ +namespace HlslDecompiler.Hlsl.FlowControl +{ + public interface IStatement + { + } +} diff --git a/Hlsl/FlowControl/StatementSequence.cs b/Hlsl/FlowControl/StatementSequence.cs new file mode 100644 index 0000000..ec14d23 --- /dev/null +++ b/Hlsl/FlowControl/StatementSequence.cs @@ -0,0 +1,65 @@ +using HlslDecompiler.DirectXShaderModel; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace HlslDecompiler.Hlsl.FlowControl +{ + public class StatementSequence : IStatement + { + public Dictionary Outputs { get; set; } + + public Dictionary GroupAssignments() + { + return GroupComponents(Outputs.Where(IsAssignment)); + } + + public Dictionary GroupOutputs() + { + return GroupComponents(Outputs.Where(IsOutput)); + } + + private static Dictionary GroupComponents(IEnumerable> outputsByComponent) + { + return outputsByComponent + .OrderBy(o => o.Key.ComponentIndex) + .GroupBy(o => o.Key.RegisterKey) + .ToDictionary( + o => o.Key, + o => (HlslTreeNode)new GroupNode(o.Select(o => o.Value).ToArray())); + } + + private static bool IsAssignment(KeyValuePair operation) + { + if (operation.Key.RegisterKey is D3D9RegisterKey key9) + { + return key9.Type == RegisterType.Temp; + } + else if (operation.Key.RegisterKey is D3D10RegisterKey key10) + { + return key10.OperandType == OperandType.Temp; + } + else + { + throw new NotImplementedException(); + } + } + + private static bool IsOutput(KeyValuePair operation) + { + if (operation.Key.RegisterKey is D3D9RegisterKey key9) + { + RegisterType type = key9.Type; + return type == RegisterType.Output || type == RegisterType.ColorOut || type == RegisterType.DepthOut; + } + else if (operation.Key.RegisterKey is D3D10RegisterKey key10) + { + return key10.OperandType == OperandType.Output; + } + else + { + throw new NotImplementedException(); + } + } + } +} diff --git a/Hlsl/HlslAst.cs b/Hlsl/HlslAst.cs index b0383fe..63fad2b 100644 --- a/Hlsl/HlslAst.cs +++ b/Hlsl/HlslAst.cs @@ -1,62 +1,17 @@ -using HlslDecompiler.DirectXShaderModel; -using HlslDecompiler.Hlsl.TemplateMatch; -using System; +using HlslDecompiler.Hlsl.FlowControl; using System.Collections.Generic; -using System.Linq; namespace HlslDecompiler.Hlsl { public class HlslAst { - public List Statements { get; private set; } + public List Statements { get; private set; } public RegisterState RegisterState { get; private set; } - public HlslAst(List statements, RegisterState registerState) + public HlslAst(List statements, RegisterState registerState) { Statements = statements; RegisterState = registerState; } - - public List> ReduceTree(NodeGrouper nodeGrouper) - { - var templateMatcher = new TemplateMatcher(nodeGrouper); - return Statements - .Select(GroupOutputs) - .Select(outputs => outputs.ToDictionary(r => r.Key, r => templateMatcher.Reduce(r.Value))) - .ToList(); - } - - public static Dictionary GroupOutputs(StatementSequence statements) - { - IEnumerable> outputsByComponent = - statements.Outputs.Where(o => - { - if (o.Value is ClipOperation) - { - return true; - } - - if (o.Key.RegisterKey is D3D9RegisterKey key9) - { - RegisterType type = key9.Type; - return type == RegisterType.Output || type == RegisterType.ColorOut || type == RegisterType.DepthOut; - } - else if (o.Key.RegisterKey is D3D10RegisterKey key10) - { - return key10.OperandType == OperandType.Output; - } - else - { - throw new NotImplementedException(); - } - }); - var outputsByRegister = outputsByComponent - .OrderBy(o => o.Key.ComponentIndex) - .GroupBy(o => o.Key.RegisterKey) - .ToDictionary( - o => o.Key, - o => (HlslTreeNode)new GroupNode(o.Select(o => o.Value).ToArray())); - return outputsByRegister; - } } } diff --git a/Hlsl/HlslAstWriter.cs b/Hlsl/HlslAstWriter.cs index f344017..f2dee84 100644 --- a/Hlsl/HlslAstWriter.cs +++ b/Hlsl/HlslAstWriter.cs @@ -1,4 +1,6 @@ using HlslDecompiler.DirectXShaderModel; +using HlslDecompiler.Hlsl.FlowControl; +using HlslDecompiler.Hlsl.TemplateMatch; using System.Collections.Generic; using System.Linq; @@ -26,34 +28,54 @@ protected override void WriteMethodBody() private void WriteAst(HlslAst ast) { var compiler = new NodeCompiler(_registers); + var nodeGrouper = new NodeGrouper(_registers); + var templateMatcher = new TemplateMatcher(nodeGrouper); - List> sequenceRoots = ast.ReduceTree(new NodeGrouper(_registers)); - foreach (var roots in sequenceRoots) + for (int i = 0; i < ast.Statements.Count; i++) { - if (roots.Count == 1) + IStatement statement = ast.Statements[i]; + if (statement is StatementSequence sequence) { - if (roots.First().Value.Inputs.First() is ClipOperation) + bool isLastStatement = i == ast.Statements.Count - 1; + + Dictionary roots = isLastStatement + ? sequence.GroupOutputs() + : sequence.GroupAssignments(); + roots = roots.ToDictionary(r => r.Key, r => templateMatcher.Reduce(r.Value)); + + if (isLastStatement) { - string statement = compiler.Compile(roots.First().Value); - WriteLine($"{statement};"); + if (roots.Count == 1) + { + string compiled = compiler.Compile(roots.Single().Value); + WriteLine($"return {compiled};"); + } + else + { + foreach (var rootGroup in roots) + { + RegisterDeclaration outputRegister = _registers.MethodOutputRegisters[rootGroup.Key]; + string compiled = compiler.Compile(rootGroup.Value); + WriteLine($"o.{outputRegister.Name} = {compiled};"); + } + WriteLine(); + WriteLine($"return o;"); + } } else { - string statement = compiler.Compile(roots.Single().Value); - WriteLine($"return {statement};"); + foreach (var rootGroup in roots) + { + string registerName = _registers.GetRegisterName(rootGroup.Key); + string compiled = compiler.Compile(rootGroup.Value); + WriteLine($"float4 {registerName} = {compiled};"); + } } } - else + else if (statement is ClipStatement clip) { - foreach (var rootGroup in roots) - { - RegisterDeclaration outputRegister = _registers.MethodOutputRegisters[rootGroup.Key]; - string statement = compiler.Compile(rootGroup.Value); - WriteLine($"o.{outputRegister.Name} = {statement};"); - } - - WriteLine(); - WriteLine($"return o;"); + string compiled = compiler.Compile(clip.Value); + WriteLine($"clip({compiled});"); } } } diff --git a/Hlsl/InstructionParser.cs b/Hlsl/InstructionParser.cs index aac24b9..11ca607 100644 --- a/Hlsl/InstructionParser.cs +++ b/Hlsl/InstructionParser.cs @@ -1,4 +1,5 @@ using HlslDecompiler.DirectXShaderModel; +using HlslDecompiler.Hlsl.FlowControl; using HlslDecompiler.Util; using System; using System.Collections.Generic; @@ -12,7 +13,7 @@ class InstructionParser { private Dictionary _activeOutputs; private RegisterState _registerState; - private List _statementSequences; + private List _sequences; private StatementSequence _currentStatementSequence; public static HlslAst Parse(ShaderModel shader) @@ -25,10 +26,10 @@ private HlslAst ParseToAst(ShaderModel shader) { _activeOutputs = new Dictionary(); _registerState = new RegisterState(shader); - _statementSequences = new List(); + _sequences = new List(); _currentStatementSequence = new StatementSequence(); - _statementSequences.Add(_currentStatementSequence); + _sequences.Add(_currentStatementSequence); LoadConstantOutputs(shader); @@ -53,14 +54,22 @@ private HlslAst ParseToAst(ShaderModel shader) _currentStatementSequence.Outputs = new Dictionary(_activeOutputs); - return new HlslAst(_statementSequences, _registerState); + return new HlslAst(_sequences, _registerState); } private void ParseInstruction(D3D9Instruction instruction) { if (instruction.HasDestination) { - ParseAssignmentInstruction(instruction); + if (instruction.Opcode == Opcode.TexKill) + { + InsertClip(instruction); + EndStatementSequence(_activeOutputs); + } + else + { + ParseAssignmentInstruction(instruction); + } } else { @@ -101,14 +110,8 @@ private void ParseInstruction(D3D10Instruction instruction) { case D3D10Opcode.Discard: { - RegisterKey registerKey = instruction.GetParamRegisterKey(0); - _currentStatementSequence.Outputs = new Dictionary(); - RegisterComponentKey registerComponentKey = new RegisterComponentKey(registerKey, 0); - ClipOperation clip = new ClipOperation(new RegisterInputNode(registerComponentKey)); - _currentStatementSequence.Outputs.Add(registerComponentKey, clip); - - _currentStatementSequence = new StatementSequence(); - _statementSequences.Add(_currentStatementSequence); + EndStatementSequence(_activeOutputs); + InsertClip(instruction); break; } case D3D10Opcode.DclTemps: @@ -290,16 +293,9 @@ private void ParseAssignmentInstruction(D3D9Instruction instruction) newOutputs[destinationKey] = instructionTree; } - if (instruction.Opcode == Opcode.TexKill) - { - EndStatementSequence(newOutputs); - } - else + foreach (var output in newOutputs) { - foreach (var output in newOutputs) - { - _activeOutputs[output.Key] = output.Value; - } + _activeOutputs[output.Key] = output.Value; } } @@ -308,7 +304,15 @@ private void EndStatementSequence(Dictionary _currentStatementSequence.Outputs = new Dictionary(newOutputs); _currentStatementSequence = new StatementSequence(); - _statementSequences.Add(_currentStatementSequence); + _sequences.Add(_currentStatementSequence); + } + + private void InsertClip(Instruction instruction) + { + RegisterKey registerKey = instruction.GetParamRegisterKey(0); + var registerComponentKey = new RegisterComponentKey(registerKey, 0); + var clip = new ClipStatement(new RegisterInputNode(registerComponentKey)); + _sequences.Add(clip); } private void ParseAssignmentInstruction(D3D10Instruction instruction) @@ -445,8 +449,6 @@ private HlslTreeNode CreateInstructionTree(D3D9Instruction instruction, Register return new SignGreaterOrEqualOperation(inputs[0], inputs[1]); case Opcode.Slt: return new SignLessOperation(inputs[0], inputs[1]); - case Opcode.TexKill: - return new ClipOperation(inputs[0]); default: throw new NotImplementedException(); } diff --git a/Hlsl/Operations/ClipOperation.cs b/Hlsl/Operations/ClipOperation.cs deleted file mode 100644 index 008da46..0000000 --- a/Hlsl/Operations/ClipOperation.cs +++ /dev/null @@ -1,12 +0,0 @@ -namespace HlslDecompiler.Hlsl -{ - public class ClipOperation : ConsumerOperation - { - public ClipOperation(HlslTreeNode value) - { - AddInput(value); - } - - public override string Mnemonic => "clip"; - } -} diff --git a/Hlsl/StatementSequence.cs b/Hlsl/StatementSequence.cs deleted file mode 100644 index 353d237..0000000 --- a/Hlsl/StatementSequence.cs +++ /dev/null @@ -1,10 +0,0 @@ -using HlslDecompiler.DirectXShaderModel; -using System.Collections.Generic; - -namespace HlslDecompiler.Hlsl -{ - public class StatementSequence - { - public Dictionary Outputs { get; set; } - } -} diff --git a/HlslDecompiler.Tests/DecompileDxbcTests.cs b/HlslDecompiler.Tests/DecompileDxbcTests.cs index 986178b..f7c793c 100644 --- a/HlslDecompiler.Tests/DecompileDxbcTests.cs +++ b/HlslDecompiler.Tests/DecompileDxbcTests.cs @@ -2,6 +2,7 @@ using HlslDecompiler.DirectXShaderModel; using NUnit.Framework; using System.IO; +using NUnit.Framework.Legacy; namespace HlslDecompiler.Tests { diff --git a/HlslDecompiler.Tests/DecompileTests.cs b/HlslDecompiler.Tests/DecompileTests.cs index c9b16ae..7b654f3 100644 --- a/HlslDecompiler.Tests/DecompileTests.cs +++ b/HlslDecompiler.Tests/DecompileTests.cs @@ -2,6 +2,7 @@ using HlslDecompiler.DirectXShaderModel; using NUnit.Framework; using System.IO; +using NUnit.Framework.Legacy; namespace HlslDecompiler.Tests { @@ -74,8 +75,8 @@ public void DecompileShaderTest(string profile, string baseFilename) hlslWriter.Write(hlslOutputFilename); FileAssert.AreEqual(asmExpectedFilename, asmOutputFilename, "Assembly not equal at " + asmOutputFilename); - FileAssert.AreEqual(hlslExpectedFilename, hlslOutputFilename, "HLSL not equal at " + hlslOutputFilename); FileAssert.AreEqual(hlslInstructionExpectedFilename, hlslInstructionOutputFilename, "HLSL not equal at " + hlslInstructionOutputFilename); + FileAssert.AreEqual(hlslExpectedFilename, hlslOutputFilename, "HLSL not equal at " + hlslOutputFilename); } } } diff --git a/HlslDecompiler.Tests/HlslDecompiler.Tests.csproj b/HlslDecompiler.Tests/HlslDecompiler.Tests.csproj index 4d0a7c2..7fa6aba 100644 --- a/HlslDecompiler.Tests/HlslDecompiler.Tests.csproj +++ b/HlslDecompiler.Tests/HlslDecompiler.Tests.csproj @@ -7,9 +7,9 @@ - - - + + + @@ -494,6 +494,9 @@ PreserveNewest + + PreserveNewest + PreserveNewest diff --git a/HlslDecompiler.Tests/ShaderSources/ps_3_0/clip.fx b/HlslDecompiler.Tests/ShaderSources/ps_3_0/clip.fx index c61200b..3daded8 100644 --- a/HlslDecompiler.Tests/ShaderSources/ps_3_0/clip.fx +++ b/HlslDecompiler.Tests/ShaderSources/ps_3_0/clip.fx @@ -1,5 +1,6 @@ float4 main(float4 texcoord : TEXCOORD) : COLOR { - clip(-1); + float4 r0 = -1; + clip(r0.x); return texcoord; } diff --git a/HlslDecompiler.Tests/ShaderSources/ps_3_0/loop.fx b/HlslDecompiler.Tests/ShaderSources/ps_3_0/loop.fx index d6c0e3f..1f8ecad 100644 --- a/HlslDecompiler.Tests/ShaderSources/ps_3_0/loop.fx +++ b/HlslDecompiler.Tests/ShaderSources/ps_3_0/loop.fx @@ -2,9 +2,9 @@ float count; float4 main(float4 texcoord : TEXCOORD) : COLOR { - float4 o = 0; + float4 r0 = 0; for (int i = 3; i < count; i++) { - o += texcoord; + r0 += texcoord; } - return o; + return r0; } diff --git a/HlslDecompiler.Tests/ShaderSources/ps_3_0_instruction/loop_instruction.fx b/HlslDecompiler.Tests/ShaderSources/ps_3_0_instruction/loop.fx similarity index 100% rename from HlslDecompiler.Tests/ShaderSources/ps_3_0_instruction/loop_instruction.fx rename to HlslDecompiler.Tests/ShaderSources/ps_3_0_instruction/loop.fx diff --git a/HlslDecompiler.csproj b/HlslDecompiler.csproj index 4e343bd..5b789b7 100644 --- a/HlslDecompiler.csproj +++ b/HlslDecompiler.csproj @@ -12,7 +12,7 @@ - +