Skip to content

Commit

Permalink
Group AST roots by register
Browse files Browse the repository at this point in the history
  • Loading branch information
AndresTraks committed May 8, 2021
1 parent 966a1cd commit b3347c9
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 64 deletions.
6 changes: 6 additions & 0 deletions DirectXShaderModel/AsmWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ private void WriteInstruction(Instruction instruction)
case Opcode.EndLoop:
WriteLine("endloop");
break;
case Opcode.EndRep:
WriteLine("endrep");
break;
case Opcode.Exp:
WriteLine("exp{0} {1}, {2}", GetModifier(instruction), GetDestinationName(instruction),
GetSourceName(instruction, 1));
Expand Down Expand Up @@ -235,6 +238,9 @@ private void WriteInstruction(Instruction instruction)
WriteLine("pow{0} {1}, {2}, {3}", GetModifier(instruction), GetDestinationName(instruction),
GetSourceName(instruction, 1), GetSourceName(instruction, 2));
break;
case Opcode.Rep:
WriteLine("rep {0}", GetDestinationName(instruction));
break;
case Opcode.Rcp:
WriteLine("rcp{0} {1}, {2}", GetModifier(instruction), GetDestinationName(instruction),
GetSourceName(instruction, 1));
Expand Down
5 changes: 5 additions & 0 deletions Hlsl/Compiler/NodeCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ public NodeCompiler(RegisterState registers)
_matrixMultiplicationCompiler = new MatrixMultiplicationCompiler(this);
}

public string Compile(HlslTreeNode node)
{
return Compile(new List<HlslTreeNode>() { node });
}

