Skip to content

Commit

Permalink
User/chrila/enable graph compile (#606)
Browse files Browse the repository at this point in the history
* Add Graph option

* Update optional tensor logic

* Move json parser logic for DmlCompileType

* Update version

* Update DmlCompileType namespace, json def, and updated Guid.md

* update spacing

---------

Co-authored-by: Christian Larson <[email protected]>
  • Loading branch information
chrilaMSFT and chrilaMSFT authored Jul 12, 2024
1 parent 72ad224 commit 61a1a50
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 26 deletions.
22 changes: 22 additions & 0 deletions DxDispatch/doc/Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,28 @@ Take note of the few odd cases that don't follow the usual rule exactly:
- Enum values of type `DML_OPERATOR_TYPE` omit `_TYPE` from their prefix. It's `DML_OPERATOR_GEMM`, not `DML_OPERATOR_TYPE_GEMM`.
- Flag values are singular and omit the "S". It's `DML_EXECUTION_FLAG_NONE`, not `DML_EXECUTION_FLAGS_NONE`.

### DirectML Compile Op vs Graph (dmlCompileType)
Enum dmlCompileType configures whether a defined DirectML operator uses IDMLDevice::CompileOperator or the operator is inserted into DML_GRAPH_DESC and compiled using IDMLDevice1::CompileGraph.

| Enums for dmlCompileType | Description |
| ------------------------------------------------ | ------------------------------------------------------------------------- |
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Uses IDMLDevice::CompileOperator for defined operator |
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Inserts Operator into a DML_GRAPH_DESC and uses IDMLDevice1::CompileGraph |

Syntax:

```json
"dmlOperator":
{
"type": "DML_OPERATOR_*",
"dmlCompileType": "DmlCompileGraph",
"Desc": { ... }
}
```

See full example in [dml_gemm_graph.json](../models/dml_gemm_graph.json).


### DML_TENSOR_DESC

Since tensor descs are so common, the JSON parser provides default values for most fields.
Expand Down
9 changes: 9 additions & 0 deletions DxDispatch/models/_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@
]
},

"dmlCompileType":
{
"enum":
[
"DmlCompileOp",
"DmlCompileGraph"
]
},

