Skip to content

Commit 61a1a50

Browse files
authored
User/chrila/enable graph compile (#606)
* Add Graph option * Update optional tensor logic * Move json parser logic for DmlCompileType * Update version * Update DmlCompileType namespace, json def, and updated Guid.md * update spacing --------- Co-authored-by: Christian Larson <[email protected]>
1 parent 72ad224 commit 61a1a50

File tree

8 files changed

+245
-26
lines changed

8 files changed

+245
-26
lines changed

DxDispatch/doc/Guide.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,28 @@ Take note of the few odd cases that don't follow the usual rule exactly:
468468
- Enum values of type `DML_OPERATOR_TYPE` omit `_TYPE` from their prefix. It's `DML_OPERATOR_GEMM`, not `DML_OPERATOR_TYPE_GEMM`.
469469
- Flag values are singular and omit the "S". It's `DML_EXECUTION_FLAG_NONE`, not `DML_EXECUTION_FLAGS_NONE`.
470470

471+
### DirectML Compile Op vs Graph (dmlCompileType)
472+
Enum dmlCompileType configures whether a defined DirectML operator uses IDMLDevice::CompileOperator or the operator is inserted into DML_GRAPH_DESC and compiled using IDMLDevice1::CompileGraph.
473+
474+
| Enums for dmlCompileType | Description |
475+
| ------------------------------------------------ | ------------------------------------------------------------------------- |
476+
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Uses IDMLDevice::CompileOperator for defined operator |
477+
| <b><i>DmlCompileGraph</b></i> (Default behavior) | Inserts Operator into a DML_GRAPH_DESC and uses IDMLDevice1::CompileGraph |
478+
479+
Syntax:
480+
481+
```json
482+
"dmlOperator":
483+
{
484+
"type": "DML_OPERATOR_*",
485+
"dmlCompileType": "DmlCompileGraph",
486+
"Desc": { ... }
487+
}
488+
```
489+
490+
See full example in [dml_gemm_graph.json](../models/dml_gemm_graph.json).
491+
492+
471493
### DML_TENSOR_DESC
472494

473495
Since tensor descs are so common, the JSON parser provides default values for most fields.

DxDispatch/models/_schema.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@
7676
]
7777
},
7878