public string Compile(IEnumerable<HlslTreeNode> group, int promoteToVectorSize = PromoteToAnyVectorSize)
{
return Compile(group.ToList(), promoteToVectorSize);
Expand Down
8 changes: 4 additions & 4 deletions Hlsl/HlslAst.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ namespace HlslDecompiler.Hlsl
{
public class HlslAst
{
public Dictionary<RegisterComponentKey, HlslTreeNode> Roots { get; private set; }
public Dictionary<RegisterComponentKey, HlslTreeNode> NoOutputInstructions { get; private set; }
public Dictionary<RegisterKey, HlslTreeNode> Roots { get; private set; }
public Dictionary<RegisterKey, HlslTreeNode> NoOutputInstructions { get; private set; }

public HlslAst(Dictionary<RegisterComponentKey, HlslTreeNode> roots,
Dictionary<RegisterComponentKey, HlslTreeNode> noOutputInstructions)
public HlslAst(Dictionary<RegisterKey, HlslTreeNode> roots,
Dictionary<RegisterKey, HlslTreeNode> noOutputInstructions)
{
Roots = roots;
NoOutputInstructions = noOutputInstructions;
Expand Down
23 changes: 7 additions & 16 deletions Hlsl/HlslAstWriter.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using HlslDecompiler.DirectXShaderModel;
using HlslDecompiler.Hlsl;
using System;
using System.Collections.Generic;
using System.Linq;

Expand Down Expand Up @@ -32,39 +33,29 @@ private void WriteAst(HlslAst ast)
{
var compiler = new NodeCompiler(_registers);

var noOutputInstructionRoots = ast.NoOutputInstructions.GroupBy(r => r.Key.RegisterKey);
foreach (var rootGroup in noOutputInstructionRoots)
foreach (var rootGroup in ast.NoOutputInstructions)
{
string statement = CompileRootStatement(compiler, rootGroup);
string statement = compiler.Compile(rootGroup.Value);
WriteLine($"{statement};");
}

var rootGroups = ast.Roots.GroupBy(r => r.Key.RegisterKey);
if (_registers.MethodOutputRegisters.Count == 1)
if (ast.Roots.Count == 1)
{
string statement = CompileRootStatement(compiler, rootGroups.Single());
string statement = compiler.Compile(ast.Roots.Single().Value);
WriteLine($"return {statement};");
}
else
{
foreach (var rootGroup in rootGroups)
foreach (var rootGroup in ast.Roots)
{
RegisterDeclaration outputRegister = _registers.MethodOutputRegisters[rootGroup.Key];
string statement = CompileRootStatement(compiler, rootGroup);
string statement = compiler.Compile(rootGroup.Value);
WriteLine($"o.{outputRegister.Name} = {statement};");
}

WriteLine();
WriteLine($"return o;");
}
}

private static string CompileRootStatement(NodeCompiler compiler,
IGrouping<RegisterKey, KeyValuePair<RegisterComponentKey, HlslTreeNode>> rootGroup)
{
var registerKey = rootGroup.Key;
var roots = rootGroup.OrderBy(r => r.Key.ComponentIndex).Select(r => r.Value).ToList();
return compiler.Compile(roots, roots.Count);
}
}
}
5 changes: 3 additions & 2 deletions Hlsl/HlslSimpleWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,14 @@ private void WriteInstruction(Instruction instruction)
uint end = intRegister.Value[0];
uint start = intRegister.Value[1];
uint stride = intRegister.Value[2];
string loopVariable = "i0";
if (stride == 1)
{
WriteLine("for (int i = {0}; i < {1}; i++) {{", start, end);
WriteLine("for (int {2} = {0}; {2} < {1}; {2}++) {{", start, end, loopVariable);
}
else
{
WriteLine("for (int i = {0}; i < {1}; i += {2}) {{", start, end, stride);
WriteLine("for (int {3} = {0}; {3} < {1}; {3} += {2}) {{", start, end, stride, loopVariable);
}
indent += "\t";
break;
Expand Down
98 changes: 56 additions & 42 deletions Hlsl/InstructionParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,57 +8,74 @@ namespace HlslDecompiler.Hlsl
class BytecodeParser
{
private Dictionary<RegisterComponentKey, HlslTreeNode> _activeOutputs;
private Dictionary<RegisterComponentKey, HlslTreeNode> _noOutputInstructions;
private Dictionary<RegisterKey, HlslTreeNode> _noOutputInstructions;
private Dictionary<RegisterKey, HlslTreeNode> _samplers;

public HlslAst Parse(ShaderModel shader)
{
_activeOutputs = new Dictionary<RegisterComponentKey, HlslTreeNode>();
_noOutputInstructions = new Dictionary<RegisterComponentKey, HlslTreeNode>();
_noOutputInstructions = new Dictionary<RegisterKey, HlslTreeNode>();
_samplers = new Dictionary<RegisterKey, HlslTreeNode>();

LoadConstantOutputs(shader);

int instructionPointer = 0;
bool ifBlock = false;
while (instructionPointer < shader.Instructions.Count)
{
var instruction = shader.Instructions[instructionPointer];
if (ifBlock)
if (instruction.HasDestination)
{
if (instruction.Opcode == Opcode.Else)
{
ifBlock = false;
}
ParseAssignmentInstruction(instruction);
}
else
{
if (instruction.Opcode == Opcode.IfC)
switch (instruction.Opcode)
{
ifBlock = true;
case Opcode.If:
case Opcode.IfC:
case Opcode.Else:
case Opcode.Loop:
case Opcode.Rep:
case Opcode.End:
case Opcode.Endif:
case Opcode.EndLoop:
case Opcode.EndRep:
ParseControlInstruction(instruction);
break;
case Opcode.Comment:
break;
default:
throw new NotImplementedException();
}
ParseInstruction(instruction);
}
instructionPointer++;
}

Dictionary<RegisterComponentKey, HlslTreeNode> roots;
if (shader.Type == ShaderType.Pixel)
{
roots = _activeOutputs
.Where(o => o.Key.Type == RegisterType.ColorOut)
.ToDictionary(o => o.Key, o => o.Value);
}
else
{
roots = _activeOutputs
.Where(o => o.Key.Type == RegisterType.Output)
.ToDictionary(o => o.Key, o => o.Value);
}

Dictionary<RegisterKey, HlslTreeNode> roots = GroupOutputs(shader);
return new HlslAst(roots, _noOutputInstructions);
}

public Dictionary<RegisterKey, HlslTreeNode> GroupOutputs(ShaderModel shader)
{
var registerType = shader.Type == ShaderType.Pixel
? RegisterType.ColorOut
: RegisterType.Output;
var outputsByRegister = _activeOutputs
.Where(o => o.Key.Type == registerType)
.OrderBy(o => o.Key.ComponentIndex)
.GroupBy(o => o.Key.RegisterKey);
var groupsByRegister = outputsByRegister
.ToDictionary(
o => o.Key,
o => (HlslTreeNode)new GroupNode(o.Select(o => o.Value).ToArray()));
return groupsByRegister;
}

private void ParseControlInstruction(Instruction instruction)
{
// TODO
}

private void LoadConstantOutputs(ShaderModel shader)
{
IList<ConstantDeclaration> constantTable = shader.ParseConstantTable();
Expand Down Expand Up @@ -108,29 +125,26 @@ private void LoadConstantOutputs(ShaderModel shader)
}
}

private void ParseInstruction(Instruction instruction)
private void ParseAssignmentInstruction(Instruction instruction)
{
if (instruction.HasDestination)
var newOutputs = new Dictionary<RegisterComponentKey, HlslTreeNode>();

RegisterComponentKey[] destinationKeys = GetDestinationKeys(instruction).ToArray();
foreach (RegisterComponentKey destinationKey in destinationKeys)
{
var newOutputs = new Dictionary<RegisterComponentKey, HlslTreeNode>();
HlslTreeNode instructionTree = CreateInstructionTree(instruction, destinationKey);
newOutputs[destinationKey] = instructionTree;
}

RegisterComponentKey[] destinationKeys = GetDestinationKeys(instruction).ToArray();
foreach (RegisterComponentKey destinationKey in destinationKeys)
foreach (var output in newOutputs)
{
if (instruction.Opcode == Opcode.TexKill)
{
HlslTreeNode instructionTree = CreateInstructionTree(instruction, destinationKey);
newOutputs[destinationKey] = instructionTree;
_noOutputInstructions[output.Key.RegisterKey] = output.Value;
}

foreach (var output in newOutputs)
else
{
if (instruction.Opcode == Opcode.TexKill)
{
_noOutputInstructions[output.Key] = output.Value;
}
else
{
_activeOutputs[output.Key] = output.Value;
}
_activeOutputs[output.Key] = output.Value;
}
}
}
Expand Down

0 comments on commit b3347c9

Please sign in to comment.