"arrayInitializer":
{
"type": "array",
Expand Down
57 changes: 57 additions & 0 deletions DxDispatch/models/dml_gemm_graph.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"$schema": "./_schema.json",

"resources":
{
"A": {
"initialValuesDataType": "FLOAT32",
"initialValues": { "valueCount": 1024, "value": 1 }
},
"B": {
"initialValuesDataType": "FLOAT32",
"initialValues": { "valueCount": 1024, "value": 1 }
},
"output": {
"initialValuesDataType": "FLOAT32",
"initialValues": { "valueCount": 1024, "value": 1 }
}
},

"dispatchables":
{
"gemm":
{
"type": "DML_OPERATOR_GEMM",
"desc":
{
"ATensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
"BTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32], "Flags": "DML_TENSOR_FLAG_OWNED_BY_DML" },
"OutputTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
"TransA": "DML_MATRIX_TRANSFORM_NONE",
"TransB": "DML_MATRIX_TRANSFORM_NONE",
"Alpha": 1.0,
"Beta": 1.0
},
"dmlCompileType": "DmlCompileGraph",
"executionFlags": "DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION",
"bindings":
{
"BTensor": "B"
}
}
},

"commands":
[
{
"type": "dispatch",
"dispatchable": "gemm",
"bindings":
{
"ATensor": "A",
"OutputTensor": "output"
}
},
{ "type": "print", "resource": "output" }
]
}
122 changes: 98 additions & 24 deletions DxDispatch/src/dxdispatch/DmlDispatchable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ DmlDispatchable::DmlDispatchable(
std::string_view name,
std::shared_ptr<Device> device,
const Model::DmlDispatchableDesc& desc,
const Dispatchable::Bindings& initBindings
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings))
const Dispatchable::Bindings& initBindings,
IDxDispatchLogger* logger
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings)), m_logger(logger)
{
THROW_IF_FAILED(device->DML()->CreateOperator(desc.desc, IID_PPV_ARGS(&m_operator)));
}
Expand All @@ -28,7 +29,8 @@ void FillBindingData(
const Dispatchable::Bindings* initializeBindings,
const Dispatchable::Bindings* executeBindings,
BindingData& bindingData,
bool bindingForInitialization = false)
bool bindingForInitialization,
Model::DmlDispatchableDesc::DmlCompileType compileType)
{
const Dispatchable::Bindings& bindings = bindingForInitialization ? *initializeBindings : *executeBindings;

Expand All @@ -47,22 +49,23 @@ void FillBindingData(

if (bindingIterator == bindings.end())
{
if (bindPoints[i].required && !bindingForInitialization)
for (size_t j = 0; j < bindPoints[i].resourceCount; j++)
{
if (!initializeBindings || initializeBindings->find(bindPointName) == initializeBindings->end())
if (compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph && !bindPoints[i].requiredBinding)
{
throw std::invalid_argument(fmt::format("Nothing bound for required tensor '{}'.", bindPointName));
// Dml Graph will fail if given DML_BINDING_TYPE_NONE for optional bindings not described in the graph.
bindingData.bindingDescs.pop_back();
bindingData.bufferBindings.pop_back();
}
else
{
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
bindingData.bufferBindings[bufferIndex].Offset = 0;
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
bufferIndex++;
}
}

for (size_t j = 0; j < bindPoints[i].resourceCount; j++)
{
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
bindingData.bufferBindings[bufferIndex].Offset = 0;
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
bufferIndex++;
}
}
else
Expand Down Expand Up @@ -103,11 +106,82 @@ void FillBindingData(

void DmlDispatchable::Initialize()
{
THROW_IF_FAILED(m_device->DML()->CompileOperator(
m_operator.Get(),
m_desc.executionFlags,
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
if(m_desc.compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp)
{
m_logger->LogInfo("Compile Op");
THROW_IF_FAILED(m_device->DML()->CompileOperator(
m_operator.Get(),
m_desc.executionFlags,
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
}
else
{
m_logger->LogInfo("Compiling op using IDMLDevice1::CompileGraph");
DML_GRAPH_DESC dmlGraphDesc = {};
std::vector<DML_INPUT_GRAPH_EDGE_DESC> dmlInputGraphEdges;
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges;

std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> dmlOutputGraphEdges;
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges;
DML_GRAPH_NODE_DESC dmlGraphNodeDesc = {};
DML_OPERATOR_GRAPH_NODE_DESC nodeDesc{};

nodeDesc.Operator = m_operator.Get();
nodeDesc.Name = m_name.c_str();

{
dmlGraphNodeDesc.Type = DML_GRAPH_NODE_TYPE_OPERATOR;
dmlGraphNodeDesc.Desc = &nodeDesc;
}

dmlInputGraphEdges.resize(m_desc.bindPoints.inputs.size());
for( size_t i = 0; i < m_desc.bindPoints.inputs.size(); i++)
{
if (m_desc.bindPoints.inputs[i].requiredBinding)
{
DML_INPUT_GRAPH_EDGE_DESC desc = {};
desc.GraphInputIndex = gsl::narrow_cast<UINT>(i);
desc.ToNodeIndex = 0;
desc.ToNodeInputIndex = gsl::narrow_cast<UINT>(i);
desc.Name = m_desc.bindPoints.inputs[i].name.c_str();
dmlInputGraphEdges[i] = desc;
dmlInputEdges.push_back({ DML_GRAPH_EDGE_TYPE_INPUT, &dmlInputGraphEdges[i] });
}
}

dmlOutputGraphEdges.resize(m_desc.bindPoints.outputs.size());
for( size_t i = 0; i < m_desc.bindPoints.outputs.size(); i++)
{
if (m_desc.bindPoints.outputs[i].requiredBinding)
{
DML_OUTPUT_GRAPH_EDGE_DESC desc = {};
desc.GraphOutputIndex = gsl::narrow_cast<UINT>(i);
desc.FromNodeIndex = 0;
desc.FromNodeOutputIndex = gsl::narrow_cast<UINT>(i);
desc.Name = m_desc.bindPoints.outputs[i].name.c_str();
dmlOutputGraphEdges[i] = desc;
dmlOutputEdges.push_back({ DML_GRAPH_EDGE_TYPE_OUTPUT, &dmlOutputGraphEdges[i] });
}
}

dmlGraphDesc.InputCount = static_cast<uint32_t>(dmlInputEdges.size());
dmlGraphDesc.InputEdges = dmlInputEdges.data();
dmlGraphDesc.InputEdgeCount = dmlGraphDesc.InputCount;

dmlGraphDesc.OutputCount = static_cast<uint32_t>(dmlOutputEdges.size());
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
dmlGraphDesc.OutputEdgeCount = dmlGraphDesc.OutputCount;

dmlGraphDesc.IntermediateEdgeCount = 0;
dmlGraphDesc.IntermediateEdges = nullptr;

dmlGraphDesc.NodeCount = 1;
dmlGraphDesc.Nodes = &dmlGraphNodeDesc;

THROW_IF_FAILED(m_device->DML()->CompileGraph(&dmlGraphDesc, m_desc.executionFlags, IID_PPV_ARGS(&m_operatorCompiled)));
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(fmt::format("Graph_{}", m_name)).data());
}

ComPtr<IDMLOperatorInitializer> initializer;
IDMLCompiledOperator* ops[] = { m_operatorCompiled.Get() };
Expand Down Expand Up @@ -145,7 +219,7 @@ void DmlDispatchable::Initialize()
// Initializers can initialize multiple inputs simultaneously, so each compiled op's inputs must
// be bound using a separate buffer array binding.
BindingData inputBindingData = {};
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true);
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true, m_desc.compileType);

DML_BUFFER_ARRAY_BINDING bufferArrayBindings = {};
if (inputBindingData.bufferBindings.size() > std::numeric_limits<uint32_t>::max())
Expand Down Expand Up @@ -193,10 +267,10 @@ void DmlDispatchable::Bind(const Bindings& bindings, uint32_t iteration)
auto bindingProps = m_operatorCompiled->GetBindingProperties();

BindingData inputBindingData = {};
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData);
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData, false, m_desc.compileType);

BindingData outputBindingData = {};
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData);
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData, false, m_desc.compileType);

D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {};
descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
Expand Down
4 changes: 3 additions & 1 deletion DxDispatch/src/dxdispatch/DmlDispatchable.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ class DmlDispatchable : public Dispatchable
std::string_view name,
std::shared_ptr<Device> device,
const Model::DmlDispatchableDesc& desc,
const Dispatchable::Bindings& initBindings);
const Dispatchable::Bindings& initBindings,
IDxDispatchLogger* logger);