79+
"dmlCompileType":
80+
{
81+
"enum":
82+
[
83+
"DmlCompileOp",
84+
"DmlCompileGraph"
85+
]
86+
},
87+
7988
"arrayInitializer":
8089
{
8190
"type": "array",

DxDispatch/models/dml_gemm_graph.json

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
{
2+
"$schema": "./_schema.json",
3+
4+
"resources":
5+
{
6+
"A": {
7+
"initialValuesDataType": "FLOAT32",
8+
"initialValues": { "valueCount": 1024, "value": 1 }
9+
},
10+
"B": {
11+
"initialValuesDataType": "FLOAT32",
12+
"initialValues": { "valueCount": 1024, "value": 1 }
13+
},
14+
"output": {
15+
"initialValuesDataType": "FLOAT32",
16+
"initialValues": { "valueCount": 1024, "value": 1 }
17+
}
18+
},
19+
20+
"dispatchables":
21+
{
22+
"gemm":
23+
{
24+
"type": "DML_OPERATOR_GEMM",
25+
"desc":
26+
{
27+
"ATensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
28+
"BTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32], "Flags": "DML_TENSOR_FLAG_OWNED_BY_DML" },
29+
"OutputTensor": { "DataType": "FLOAT32", "Sizes": [1,1,32,32] },
30+
"TransA": "DML_MATRIX_TRANSFORM_NONE",
31+
"TransB": "DML_MATRIX_TRANSFORM_NONE",
32+
"Alpha": 1.0,
33+
"Beta": 1.0
34+
},
35+
"dmlCompileType": "DmlCompileGraph",
36+
"executionFlags": "DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION",
37+
"bindings":
38+
{
39+
"BTensor": "B"
40+
}
41+
}
42+
},
43+
44+
"commands":
45+
[
46+
{
47+
"type": "dispatch",
48+
"dispatchable": "gemm",
49+
"bindings":
50+
{
51+
"ATensor": "A",
52+
"OutputTensor": "output"
53+
}
54+
},
55+
{ "type": "print", "resource": "output" }
56+
]
57+
}

DxDispatch/src/dxdispatch/DmlDispatchable.cpp

Lines changed: 98 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ DmlDispatchable::DmlDispatchable(
1111
std::string_view name,
1212
std::shared_ptr<Device> device,
1313
const Model::DmlDispatchableDesc& desc,
14-
const Dispatchable::Bindings& initBindings
15-
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings))
14+
const Dispatchable::Bindings& initBindings,
15+
IDxDispatchLogger* logger
16+
) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings)), m_logger(logger)
1617
{
1718
THROW_IF_FAILED(device->DML()->CreateOperator(desc.desc, IID_PPV_ARGS(&m_operator)));
1819
}
@@ -28,7 +29,8 @@ void FillBindingData(
2829
const Dispatchable::Bindings* initializeBindings,
2930
const Dispatchable::Bindings* executeBindings,
3031
BindingData& bindingData,
31-
bool bindingForInitialization = false)
32+
bool bindingForInitialization,
33+
Model::DmlDispatchableDesc::DmlCompileType compileType)
3234
{
3335
const Dispatchable::Bindings& bindings = bindingForInitialization ? *initializeBindings : *executeBindings;
3436

@@ -47,22 +49,23 @@ void FillBindingData(
4749

4850
if (bindingIterator == bindings.end())
4951
{
50-
if (bindPoints[i].required && !bindingForInitialization)
52+
for (size_t j = 0; j < bindPoints[i].resourceCount; j++)
5153
{
52-
if (!initializeBindings || initializeBindings->find(bindPointName) == initializeBindings->end())
54+
if (compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph && !bindPoints[i].requiredBinding)
5355
{
54-
throw std::invalid_argument(fmt::format("Nothing bound for required tensor '{}'.", bindPointName));
56+
// Dml Graph will fail if given DML_BINDING_TYPE_NONE for optional bindings not described in the graph.
57+
bindingData.bindingDescs.pop_back();
58+
bindingData.bufferBindings.pop_back();
59+
}
60+
else
61+
{
62+
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
63+
bindingData.bufferBindings[bufferIndex].Offset = 0;
64+
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
65+
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
66+
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
67+
bufferIndex++;
5568
}
56-
}
57-
58-
for (size_t j = 0; j < bindPoints[i].resourceCount; j++)
59-
{
60-
bindingData.bufferBindings[bufferIndex].Buffer = nullptr;
61-
bindingData.bufferBindings[bufferIndex].Offset = 0;
62-
bindingData.bufferBindings[bufferIndex].SizeInBytes = 0;
63-
bindingData.bindingDescs[bufferIndex].Type = DML_BINDING_TYPE_NONE;
64-
bindingData.bindingDescs[bufferIndex].Desc = nullptr;
65-
bufferIndex++;
6669
}
6770
}
6871
else
@@ -103,11 +106,82 @@ void FillBindingData(
103106

104107
void DmlDispatchable::Initialize()
105108
{
106-
THROW_IF_FAILED(m_device->DML()->CompileOperator(
107-
m_operator.Get(),
108-
m_desc.executionFlags,
109-
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
110-
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
109+
if(m_desc.compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp)
110+
{
111+
m_logger->LogInfo("Compile Op");
112+
THROW_IF_FAILED(m_device->DML()->CompileOperator(
113+
m_operator.Get(),
114+
m_desc.executionFlags,
115+
IID_PPV_ARGS(m_operatorCompiled.ReleaseAndGetAddressOf())));
116+
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(m_name).data());
117+
}
118+
else
119+
{
120+
m_logger->LogInfo("Compiling op using IDMLDevice1::CompileGraph");
121+
DML_GRAPH_DESC dmlGraphDesc = {};
122+
std::vector<DML_INPUT_GRAPH_EDGE_DESC> dmlInputGraphEdges;
123+
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges;
124+
125+
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> dmlOutputGraphEdges;
126+
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges;
127+
DML_GRAPH_NODE_DESC dmlGraphNodeDesc = {};
128+
DML_OPERATOR_GRAPH_NODE_DESC nodeDesc{};
129+
130+
nodeDesc.Operator = m_operator.Get();
131+
nodeDesc.Name = m_name.c_str();
132+
133+
{
134+
dmlGraphNodeDesc.Type = DML_GRAPH_NODE_TYPE_OPERATOR;
135+
dmlGraphNodeDesc.Desc = &nodeDesc;
136+
}
137+
138+
dmlInputGraphEdges.resize(m_desc.bindPoints.inputs.size());
139+
for( size_t i = 0; i < m_desc.bindPoints.inputs.size(); i++)
140+
{
141+
if (m_desc.bindPoints.inputs[i].requiredBinding)
142+
{
143+
DML_INPUT_GRAPH_EDGE_DESC desc = {};
144+
desc.GraphInputIndex = gsl::narrow_cast<UINT>(i);
145+
desc.ToNodeIndex = 0;
146+
desc.ToNodeInputIndex = gsl::narrow_cast<UINT>(i);
147+
desc.Name = m_desc.bindPoints.inputs[i].name.c_str();
148+
dmlInputGraphEdges[i] = desc;
149+
dmlInputEdges.push_back({ DML_GRAPH_EDGE_TYPE_INPUT, &dmlInputGraphEdges[i] });
150+
}
151+
}
152+
153+
dmlOutputGraphEdges.resize(m_desc.bindPoints.outputs.size());
154+
for( size_t i = 0; i < m_desc.bindPoints.outputs.size(); i++)
155+
{
156+
if (m_desc.bindPoints.outputs[i].requiredBinding)
157+
{
158+
DML_OUTPUT_GRAPH_EDGE_DESC desc = {};
159+
desc.GraphOutputIndex = gsl::narrow_cast<UINT>(i);
160+
desc.FromNodeIndex = 0;
161+
desc.FromNodeOutputIndex = gsl::narrow_cast<UINT>(i);
162+
desc.Name = m_desc.bindPoints.outputs[i].name.c_str();
163+
dmlOutputGraphEdges[i] = desc;
164+
dmlOutputEdges.push_back({ DML_GRAPH_EDGE_TYPE_OUTPUT, &dmlOutputGraphEdges[i] });
165+
}
166+
}
167+
168+
dmlGraphDesc.InputCount = static_cast<uint32_t>(dmlInputEdges.size());
169+
dmlGraphDesc.InputEdges = dmlInputEdges.data();
170+
dmlGraphDesc.InputEdgeCount = dmlGraphDesc.InputCount;
171+
172+
dmlGraphDesc.OutputCount = static_cast<uint32_t>(dmlOutputEdges.size());
173+
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
174+
dmlGraphDesc.OutputEdgeCount = dmlGraphDesc.OutputCount;
175+
176+
dmlGraphDesc.IntermediateEdgeCount = 0;
177+
dmlGraphDesc.IntermediateEdges = nullptr;
178+
179+
dmlGraphDesc.NodeCount = 1;
180+
dmlGraphDesc.Nodes = &dmlGraphNodeDesc;
181+
182+
THROW_IF_FAILED(m_device->DML()->CompileGraph(&dmlGraphDesc, m_desc.executionFlags, IID_PPV_ARGS(&m_operatorCompiled)));
183+
m_operatorCompiled->SetName(std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(fmt::format("Graph_{}", m_name)).data());
184+
}
111185

112186
ComPtr<IDMLOperatorInitializer> initializer;
113187
IDMLCompiledOperator* ops[] = { m_operatorCompiled.Get() };
@@ -145,7 +219,7 @@ void DmlDispatchable::Initialize()
145219
// Initializers can initialize multiple inputs simultaneously, so each compiled op's inputs must
146220
// be bound using a separate buffer array binding.
147221
BindingData inputBindingData = {};
148-
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true);
222+
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, nullptr, inputBindingData, true, m_desc.compileType);
149223

150224
DML_BUFFER_ARRAY_BINDING bufferArrayBindings = {};
151225
if (inputBindingData.bufferBindings.size() > std::numeric_limits<uint32_t>::max())
@@ -193,10 +267,10 @@ void DmlDispatchable::Bind(const Bindings& bindings, uint32_t iteration)
193267
auto bindingProps = m_operatorCompiled->GetBindingProperties();
194268

195269
BindingData inputBindingData = {};
196-
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData);
270+
FillBindingData(m_desc.bindPoints.inputs, &m_initBindings, &bindings, inputBindingData, false, m_desc.compileType);
197271

198272
BindingData outputBindingData = {};
199-
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData);
273+
FillBindingData(m_desc.bindPoints.outputs, &m_initBindings, &bindings, outputBindingData, false, m_desc.compileType);
200274

201275
D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {};
202276
descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;

DxDispatch/src/dxdispatch/DmlDispatchable.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ class DmlDispatchable : public Dispatchable
77
std::string_view name,
88
std::shared_ptr<Device> device,
99
const Model::DmlDispatchableDesc& desc,
10-
const Dispatchable::Bindings& initBindings);
10+
const Dispatchable::Bindings& initBindings,
11+
IDxDispatchLogger* logger);
1112

