@@ -11,8 +11,9 @@ DmlDispatchable::DmlDispatchable(
11
11
std::string_view name,
12
12
std::shared_ptr<Device> device,
13
13
const Model::DmlDispatchableDesc& desc,
14
- const Dispatchable::Bindings& initBindings
15
- ) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings))
14
+ const Dispatchable::Bindings& initBindings,
15
+ IDxDispatchLogger* logger
16
+ ) : m_name(name), m_device(device), m_desc(desc), m_initBindings(std::move(initBindings)), m_logger(logger)
16
17
{
17
18
THROW_IF_FAILED (device->DML ()->CreateOperator (desc.desc , IID_PPV_ARGS (&m_operator)));
18
19
}
@@ -28,7 +29,8 @@ void FillBindingData(
28
29
const Dispatchable::Bindings* initializeBindings,
29
30
const Dispatchable::Bindings* executeBindings,
30
31
BindingData& bindingData,
31
- bool bindingForInitialization = false )
32
+ bool bindingForInitialization,
33
+ Model::DmlDispatchableDesc::DmlCompileType compileType)
32
34
{
33
35
const Dispatchable::Bindings& bindings = bindingForInitialization ? *initializeBindings : *executeBindings;
34
36
@@ -47,22 +49,23 @@ void FillBindingData(
47
49
48
50
if (bindingIterator == bindings.end ())
49
51
{
50
- if ( bindPoints[i].required && !bindingForInitialization )
52
+ for ( size_t j = 0 ; j < bindPoints[i].resourceCount ; j++ )
51
53
{
52
- if (!initializeBindings || initializeBindings-> find (bindPointName) == initializeBindings-> end () )
54
+ if (compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileGraph && !bindPoints[i]. requiredBinding )
53
55
{
54
- throw std::invalid_argument (fmt::format (" Nothing bound for required tensor '{}'." , bindPointName));
56
+ // Dml Graph will fail if given DML_BINDING_TYPE_NONE for optional bindings not described in the graph.
57
+ bindingData.bindingDescs .pop_back ();
58
+ bindingData.bufferBindings .pop_back ();
59
+ }
60
+ else
61
+ {
62
+ bindingData.bufferBindings [bufferIndex].Buffer = nullptr ;
63
+ bindingData.bufferBindings [bufferIndex].Offset = 0 ;
64
+ bindingData.bufferBindings [bufferIndex].SizeInBytes = 0 ;
65
+ bindingData.bindingDescs [bufferIndex].Type = DML_BINDING_TYPE_NONE;
66
+ bindingData.bindingDescs [bufferIndex].Desc = nullptr ;
67
+ bufferIndex++;
55
68
}
56
- }
57
-
58
- for (size_t j = 0 ; j < bindPoints[i].resourceCount ; j++)
59
- {
60
- bindingData.bufferBindings [bufferIndex].Buffer = nullptr ;
61
- bindingData.bufferBindings [bufferIndex].Offset = 0 ;
62
- bindingData.bufferBindings [bufferIndex].SizeInBytes = 0 ;
63
- bindingData.bindingDescs [bufferIndex].Type = DML_BINDING_TYPE_NONE;
64
- bindingData.bindingDescs [bufferIndex].Desc = nullptr ;
65
- bufferIndex++;
66
69
}
67
70
}
68
71
else
@@ -103,11 +106,82 @@ void FillBindingData(
103
106
104
107
void DmlDispatchable::Initialize ()
105
108
{
106
- THROW_IF_FAILED (m_device->DML ()->CompileOperator (
107
- m_operator.Get (),
108
- m_desc.executionFlags ,
109
- IID_PPV_ARGS (m_operatorCompiled.ReleaseAndGetAddressOf ())));
110
- m_operatorCompiled->SetName (std::wstring_convert<std::codecvt_utf8<wchar_t >>().from_bytes (m_name).data ());
109
+ if (m_desc.compileType == Model::DmlDispatchableDesc::DmlCompileType::DmlCompileOp)
110
+ {
111
+ m_logger->LogInfo (" Compile Op" );
112
+ THROW_IF_FAILED (m_device->DML ()->CompileOperator (
113
+ m_operator.Get (),
114
+ m_desc.executionFlags ,
115
+ IID_PPV_ARGS (m_operatorCompiled.ReleaseAndGetAddressOf ())));
116
+ m_operatorCompiled->SetName (std::wstring_convert<std::codecvt_utf8<wchar_t >>().from_bytes (m_name).data ());
117
+ }
118
+ else
119
+ {
120
+ m_logger->LogInfo (" Compiling op using IDMLDevice1::CompileGraph" );
121
+ DML_GRAPH_DESC dmlGraphDesc = {};
122
+ std::vector<DML_INPUT_GRAPH_EDGE_DESC> dmlInputGraphEdges;
123
+ std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges;
124
+
125
+ std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> dmlOutputGraphEdges;
126
+ std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges;
127
+ DML_GRAPH_NODE_DESC dmlGraphNodeDesc = {};
128
+ DML_OPERATOR_GRAPH_NODE_DESC nodeDesc{};
129
+
130
+ nodeDesc.Operator = m_operator.Get ();
131
+ nodeDesc.Name = m_name.c_str ();
132
+
133
+ {
134
+ dmlGraphNodeDesc.Type = DML_GRAPH_NODE_TYPE_OPERATOR;
135
+ dmlGraphNodeDesc.Desc = &nodeDesc;
136
+ }
137
+
138
+ dmlInputGraphEdges.resize (m_desc.bindPoints .inputs .size ());
139
+ for ( size_t i = 0 ; i < m_desc.bindPoints .inputs .size (); i++)
140
+ {
141
+ if (m_desc.bindPoints .inputs [i].requiredBinding )
142
+ {
143
+ DML_INPUT_GRAPH_EDGE_DESC desc = {};
144
+ desc.GraphInputIndex = gsl::narrow_cast<UINT>(i);
145
+ desc.ToNodeIndex = 0 ;
146
+ desc.ToNodeInputIndex = gsl::narrow_cast<UINT>(i);
147
+ desc.Name = m_desc.bindPoints .inputs [i].name .c_str ();
148
+ dmlInputGraphEdges[i] = desc;
149
+ dmlInputEdges.push_back ({ DML_GRAPH_EDGE_TYPE_INPUT, &dmlInputGraphEdges[i] });
150
+ }
151
+ }
152
+
153
+ dmlOutputGraphEdges.resize (m_desc.bindPoints .outputs .size ());
154
+ for ( size_t i = 0 ; i < m_desc.bindPoints .outputs .size (); i++)
155
+ {
156
+ if (m_desc.bindPoints .outputs [i].requiredBinding )
157
+ {
158
+ DML_OUTPUT_GRAPH_EDGE_DESC desc = {};
159
+ desc.GraphOutputIndex = gsl::narrow_cast<UINT>(i);
160
+ desc.FromNodeIndex = 0 ;
161
+ desc.FromNodeOutputIndex = gsl::narrow_cast<UINT>(i);
162
+ desc.Name = m_desc.bindPoints .outputs [i].name .c_str ();
163
+ dmlOutputGraphEdges[i] = desc;
164
+ dmlOutputEdges.push_back ({ DML_GRAPH_EDGE_TYPE_OUTPUT, &dmlOutputGraphEdges[i] });
165
+ }
166
+ }
167
+
168
+ dmlGraphDesc.InputCount = static_cast <uint32_t >(dmlInputEdges.size ());
169
+ dmlGraphDesc.InputEdges = dmlInputEdges.data ();
170
+ dmlGraphDesc.InputEdgeCount = dmlGraphDesc.InputCount ;
171
+
172
+ dmlGraphDesc.OutputCount = static_cast <uint32_t >(dmlOutputEdges.size ());
173
+ dmlGraphDesc.OutputEdges = dmlOutputEdges.data ();
174
+ dmlGraphDesc.OutputEdgeCount = dmlGraphDesc.OutputCount ;
175
+
176
+ dmlGraphDesc.IntermediateEdgeCount = 0 ;
177
+ dmlGraphDesc.IntermediateEdges = nullptr ;
178
+
179
+ dmlGraphDesc.NodeCount = 1 ;
180
+ dmlGraphDesc.Nodes = &dmlGraphNodeDesc;
181
+
182
+ THROW_IF_FAILED (m_device->DML ()->CompileGraph (&dmlGraphDesc, m_desc.executionFlags , IID_PPV_ARGS (&m_operatorCompiled)));
183
+ m_operatorCompiled->SetName (std::wstring_convert<std::codecvt_utf8<wchar_t >>().from_bytes (fmt::format (" Graph_{}" , m_name)).data ());
184
+ }
111
185
112
186
ComPtr<IDMLOperatorInitializer> initializer;
113
187
IDMLCompiledOperator* ops[] = { m_operatorCompiled.Get () };
@@ -145,7 +219,7 @@ void DmlDispatchable::Initialize()
145
219
// Initializers can initialize multiple inputs simultaneously, so each compiled op's inputs must
146
220
// be bound using a separate buffer array binding.
147
221
BindingData inputBindingData = {};
148
- FillBindingData (m_desc.bindPoints .inputs , &m_initBindings, nullptr , inputBindingData, true );
222
+ FillBindingData (m_desc.bindPoints .inputs , &m_initBindings, nullptr , inputBindingData, true , m_desc. compileType );
149
223
150
224
DML_BUFFER_ARRAY_BINDING bufferArrayBindings = {};
151
225
if (inputBindingData.bufferBindings .size () > std::numeric_limits<uint32_t >::max ())
@@ -193,10 +267,10 @@ void DmlDispatchable::Bind(const Bindings& bindings, uint32_t iteration)
193
267
auto bindingProps = m_operatorCompiled->GetBindingProperties ();
194
268
195
269
BindingData inputBindingData = {};
196
- FillBindingData (m_desc.bindPoints .inputs , &m_initBindings, &bindings, inputBindingData);
270
+ FillBindingData (m_desc.bindPoints .inputs , &m_initBindings, &bindings, inputBindingData, false , m_desc. compileType );
197
271
198
272
BindingData outputBindingData = {};
199
- FillBindingData (m_desc.bindPoints .outputs , &m_initBindings, &bindings, outputBindingData);
273
+ FillBindingData (m_desc.bindPoints .outputs , &m_initBindings, &bindings, outputBindingData, false , m_desc. compileType );
200
274
201
275
D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc = {};
202
276
descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
0 commit comments