void Initialize() final;
void Bind(const Bindings& bindings, uint32_t iteration) final;
Expand All @@ -23,4 +24,5 @@ class DmlDispatchable : public Dispatchable
Microsoft::WRL::ComPtr<ID3D12Resource> m_persistentBuffer;
Microsoft::WRL::ComPtr<IDMLBindingTable> m_bindingTable;
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorHeap;
Microsoft::WRL::ComPtr<IDxDispatchLogger> m_logger;
};
2 changes: 1 addition & 1 deletion DxDispatch/src/dxdispatch/Executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ Executor::Executor(Model& model, std::shared_ptr<Device> device, const CommandLi
return;
}

m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings);
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings, m_logger.Get());
}
}
catch(const std::exception& e)
Expand Down
47 changes: 47 additions & 0 deletions DxDispatch/src/model/JsonParsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1402,11 +1402,58 @@ std::vector<Model::BufferBindingSource> ParseBindingSource(const rapidjson::Valu
return sourceResources;
}

Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileType(const rapidjson::Value& value)
{
if (value.GetType() != rapidjson::Type::kStringType)
{
throw std::invalid_argument("Expected a string.");
}
auto valueString = value.GetString();
if (!strcmp(valueString, "DmlCompileOp")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp; }
if (!strcmp(valueString, "DmlCompileGraph")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DmlCompileType.", valueString));
}

Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileTypeField(const rapidjson::Value& object, std::string_view fieldName, bool required, Model::DmlDispatchableDesc::DmlCompileType defaultValue)
{
return ParseFieldHelper<Model::DmlDispatchableDesc::DmlCompileType>(object, fieldName, required, defaultValue, [](auto& value) {
return ParseDmlCompileType(value);
});
}

Model::DmlDispatchableDesc ParseModelDmlDispatchableDesc(const rapidjson::Value& object, BucketAllocator& allocator)
{
Model::DmlDispatchableDesc desc;
desc.desc = ParseDmlOperatorDesc(object, false, allocator);
desc.bindPoints = GetBindPoints(*desc.desc);

// DirectML requires optional bindings if DML_OPERATOR_DESC declares that binding for optional operator tensors.
// Logic is based on the Model directml Operator the tensors declared in "desc".
auto UpdateBindingPoints = [](const rapidjson::Value& object, std::vector<Model::DmlDispatchableDesc::BindPoint>& bindPoints) {
for (auto& bindPoint : bindPoints)
{
if (bindPoint.required || object.HasMember(bindPoint.name.c_str()))
{
bindPoint.requiredBinding = true;
}
else
{
bindPoint.requiredBinding = false;
}
}};

auto descMember = object.FindMember("Desc");
if (descMember == object.MemberEnd())
{
descMember = object.FindMember("desc");
}
if (descMember != object.MemberEnd())
{
UpdateBindingPoints(descMember->value, desc.bindPoints.inputs);
UpdateBindingPoints(descMember->value, desc.bindPoints.outputs);
}
desc.compileType = ParseDmlCompileTypeField(object, "dmlCompileType", false, Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp);

desc.executionFlags = ParseDmlExecutionFlagsField(object, "executionFlags", false, DML_EXECUTION_FLAG_NONE);

auto bindingsField = object.FindMember("bindings");
Expand Down
8 changes: 8 additions & 0 deletions DxDispatch/src/model/Model.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <DirectML.h>
#include "BucketAllocator.h"


class Model
{
public:
Expand Down Expand Up @@ -58,11 +59,17 @@ class Model

struct DmlDispatchableDesc
{
enum class DmlCompileType
{
DmlCompileOp,
DmlCompileGraph
};
struct BindPoint
{
std::string name;
uint32_t resourceCount;
bool required;
bool requiredBinding;
};

struct BindPoints
Expand All @@ -74,6 +81,7 @@ class Model
DML_OPERATOR_DESC* desc;
BindPoints bindPoints;
DML_EXECUTION_FLAGS executionFlags;
DmlCompileType compileType;
Bindings initBindings;
};

Expand Down

0 comments on commit 61a1a50

Please sign in to comment.