1213
void Initialize() final;
1314
void Bind(const Bindings& bindings, uint32_t iteration) final;
@@ -23,4 +24,5 @@ class DmlDispatchable : public Dispatchable
2324
Microsoft::WRL::ComPtr<ID3D12Resource> m_persistentBuffer;
2425
Microsoft::WRL::ComPtr<IDMLBindingTable> m_bindingTable;
2526
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorHeap;
27+
Microsoft::WRL::ComPtr<IDxDispatchLogger> m_logger;
2628
};

DxDispatch/src/dxdispatch/Executor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ Executor::Executor(Model& model, std::shared_ptr<Device> device, const CommandLi
162162
return;
163163
}
164164

165-
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings);
165+
m_dispatchables[desc.name] = std::make_unique<DmlDispatchable>(desc.name, device, dmlDispatchableDesc, initBindings, m_logger.Get());
166166
}
167167
}
168168
catch(const std::exception& e)

DxDispatch/src/model/JsonParsers.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,11 +1402,58 @@ std::vector<Model::BufferBindingSource> ParseBindingSource(const rapidjson::Valu
14021402
return sourceResources;
14031403
}
14041404

1405+
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileType(const rapidjson::Value& value)
1406+
{
1407+
if (value.GetType() != rapidjson::Type::kStringType)
1408+
{
1409+
throw std::invalid_argument("Expected a string.");
1410+
}
1411+
auto valueString = value.GetString();
1412+
if (!strcmp(valueString, "DmlCompileOp")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp; }
1413+
if (!strcmp(valueString, "DmlCompileGraph")) { return Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph; }
1414+
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DmlCompileType.", valueString));
1415+
}
1416+
1417+
Model::DmlDispatchableDesc::DmlCompileType ParseDmlCompileTypeField(const rapidjson::Value& object, std::string_view fieldName, bool required, Model::DmlDispatchableDesc::DmlCompileType defaultValue)
1418+
{
1419+
return ParseFieldHelper<Model::DmlDispatchableDesc::DmlCompileType>(object, fieldName, required, defaultValue, [](auto& value) {
1420+
return ParseDmlCompileType(value);
1421+
});
1422+
}
1423+
14051424
Model::DmlDispatchableDesc ParseModelDmlDispatchableDesc(const rapidjson::Value& object, BucketAllocator& allocator)
14061425
{
14071426
Model::DmlDispatchableDesc desc;
14081427
desc.desc = ParseDmlOperatorDesc(object, false, allocator);
14091428
desc.bindPoints = GetBindPoints(*desc.desc);
1429+
1430+
// DirectML requires optional bindings if DML_OPERATOR_DESC declares that binding for optional operator tensors.
1431+
// Logic is based on the Model directml Operator the tensors declared in "desc".
1432+
auto UpdateBindingPoints = [](const rapidjson::Value& object, std::vector<Model::DmlDispatchableDesc::BindPoint>& bindPoints) {
1433+
for (auto& bindPoint : bindPoints)
1434+
{
1435+
if (bindPoint.required || object.HasMember(bindPoint.name.c_str()))
1436+
{
1437+
bindPoint.requiredBinding = true;
1438+
}
1439+
else
1440+
{
1441+
bindPoint.requiredBinding = false;
1442+
}
1443+
}};
1444+
1445+
auto descMember = object.FindMember("Desc");
1446+
if (descMember == object.MemberEnd())
1447+
{
1448+
descMember = object.FindMember("desc");
1449+
}
1450+
if (descMember != object.MemberEnd())
1451+
{
1452+
UpdateBindingPoints(descMember->value, desc.bindPoints.inputs);
1453+
UpdateBindingPoints(descMember->value, desc.bindPoints.outputs);
1454+
}
1455+
desc.compileType = ParseDmlCompileTypeField(object, "dmlCompileType", false, Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp);
1456+
14101457
desc.executionFlags = ParseDmlExecutionFlagsField(object, "executionFlags", false, DML_EXECUTION_FLAG_NONE);
14111458

14121459
auto bindingsField = object.FindMember("bindings");

DxDispatch/src/model/Model.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <DirectML.h>
1111
#include "BucketAllocator.h"
1212

13+
1314
class Model
1415
{
1516
public:
@@ -58,11 +59,17 @@ class Model
5859

5960
struct DmlDispatchableDesc
6061
{
62+
enum class DmlCompileType
63+
{
64+
DmlCompileOp,
65+
DmlCompileGraph
66+
};
6167
struct BindPoint
6268
{
6369
std::string name;
6470
uint32_t resourceCount;
6571
bool required;
72+
bool requiredBinding;
6673
};
6774

6875
struct BindPoints
@@ -74,6 +81,7 @@ class Model
7481
DML_OPERATOR_DESC* desc;
7582
BindPoints bindPoints;
7683
DML_EXECUTION_FLAGS executionFlags;
84+
DmlCompileType compileType;
7785
Bindings initBindings;
7886
};
7987

0 commit comments

Comments
 (0)