1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/Graph/Graph.h" |
17 | #include "glow/Backend/Backend.h" |
18 | #include "glow/Flags/Flags.h" |
19 | #include "glow/Graph/Nodes.h" |
20 | #include "glow/Graph/PlaceholderBindings.h" |
21 | #include "glow/Graph/TensorLayout.h" |
22 | #include "glow/Graph/VerifierHelper.h" |
23 | #include "glow/Quantization/Base/Base.h" |
24 | #include "glow/Support/Support.h" |
25 | |
26 | #include "llvm/ADT/DenseMap.h" |
27 | #include "llvm/ADT/SmallString.h" |
28 | #include "llvm/ADT/SmallVector.h" |
29 | #include "llvm/Support/Casting.h" |
30 | #include "llvm/Support/FileSystem.h" |
31 | #include "llvm/Support/Format.h" |
32 | #include "llvm/Support/Path.h" |
33 | #include "llvm/Support/raw_ostream.h" |
34 | |
35 | #ifdef WIN32 |
36 | #include <corecrt_math_defines.h> |
37 | #endif |
38 | #include <float.h> |
39 | #include <fstream> |
40 | #include <unordered_set> |
41 | |
42 | using namespace glow; |
43 | using llvm::cast; |
44 | using llvm::dyn_cast; |
45 | using llvm::isa; |
46 | |
47 | namespace { |
48 | /// A helper function to log the deletion of constant/placeholder \p s of a |
49 | /// module into the log context of given functions \p functions. |
50 | /// Note: The reason we don't log the deletion of constants in the function that |
51 | /// ueses or creates it, is that constants/placeholders do not have a function |
52 | /// parent (we can't utilize its user's function also because its users might be |
53 | /// removed) such that it's best to log the constants/placeholders in a Module |
54 | /// level log context and copy over to its all functions. |
55 | void logStorageDeletion(std::list<Function *> functions, Storage *s) { |
56 | for (auto *F : functions) { |
57 | F->getLogContext()->logNodeDeletion(*s); |
58 | } |
59 | if (functions.size() > 0) { |
60 | auto *F = *(functions.begin()); |
61 | F->getLogContext()->logNodeDeletion(*s, /* logIntoModule */ true); |
62 | } |
63 | } |
64 | |
65 | /// A helper function to log the creation of constant/placeholder \p s of a |
66 | /// module into the log context of given functions \p functions. |
67 | /// Same note as for logStorageDeletion(). |
68 | void logStorageCreation(std::list<Function *> functions, Storage *s) { |
69 | for (auto *F : functions) { |
70 | F->getLogContext()->logNodeCreation(*s); |
71 | } |
72 | if (functions.size() > 0) { |
73 | auto *F = *(functions.begin()); |
74 | F->getLogContext()->logNodeCreation(*s, /* logIntoModule */ true); |
75 | } |
76 | } |
77 | } // namespace |
78 | |
79 | /// Merge shape \p shape into \p mergeShape, following multidirectional |
80 | /// broadcasting rules. |
81 | static void mergeMultidirectionalBroadcastHelper(std::vector<dim_t> &mergeShape, |
82 | llvm::ArrayRef<dim_t> shape) { |
83 | size_t shift = mergeShape.size() - shape.size(); |
84 | for (size_t i = 0, e = shape.size(); i < e; i++) { |
85 | if (shape[i] == 1) { |
86 | // Just leave mergeShape[i] as it is. |
87 | continue; |
88 | } |
89 | |
90 | assert( |
91 | ((shape[i] == mergeShape[shift + i]) || (mergeShape[shift + i] == 1)) && |
92 | "Incompatible dimension for the broadcast" ); |
93 | mergeShape[shift + i] = shape[i]; |
94 | } |
95 | } |
96 | |
97 | /// Utility function which computes the resulting shape in case of |
98 | /// multidirectional broadcasting. |
99 | static std::vector<dim_t> |
100 | computeMultidirectionalBroadcastHelper(llvm::ArrayRef<dim_t> shape0, |
101 | llvm::ArrayRef<dim_t> shape1) { |
102 | size_t numDims0 = shape0.size(); |
103 | size_t numDims1 = shape1.size(); |
104 | size_t newNumDims = std::max(numDims0, numDims1); |
105 | std::vector<dim_t> reshapeDims(newNumDims, 1); |
106 | |
107 | mergeMultidirectionalBroadcastHelper(reshapeDims, shape0); |
108 | mergeMultidirectionalBroadcastHelper(reshapeDims, shape1); |
109 | |
110 | return reshapeDims; |
111 | } |
112 | |
113 | std::vector<NodeValue> |
114 | Function::broadcastInputs(int axis, const llvm::ArrayRef<NodeValue> inputs) { |
115 | dim_t numInputs = inputs.size(); |
116 | |
117 | if (axis > -1) { |
118 | assert( |
119 | numInputs == 2 && |
120 | "If axis is specified, not -1, unidirectional broadcast will be used, " |
121 | "input size must be 2." ); |
122 | return {inputs[0], |
123 | createBroadcast("broadcast_" + inputs[1].getNode()->getName().str(), |
124 | inputs[1], inputs[0].dims(), axis)}; |
125 | } |
126 | |
127 | assert(numInputs >= 2 && "Invalid input passed in to commonCreateBroadcast." ); |
128 | |
129 | std::vector<dim_t> targetDim = computeMultidirectionalBroadcastHelper( |
130 | inputs[0].dims(), inputs[1].dims()); |
131 | |
132 | for (size_t i = 2; i < numInputs; ++i) { |
133 | targetDim = |
134 | computeMultidirectionalBroadcastHelper(targetDim, inputs[i].dims()); |
135 | } |
136 | |
137 | std::vector<NodeValue> out(numInputs); |
138 | for (size_t i = 0; i < numInputs; ++i) { |
139 | NodeValue n = inputs[i]; |
140 | auto dims = n.dims(); |
141 | if (dims != llvm::ArrayRef<dim_t>(targetDim)) { |
142 | unsigned axis = targetDim.size() - dims.size(); |
143 | out[i] = createBroadcast("broadcast_" + n.getNode()->getName().str(), n, |
144 | targetDim, axis); |
145 | } else { |
146 | out[i] = inputs[i]; |
147 | } |
148 | } |
149 | return out; |
150 | } |
151 | |
152 | bool Module::hasFunction(llvm::StringRef name) { return getFunction(name); } |
153 | |
154 | void Module::clearFunctions() { |
155 | for (auto *F : functions_) { |
156 | F->clear(); |
157 | } |
158 | } |
159 | |
160 | void Function::clear() { |
161 | nodes_.clear(); |
162 | uniqueNodeNames_.clear(); |
163 | } |
164 | |
165 | Function *Module::getFunction(llvm::StringRef name) { |
166 | for (auto *F : functions_) { |
167 | if (F->getName() == name) { |
168 | return F; |
169 | } |
170 | } |
171 | return nullptr; |
172 | } |
173 | |
174 | Function *Module::createFunction(llvm::StringRef name) { |
175 | assert(!hasFunction(name) && "A function with this name already exists" ); |
176 | Function *F = new Function(this, name); |
177 | functions_.push_back(F); |
178 | return F; |
179 | } |
180 | |
181 | void Module::strip() { |
182 | for (auto it = constants_.begin(), e = constants_.end(); it != e; it++) { |
183 | Constant *v = *it; |
184 | v->clearPayload(); |
185 | } |
186 | } |
187 | |
188 | void Module::clear() { |
189 | for (auto it = constants_.begin(), e = constants_.end(); it != e; it++) { |
190 | Constant *v = *it; |
191 | logStorageDeletion(functions_, v); |
192 | delete v; |
193 | } |
194 | |
195 | constants_.clear(); |
196 | |
197 | for (auto it = placeholders_.begin(), e = placeholders_.end(); it != e; |
198 | it++) { |
199 | Placeholder *p = *it; |
200 | logStorageDeletion(functions_, p); |
201 | delete p; |
202 | } |
203 | |
204 | eraseFunctions(); |
205 | |
206 | placeholders_.clear(); |
207 | } |
208 | |
209 | Module::~Module() { clear(); } |
210 | bool Module::verify() const { |
211 | bool isValid = true; |
212 | for (auto *F : functions_) { |
213 | isValid &= F->verify(); |
214 | } |
215 | // Check that all types used by constants or placeholders belong to the |
216 | // module. |
217 | auto &types = getTypes(); |
218 | for (const auto *PH : getPlaceholders()) { |
219 | bool foundType = |
220 | std::find(types.begin(), types.end(), *PH->getType()) != types.end(); |
221 | isValid &= |
222 | expectCompareTrue("Every type used by placeholders should be part of " |
223 | "the graph" , |
224 | foundType, true, PH); |
225 | } |
226 | for (const auto *C : getConstants()) { |
227 | bool foundType = |
228 | std::find(types.begin(), types.end(), *C->getType()) != types.end(); |
229 | isValid &= |
230 | expectCompareTrue("Every type used by constants should be part of " |
231 | "the graph" , |
232 | foundType, true, C); |
233 | } |
234 | return isValid; |
235 | } |
236 | |
237 | void Module::dump() const { |
238 | llvm::outs() << "Module structure:\n" ; |
239 | for (auto *C : getConstants()) { |
240 | llvm::outs() << C->getDebugDesc() << "\n" ; |
241 | } |
242 | |
243 | for (auto *P : getPlaceholders()) { |
244 | llvm::outs() << P->getDebugDesc() << "\n" ; |
245 | } |
246 | |
247 | for (auto *F : functions_) { |
248 | llvm::outs() << "Function:" << F->getName() << "\n" ; |
249 | } |
250 | } |
251 | |
252 | std::string Module::toString() const { |
253 | std::string storage; |
254 | llvm::raw_string_ostream os(storage); |
255 | dump(os); |
256 | return os.str(); |
257 | } |
258 | |
259 | /// Creates a std::set copy of \p unsorted, sorted based on name of each |
260 | /// element, and \returns it. |
261 | template <class T> |
262 | static std::set<T *, SortNamed> getNamedSorted(const std::list<T *> &unsorted) { |
263 | return std::set<T *, SortNamed>(unsorted.begin(), unsorted.end()); |
264 | } |
265 | |
266 | void Module::dump(llvm::raw_ostream &os) const { |
267 | os << "Module structure:\n" ; |
268 | for (auto *C : getNamedSorted(constants_)) { |
269 | os << C->getDebugDesc() << "\n" ; |
270 | } |
271 | for (auto *P : getNamedSorted(placeholders_)) { |
272 | os << P->getDebugDesc() << "\n" ; |
273 | } |
274 | for (auto *F : getNamedSorted(functions_)) { |
275 | os << "Function : " << F->getName() << "\n" ; |
276 | } |
277 | } |
278 | |
279 | /// A helper class for visiting and generating the dotty graph file. |
280 | class AbstractDottyPrinter { |
281 | protected: |
282 | // List of generated vertices. |
283 | std::vector<std::string> vertices_{}; |
284 | // List of generated edges. |
285 | std::unordered_set<std::string> edges_{}; |
286 | // Map node addresses to unique numbers. |
287 | using VertexNumberMap = std::unordered_map<void *, unsigned>; |
288 | VertexNumberMap vertex_numbers{}; |
289 | |
290 | /// Dumps label for a input/output row, given port names. |
291 | /// E.g. {"LHS", "RHS"} will produce {<LHS>LHS|<RHS>RHS} |
292 | void dumpLabelForRow(llvm::ArrayRef<std::string> names, std::ostream &os) { |
293 | os << "{" ; |
294 | for (size_t i = 0; i < names.size(); i++) { |
295 | if (i) { |
296 | os << "|" ; |
297 | } |
298 | os << "<" << names[i] << ">" << names[i]; |
299 | } |
300 | os << "}" ; |
301 | } |
302 | |
303 | void dumpLabel(Node *N, std::ostream &os) { |
304 | os << "{" ; |
305 | if (N->getNumInputs()) { |
306 | std::vector<std::string> names(N->getNumInputs()); |
307 | for (size_t i = 0; i < names.size(); i++) { |
308 | names[i] = N->getInputName(i); |
309 | } |
310 | dumpLabelForRow(names, os); |
311 | os << "|" ; |
312 | } |
313 | os << "{" << escapeDottyString(N->getDebugDesc()) << "}" ; |
314 | if (N->getNumResults()) { |
315 | os << "|" ; |
316 | std::vector<std::string> names(N->getNumResults()); |
317 | for (size_t i = 0; i < names.size(); i++) { |
318 | names[i] = N->getOutputName(i).str(); |
319 | } |
320 | dumpLabelForRow(names, os); |
321 | } |
322 | os << "}" ; |
323 | } |
324 | |
325 | void dumpNode(Node *N, bool uniqueNames) { |
326 | if (!N) { |
327 | return; |
328 | } |
329 | std::ostringstream os; |
330 | // Print a node descriptor that looks like this: |
331 | if (uniqueNames) { |
332 | // vNNNN [ shape = "record" label = "{...}" ]; |
333 | os << uniqueVertexName(N) << "[\n" ; |
334 | } else { |
335 | // <name> [ shape = "record" label = "{...}" ]; |
336 | os << N->getName().str() << "[\n" ; |
337 | } |
338 | os << "\tlabel = \"" ; |
339 | dumpLabel(N, os); |
340 | os << "\"\n" ; |
341 | os << "\tshape = \"record\"\n" ; |
342 | os << "\tstyle=\"filled,rounded\"\n" ; |
343 | |
344 | // Pick a color based on the node kind. |
345 | unsigned colorIdx = llvm::hash_value(llvm::StringRef(N->getKindName())); |
346 | auto nodeColor = getDotFileNodeColor(colorIdx); |
347 | |
348 | if (isa<Constant>(N)) { |
349 | os << "\tfillcolor=Snow3 color=DeepSkyBlue4\n" ; |
350 | } else { |
351 | os << "\tfillcolor=" << nodeColor << "\n" ; |
352 | } |
353 | os << "penwidth = 2];\n" ; |
354 | |
355 | vertices_.push_back(os.str()); |
356 | } |
357 | |
358 | void dumpEdgeStyle(const Node *N, size_t i, Node *to, std::ostream &os) { |
359 | if (N->isOverwrittenNthInput(i)) { |
360 | os << " [dir=\"both\"]" ; |
361 | } |
362 | } |
363 | |
364 | std::string uniqueVertexName(void *N) { |
365 | VertexNumberMap::iterator i; |
366 | bool inserted; |
367 | std::tie(i, inserted) = vertex_numbers.insert(std::make_pair(N, 0u)); |
368 | if (inserted) { |
369 | i->second = vertex_numbers.size() - 1; |
370 | } |
371 | |
372 | std::string buffer; |
373 | llvm::raw_string_ostream stream(buffer); |
374 | stream << llvm::format("v%04u" , i->second); |
375 | return stream.str(); |
376 | } |
377 | |
378 | public: |
379 | void dumpAll(std::ostream &os) { |
380 | CHECK(os) << "Failed to create file for to dump Graph" ; |
381 | |
382 | os << "digraph DAG {\n\trankdir=TB;\n" ; |
383 | |
384 | // Dump vertices: |
385 | for (auto &v : vertices_) { |
386 | os << v << "\n" ; |
387 | } |
388 | |
389 | // Dump edges: |
390 | for (auto &e : edges_) { |
391 | os << e << ";\n" ; |
392 | } |
393 | |
394 | os << "}" ; |
395 | } |
396 | }; |
397 | |
398 | class ModuleDottyPrinter : public AbstractDottyPrinter { |
399 | /// Dump Function as a vertix. Then iterate through constants, used in the |
400 | /// function, and create corresponding edges. |
401 | void visitFunction(Function *F) { |
402 | std::ostringstream os; |
403 | // Print a Function descriptor that looks like this: |
404 | // vNNNN [ label = "{...}" ]; |
405 | os << uniqueVertexName(F) << "[\n" |
406 | << "\tlabel = \"Function\\l" |
407 | << "name : " << F->getName().str() << "\\l" |
408 | << "node count : " << F->getNodes().size() << "\"\n" |
409 | << "\tshape = box\n" |
410 | << "\tfillcolor=gray89, style=\"filled,rounded\"\n" |
411 | << "\t\n" |
412 | << "];\n" ; |
413 | vertices_.push_back(os.str()); |
414 | |
415 | for (auto &N : F->getNodes()) { |
416 | for (size_t i = 0; i < N.getNumInputs(); i++) { |
417 | Node *to = N.getNthInput(i).getNode(); |
418 | size_t resNo = N.getNthInput(i).getResNo(); |
419 | |
420 | if (!isa<Constant>(to)) |
421 | continue; |
422 | |
423 | std::ostringstream edge; |
424 | edge << uniqueVertexName(to) << ":" << to->getOutputName(resNo).str() |
425 | << " -> " << uniqueVertexName(F); |
426 | dumpEdgeStyle(&N, i, to, edge); |
427 | edges_.insert(edge.str()); |
428 | } |
429 | } |
430 | } |
431 | |
432 | public: |
433 | void visitModule(Module *M) { |
434 | for (auto N : M->getConstants()) { |
435 | dumpNode(N, true); |
436 | } |
437 | |
438 | for (auto F : M->getFunctions()) { |
439 | visitFunction(F); |
440 | } |
441 | } |
442 | }; |
443 | |
444 | // TODO: consider refactoring boilerplate code to new trait: DottyPrintable<ADP> |
445 | void Module::dumpDAG() { |
446 | llvm::SmallString<64> dotPath; |
447 | llvm::sys::fs::createTemporaryFile("dotty_graph_dump" , "dot" , dotPath); |
448 | dumpDAG(dotPath); |
449 | } |
450 | |
451 | void Module::dumpDAG(llvm::StringRef dotFilename) { |
452 | llvm::outs() << "Writing dotty graph for Module to: " << dotFilename << '\n'; |
453 | |
454 | ModuleDottyPrinter DP; |
455 | |
456 | DP.visitModule(this); |
457 | |
458 | std::ofstream myfile; |
459 | myfile.open(dotFilename.str()); |
460 | if (myfile.fail()) { |
461 | LOG(ERROR) << "Unable to open " << dotFilename.str() |
462 | << ", reason: " << strerror(errno); |
463 | } else { |
464 | DP.dumpAll(myfile); |
465 | } |
466 | myfile.close(); |
467 | } |
468 | |
469 | void Module::dumpDAG(const char *dotFilename) { |
470 | dumpDAG(llvm::StringRef(dotFilename)); |
471 | } |
472 | |
473 | void Module::eraseFunctions() { |
474 | while (!functions_.empty()) { |
475 | eraseFunction(*functions_.begin()); |
476 | } |
477 | } |
478 | |
479 | void Module::eraseFunction(Function *F) { |
480 | auto it = std::find(functions_.begin(), functions_.end(), F); |
481 | assert(it != functions_.end() && "Function is not part of a module" ); |
482 | functions_.erase(it); |
483 | delete F; |
484 | } |
485 | |
486 | uint64_t Module::getConstantsSize() { |
487 | uint64_t size = 0; |
488 | for (auto *constant : constants_) { |
489 | size += constant->getPayload().getSizeInBytes(); |
490 | } |
491 | return size; |
492 | } |
493 | |
494 | /// \returns an Error if any results from \p N are non-fused quantized with |
495 | /// scale == or != dummyScale, depending on \p expectDummy. |
496 | static Error verifyDummyQParamResults(const Node &N, bool expectDummy) { |
497 | for (size_t i = 0, e = N.getNumResults(); i < e; i++) { |
498 | TypeRef T = N.getType(i); |
499 | if (T->isQuantizedType() && !T->isFusedQuantizedType()) { |
500 | const bool isDummy = T->getScale() == dummyScale; |
501 | if (expectDummy) { |
502 | RETURN_ERR_IF_NOT( |
503 | isDummy, strFormat("Expected all dummy scales, but found non-dummy " |
504 | "inside Function %s: %s" , |
505 | N.getParent()->getName().data(), |
506 | N.getDebugDesc().data())); |
507 | } else { |
508 | RETURN_ERR_IF_NOT(!isDummy, |
509 | strFormat("Expected no dummy scales, but found one " |
510 | "inside Function %s: %s" , |
511 | N.getParent()->getName().data(), |
512 | N.getDebugDesc().data())); |
513 | } |
514 | } |
515 | } |
516 | return Error::success(); |
517 | } |
518 | |
519 | Error Module::verifyDummyQParams(bool expectDummies) { |
520 | for (const Function *F : getFunctions()) { |
521 | for (const Node &N : F->getNodes()) { |
522 | RETURN_IF_ERR(verifyDummyQParamResults(N, expectDummies)); |
523 | } |
524 | } |
525 | for (const Placeholder *PH : getPlaceholders()) { |
526 | RETURN_IF_ERR(verifyDummyQParamResults(*PH, expectDummies)); |
527 | } |
528 | for (const Constant *C : getConstants()) { |
529 | RETURN_IF_ERR(verifyDummyQParamResults(*C, expectDummies)); |
530 | } |
531 | return Error::success(); |
532 | } |
533 | |
534 | Function::~Function() { |
535 | // Delete all of the nodes. |
536 | for (auto it = nodes_.begin(), e = nodes_.end(); it != e;) { |
537 | auto cur = it++; |
538 | eraseNode(&*cur); |
539 | } |
540 | } |
541 | |
542 | TypeRef Module::uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims) { |
543 | return uniqueType(Type(elemTy, dims)); |
544 | } |
545 | |
546 | TypeRef Module::uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, |
547 | float scale, int32_t offset) { |
548 | return uniqueType(Type(elemTy, dims, scale, offset)); |
549 | } |
550 | |
551 | TypeRef Module::uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims) { |
552 | return uniqueType(Type::newShape(*T, dims)); |
553 | } |
554 | |
555 | TypeRef Module::uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims, |
556 | llvm::ArrayRef<dim_t> alignments) { |
557 | return uniqueType(Type::newShape(*T, dims, alignments)); |
558 | } |
559 | |
560 | TypeRef Module::uniqueTypeWithNewShape(TypeRef T, TypeRef shapeType) { |
561 | return uniqueType(Type::newShape(*T, shapeType)); |
562 | } |
563 | |
564 | TypeRef Module::uniqueTypeWithNewStrides(TypeRef T, llvm::ArrayRef<dim_t> dims, |
565 | llvm::ArrayRef<dim_t> strides) { |
566 | return uniqueType(Type::newStrides(*T, strides)); |
567 | } |
568 | |
569 | TypeRef Module::uniqueTypeWithNewQuantParams(TypeRef T, |
570 | TypeRef quantParamType) { |
571 | return uniqueType(Type::newQuantparams(*T, quantParamType->getScale(), |
572 | quantParamType->getOffset())); |
573 | } |
574 | |
575 | TypeRef Module::uniqueType(const Type &T) { |
576 | for (auto &tp : types_) { |
577 | if (T.isEqual(tp)) { |
578 | return &tp; |
579 | } |
580 | } |
581 | |
582 | return &*types_.insert(types_.begin(), T); |
583 | } |
584 | |
585 | TypeRef Module::getVoidTy() { return uniqueType(Type()); } |
586 | |
587 | /// \returns a ShapeVector of rank axes.size() less than the input \p dims, |
588 | /// where the provided \p axes dimensions are removed from the shape. |
589 | static ShapeVector getNewShapeWithoutAxes(llvm::ArrayRef<dim_t> dims, |
590 | llvm::ArrayRef<unsigned_t> axes) { |
591 | assert(axes.size() <= dims.size() && |
592 | "Cannot remove more dimensions than exist." ); |
593 | ShapeVector newDims(dims.begin(), dims.end()); |
594 | ShapeVector shapeAxes(axes.begin(), axes.end()); |
595 | |
596 | // Sort so that looping erase below doesn't fail. |
597 | std::sort(shapeAxes.rbegin(), shapeAxes.rend()); |
598 | |
599 | for (const auto &axis : shapeAxes) { |
600 | assert(axis <= dims.size() && |
601 | "Axis to remove must fit inside dimensions of the provided dims." ); |
602 | newDims.erase(newDims.begin() + axis); |
603 | } |
604 | return newDims; |
605 | } |
606 | |
607 | //===----------------------------------------------------------------------===// |
608 | // Node builders |
609 | //===----------------------------------------------------------------------===// |
610 | |
611 | Placeholder *Module::createPlaceholder(TypeRef T, llvm::StringRef name, |
612 | bool isTrainable, |
613 | const std::string &layout) { |
614 | auto FT = uniqueType(*T); |
615 | auto *ph = new Placeholder(name, FT, isTrainable, layout); |
616 | ph->setName(uniqueName(ph->getName(), usedNodeNames_, usedStorageNames_, |
617 | originalNames_)); |
618 | placeholders_.push_back(ph); |
619 | logStorageCreation(functions_, ph); |
620 | return ph; |
621 | } |
622 | |
623 | Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims, |
624 | llvm::StringRef name, bool isTrainable, |
625 | const std::string &layout) { |
626 | auto FT = uniqueType(T, dims); |
627 | return createPlaceholder(FT, name, isTrainable, layout); |
628 | } |
629 | |
630 | Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims, |
631 | float scale, int32_t offset, |
632 | llvm::StringRef name, bool isTrainable, |
633 | const std::string &layout) { |
634 | auto FT = uniqueType(T, dims, scale, offset); |
635 | return createPlaceholder(FT, name, isTrainable, layout); |
636 | } |
637 | |
638 | Constant *Module::createConstant(TypeRef T, llvm::StringRef name, |
639 | const std::string &layout) { |
640 | auto FT = uniqueType(*T); |
641 | return addConstant(new Constant(name, FT, layout)); |
642 | } |
643 | |
644 | Constant *Module::createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims, |
645 | llvm::StringRef name, |
646 | const std::string &layout) { |
647 | auto FT = uniqueType(T, dims); |
648 | return createConstant(FT, name, layout); |
649 | } |
650 | |
651 | Constant *Module::createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims, |
652 | float scale, int32_t offset, |
653 | llvm::StringRef name, |
654 | const std::string &layout) { |
655 | auto FT = uniqueType(T, dims, scale, offset); |
656 | return createConstant(FT, name, layout); |
657 | } |
658 | |
659 | Constant *Module::createConstant(llvm::StringRef name, const Tensor &tensor, |
660 | const std::string &layout) { |
661 | auto *V = createConstant(&tensor.getType(), name, layout); |
662 | V->assign(&tensor); |
663 | return V; |
664 | } |
665 | |
666 | Constant *Module::createConstant(llvm::StringRef name, Tensor &&tensor, |
667 | const std::string &layout) { |
668 | return addConstant(new Constant(name, std::move(tensor), layout)); |
669 | } |
670 | |
671 | std::string Module::getPrefix(llvm::StringRef name) { |
672 | std::string prefix = name.str(); |
673 | size_t delim = name.rfind("__" ); |
674 | if (delim != std::string::npos && |
675 | std::all_of(name.begin() + (delim + 2), name.end(), |
676 | [](unsigned char c) { return ::isdigit(c); })) { |
677 | prefix = prefix.substr(0, delim); |
678 | } |
679 | return prefix; |
680 | } |
681 | |
682 | llvm::StringRef Module::uniqueName(llvm::StringRef name, |
683 | const llvm::StringSet<> &stringTable, |
684 | llvm::StringSet<> &updateTable, |
685 | const llvm::StringSet<> &originalNames) { |
686 | std::string legalName = legalizeName(name); |
687 | if (stringTable.find(legalName) == stringTable.end()) { |
688 | auto it = updateTable.insert(legalName); |
689 | if (it.second) { |
690 | return it.first->first(); |
691 | } |
692 | } |
693 | // Retain the trailing "__[0-9]+" if it is in the original name. |
694 | std::string prefix = (originalNames.find(legalName) == originalNames.end()) |
695 | ? Module::getPrefix(legalName) |
696 | : legalName; |
697 | for (unsigned i = 1; i < 10000; i++) { |
698 | auto suffix = std::to_string(i); |
699 | std::string fullName = prefix + "__" + suffix; |
700 | if (stringTable.find(fullName) != stringTable.end()) { |
701 | continue; |
702 | } |
703 | |
704 | auto it = updateTable.insert(fullName); |
705 | if (it.second) { |
706 | return it.first->first(); |
707 | } |
708 | } |
709 | llvm_unreachable("Unable to find a unique a name." ); |
710 | } |
711 | |
712 | Constant *Module::addConstant(Constant *V) { |
713 | V->setName(uniqueName(V->getName(), usedNodeNames_, usedStorageNames_, |
714 | originalNames_)); |
715 | // Replace the Constant's output type with the equivalent unique type for |
716 | // this Module to maintain the invariant that each type in the Module is |
717 | // unique. |
718 | V->setType(Constant::ResultIndices::OutputIdx, uniqueType(*V->getType())); |
719 | constants_.push_back(V); |
720 | logStorageCreation(functions_, V); |
721 | return V; |
722 | } |
723 | |
724 | /// Check if the 'pads' array has the right size. |
725 | static void assertPadsSize(NodeValue input, llvm::ArrayRef<int> pads) { |
726 | assert((pads.size() == 2 * input.dims().size()) && |
727 | "the pads array must contain 2 values per dimensions" ); |
728 | } |
729 | |
730 | PadNode *Function::createPad(llvm::StringRef name, NodeValue input, |
731 | TypeRef outTy, unsigned_t mode, |
732 | llvm::ArrayRef<int> pads, float value) { |
733 | assertPadsSize(input, pads); |
734 | auto OT = getParent()->uniqueType(*outTy); |
735 | return addNode(new PadNode(name, OT, input, mode, pads, value)); |
736 | } |
737 | |
738 | /// Check the kernel size for Conv/Pooling ops. |
739 | static void checkKernelSize(ShapeNHWC idim, llvm::ArrayRef<unsigned_t> kernels, |
740 | llvm::ArrayRef<unsigned_t> pads) { |
741 | PaddingTLBR pdim(pads); |
742 | (void)pdim; |
743 | ShapeHW kdim(kernels); |
744 | (void)kdim; |
745 | assert((idim.w + pdim.left + pdim.right) >= kdim.width && |
746 | (idim.h + pdim.top + pdim.bottom) >= kdim.height && |
747 | "Kernel size is too large" ); |
748 | } |
749 | |
750 | /// Check the kernel size for 3D Conv/Pooling ops. |
751 | static void check3DKernelSize(ShapeNTHWC idim, |
752 | llvm::ArrayRef<unsigned_t> kernels, |
753 | llvm::ArrayRef<unsigned_t> pads) { |
754 | PaddingNFTBLR pdim(pads); |
755 | (void)pdim; |
756 | ShapeTHW kdim(kernels); |
757 | (void)kdim; |
758 | assert((idim.w + pdim.left + pdim.right) >= kdim.width && |
759 | (idim.h + pdim.top + pdim.bottom) >= kdim.height && |
760 | (idim.t + pdim.near + pdim.far) >= kdim.temporal_frames && |
761 | "Kernel size is too large" ); |
762 | } |
763 | |
764 | /// Check that the dimensions that are passed in when the ConvTranspose is |
765 | /// constructed are correct. |
766 | static void assertConvTransposeDims(NodeValue input, NodeValue filter, |
767 | NodeValue bias, |
768 | llvm::ArrayRef<unsigned_t> kernels, |
769 | llvm::ArrayRef<unsigned_t> strides, |
770 | llvm::ArrayRef<unsigned_t> pads, |
771 | unsigned_t group) { |
772 | ShapeNHWC idim = ShapeNHWC(input.dims()); |
773 | (void)idim; |
774 | ShapeHW kdim(kernels); |
775 | (void)kdim; |
776 | assert(idim.c % group == 0 && "channels number must be divisible by groups" ); |
777 | |
778 | // NOTE: here the N in NHWC is abnormal because it is the number of filters |
779 | // (and therefore the number of output channels of the conv) and not the |
780 | // batch size. The rest of the dimensions are representative of the input |
781 | // dimensions to the convolution. |
782 | ShapeNHWC filterDims(filter.dims()); |
783 | (void)filterDims; |
784 | |
785 | assert(filterDims.h == kdim.height && filterDims.w == kdim.width && |
786 | filterDims.c == idim.c && "Invalid filter dims" ); |
787 | |
788 | assert(bias.getType()->size() == filterDims.n * group && "Invalid bias size" ); |
789 | } |
790 | |
791 | /// Check that the dimensions that are passed in when the convolution is |
792 | /// constructed are correct. |
793 | static void assertConvDims(NodeValue input, NodeValue filter, NodeValue bias, |
794 | llvm::ArrayRef<unsigned_t> kernels, |
795 | llvm::ArrayRef<unsigned_t> strides, |
796 | llvm::ArrayRef<unsigned_t> pads, unsigned_t group) { |
797 | ShapeNHWC idim = ShapeNHWC(input.dims()); |
798 | ShapeHW kdim(kernels); |
799 | (void)kdim; |
800 | checkKernelSize(idim, kernels, pads); |
801 | assert(idim.c % group == 0 && "channels number must be divisible by groups" ); |
802 | |
803 | // NOTE: here the N in NHWC is abnormal because it is the number of filters |
804 | // (and therefore the number of output channels of the conv) and not the |
805 | // batch size. The rest of the dimensions are representative of the input |
806 | // dimensions to the convolution. |
807 | ShapeNHWC filterDims(filter.dims()); |
808 | (void)filterDims; |
809 | |
810 | assert(filterDims.n % group == 0 && filterDims.h == kdim.height && |
811 | filterDims.w == kdim.width && filterDims.c == idim.c / group && |
812 | "Invalid filter dims" ); |
813 | |
814 | assert(bias.getType()->size() == filterDims.n && "Invalid bias size" ); |
815 | } |
816 | |
817 | /// Check that the dimensions that are passed in when the 3D convolution is |
818 | /// constructed are correct. |
819 | static void assertConv3DDims(NodeValue input, NodeValue filter, NodeValue bias, |
820 | llvm::ArrayRef<unsigned_t> kernels, |
821 | llvm::ArrayRef<unsigned_t> strides, |
822 | llvm::ArrayRef<unsigned_t> pads, |
823 | unsigned_t group) { |
824 | ShapeNTHWC idim(input.dims()); |
825 | ShapeTHW kdim(kernels); |
826 | (void)kdim; |
827 | check3DKernelSize(idim, kernels, pads); |
828 | assert(idim.c % group == 0 && "channels number must be divisible by groups" ); |
829 | |
830 | // NOTE: here the N in NTHWC is abnormal because it is the number of filters |
831 | // (and therefore the number of output channels of the 3d conv) and not the |
832 | // batch size. The rest of the dimensions are representative of the input |
833 | // dimensions to the convolution. |
834 | ShapeNTHWC filterDims(filter.dims()); |
835 | (void)filterDims; |
836 | |
837 | assert(filterDims.n % group == 0 && filterDims.h == kdim.height && |
838 | filterDims.w == kdim.width && filterDims.t == kdim.temporal_frames && |
839 | filterDims.c == idim.c / group && "Invalid filter dims" ); |
840 | |
841 | assert(bias.getType()->size() == filterDims.n && "Invalid bias size" ); |
842 | } |
843 | |
844 | ConvolutionNode *Function::createConv( |
845 | llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias, |
846 | TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels, |
847 | llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads, |
848 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation, |
849 | ConvolutionLayout layout) { |
850 | assertConvDims(input, filter, bias, kernels, strides, pads, group); |
851 | auto OT = getParent()->uniqueType(*outTy); |
852 | |
853 | // If the input is quantized but the bias is not then auto-quantize the |
854 | // bias. |
855 | if (input.getType()->isQuantizedType()) { |
856 | auto biasType = bias.getElementType(); |
857 | if (biasType == ElemKind::Int32QTy || biasType == ElemKind::Int8QTy) { |
858 | // Nothing to do |
859 | } else if (biasType == ElemKind::FloatTy) { |
860 | auto biasTy = getParent()->uniqueType( |
861 | glow::ElemKind::Int32QTy, bias.dims(), |
862 | input.getType()->getScale() * filter.getType()->getScale(), |
863 | /* offset */ 0); |
864 | bias = createQuantize("quantized_bias" , bias, biasTy); |
865 | } else { |
866 | LOG(DFATAL) |
867 | << "Unsupported element type for bias of quantized convolution: " |
868 | << Type::getElementName(biasType).str(); |
869 | } |
870 | } |
871 | |
872 | return addNode(new ConvolutionNode(name, OT, input, filter, bias, kernels, |
873 | strides, pads, group, dilation, layout, |
874 | FusedActivation::NONE, {})); |
875 | } |
876 | |
877 | ConvolutionNode *Function::createConv(llvm::StringRef name, NodeValue input, |
878 | NodeValue filter, NodeValue bias, |
879 | TypeRef outTy, unsigned_t kernel, |
880 | unsigned_t stride, unsigned_t pad, |
881 | unsigned_t group, |
882 | llvm::ArrayRef<unsigned_t> dilation, |
883 | ConvolutionLayout layout) { |
884 | llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad}; |
885 | llvm::SmallVector<unsigned_t, 2> strides = {stride, stride}; |
886 | llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel}; |
887 | return createConv(name, input, filter, bias, outTy, kernels, strides, pads, |
888 | group, dilation, layout); |
889 | } |
890 | |
891 | Convolution3DNode *Function::createConv3D(llvm::StringRef name, NodeValue input, |
892 | NodeValue filter, NodeValue bias, |
893 | TypeRef outTy, |
894 | llvm::ArrayRef<unsigned_t> kernels, |
895 | llvm::ArrayRef<unsigned_t> strides, |
896 | llvm::ArrayRef<unsigned_t> pads, |
897 | unsigned_t group) { |
898 | assertConv3DDims(input, filter, bias, kernels, strides, pads, group); |
899 | auto OT = getParent()->uniqueType(*outTy); |
900 | |
901 | // If the input is quantized but the bias is not then auto-quantize the |
902 | // bias. |
903 | if (input.getType()->isQuantizedType()) { |
904 | auto biasType = bias.getElementType(); |
905 | if (biasType == ElemKind::Int32QTy || biasType == ElemKind::Int8QTy || |
906 | biasType == ElemKind::Int16QTy) { |
907 | // Nothing to do |
908 | } else if (biasType == ElemKind::FloatTy) { |
909 | auto biasTy = getParent()->uniqueType( |
910 | glow::ElemKind::Int32QTy, bias.dims(), |
911 | input.getType()->getScale() * filter.getType()->getScale(), |
912 | /* offset */ 0); |
913 | bias = createQuantize("quantized_bias" , bias, biasTy); |
914 | } else { |
915 | LOG(DFATAL) |
916 | << "Unsupported element type for bias of quantized convolution: " |
917 | << Type::getElementName(biasType).str(); |
918 | } |
919 | } |
920 | return addNode(new Convolution3DNode(name, OT, input, filter, bias, kernels, |
921 | strides, pads, group)); |
922 | } |
923 | |
924 | Convolution3DNode *Function::createConv3D(llvm::StringRef name, NodeValue input, |
925 | NodeValue filter, NodeValue bias, |
926 | TypeRef outTy, unsigned_t kernel, |
927 | unsigned_t stride, unsigned_t pad, |
928 | unsigned_t group) { |
929 | llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad}; |
930 | llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride}; |
931 | llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel}; |
932 | return createConv3D(name, input, filter, bias, outTy, kernels, strides, pads, |
933 | group); |
934 | } |
935 | |
936 | ConvTransposeNode *Function::createConvTranspose( |
937 | llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias, |
938 | TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels, |
939 | llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads, |
940 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation) { |
941 | assertConvTransposeDims(input, filter, bias, kernels, strides, pads, group); |
942 | auto OT = getParent()->uniqueType(*outTy); |
943 | return addNode(new ConvTransposeNode(name, OT, input, filter, bias, kernels, |
944 | strides, pads, group, dilation)); |
945 | } |
946 | |
947 | ConvTransposeNode *Function::createConvTranspose( |
948 | llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias, |
949 | TypeRef outTy, unsigned_t kernel, unsigned_t stride, unsigned_t pad, |
950 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation) { |
951 | llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad}; |
952 | llvm::SmallVector<unsigned_t, 2> strides = {stride, stride}; |
953 | llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel}; |
954 | return createConvTranspose(name, input, filter, bias, outTy, kernels, strides, |
955 | pads, group, dilation); |
956 | } |
957 | |
958 | MaxPoolNode *Function::createMaxPool(llvm::StringRef name, NodeValue input, |
959 | llvm::ArrayRef<unsigned_t> kernels, |
960 | llvm::ArrayRef<unsigned_t> strides, |
961 | llvm::ArrayRef<unsigned_t> pads, |
962 | ElemKind elemTyAMT, |
963 | ConvolutionLayout layout) { |
964 | ShapeNHWC idim = ShapeNHWC(input.dims()); |
965 | checkKernelSize(idim, kernels, pads); |
966 | |
967 | auto outSz = |
968 | calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads); |
969 | auto OT = getParent()->uniqueTypeWithNewShape( |
970 | input.getType(), {idim.n, outSz.first, outSz.second, idim.c}); |
971 | auto AMT = getParent()->uniqueType( |
972 | elemTyAMT, {idim.n, outSz.first, outSz.second, idim.c}); |
973 | |
974 | return addNode( |
975 | new MaxPoolNode(name, OT, AMT, input, kernels, strides, pads, layout)); |
976 | } |
977 | |
978 | MaxPoolNode *Function::createMaxPool(llvm::StringRef name, NodeValue input, |
979 | unsigned_t kernel, unsigned_t stride, |
980 | unsigned_t pad, ElemKind elemTyAMT, |
981 | ConvolutionLayout layout) { |
982 | llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad}; |
983 | llvm::SmallVector<unsigned_t, 2> strides = {stride, stride}; |
984 | llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel}; |
985 | return createMaxPool(name, input, kernels, strides, pads, elemTyAMT, layout); |
986 | } |
987 | |
988 | AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input, |
989 | llvm::ArrayRef<unsigned_t> kernels, |
990 | llvm::ArrayRef<unsigned_t> strides, |
991 | llvm::ArrayRef<unsigned_t> pads, |
992 | ConvolutionLayout layout, |
993 | bool countIncludePads) { |
994 | if (!is3DData(layout)) { |
995 | |
996 | ShapeNHWC idim = ShapeNHWC(input.dims()); |
997 | checkKernelSize(idim, kernels, pads); |
998 | |
999 | auto outSz = |
1000 | calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads); |
1001 | auto OT = getParent()->uniqueTypeWithNewShape( |
1002 | input.getType(), {idim.n, outSz.first, outSz.second, idim.c}); |
1003 | return addNode(new AvgPoolNode(name, OT, input, kernels, strides, pads, |
1004 | layout, countIncludePads)); |
1005 | |
1006 | } else { |
1007 | ShapeNTHWC idim = ShapeNTHWC(input.dims()); |
1008 | check3DKernelSize(idim, kernels, pads); |
1009 | |
1010 | auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels, |
1011 | strides, pads); |
1012 | auto OT = getParent()->uniqueTypeWithNewShape( |
1013 | input.getType(), |
1014 | {idim.n, outSz.temporal_frames, outSz.height, outSz.width, idim.c}); |
1015 | return addNode(new AvgPoolNode(name, OT, input, kernels, strides, pads, |
1016 | layout, countIncludePads)); |
1017 | } |
1018 | } |
1019 | |
1020 | AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input, |
1021 | TypeRef outTy, |
1022 | llvm::ArrayRef<unsigned_t> kernels, |
1023 | llvm::ArrayRef<unsigned_t> strides, |
1024 | llvm::ArrayRef<unsigned_t> pads, |
1025 | ConvolutionLayout layout, |
1026 | bool countIncludePads) { |
1027 | if (!is3DData(layout)) { |
1028 | |
1029 | ShapeNHWC idim = ShapeNHWC(input.dims()); |
1030 | ShapeHW kdim(kernels); |
1031 | (void)kdim; |
1032 | checkKernelSize(idim, kernels, pads); |
1033 | return addNode(new AvgPoolNode(name, outTy, input, kernels, strides, pads, |
1034 | layout, countIncludePads)); |
1035 | |
1036 | } else { |
1037 | |
1038 | ShapeNTHWC idim = ShapeNTHWC(input.dims()); |
1039 | ShapeTHW kdim(kernels); |
1040 | (void)kdim; |
1041 | check3DKernelSize(idim, kernels, pads); |
1042 | return addNode(new AvgPoolNode(name, outTy, input, kernels, strides, pads, |
1043 | layout, countIncludePads)); |
1044 | } |
1045 | } |
1046 | |
1047 | AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input, |
1048 | unsigned_t kernel, unsigned_t stride, |
1049 | unsigned_t pad, ConvolutionLayout layout, |
1050 | bool countIncludePads) { |
1051 | if (!is3DData(layout)) { |
1052 | |
1053 | llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad}; |
1054 | llvm::SmallVector<unsigned_t, 2> strides = {stride, stride}; |
1055 | llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel}; |
1056 | return createAvgPool(name, input, kernels, strides, pads, layout, |
1057 | countIncludePads); |
1058 | |
1059 | } else { |
1060 | |
1061 | llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad}; |
1062 | llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride}; |
1063 | llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel}; |
1064 | return createAvgPool(name, input, kernels, strides, pads, layout, |
1065 | countIncludePads); |
1066 | } |
1067 | } |
1068 | |
1069 | AdaptiveAvgPoolNode *Function::createAdaptiveAvgPool(llvm::StringRef name, |
1070 | NodeValue input, |
1071 | TypeRef outTy) { |
1072 | return addNode(new AdaptiveAvgPoolNode(name, outTy, input)); |
1073 | } |
1074 | |
1075 | GemmNode *Function::createGemm(llvm::StringRef name, NodeValue A, NodeValue B, |
1076 | NodeValue C, float alpha, float beta, |
1077 | bool transposeA, bool transposeB) { |
1078 | std::vector<dim_t> outDims(2); |
1079 | outDims[0] = transposeA ? A.dims()[1] : A.dims()[0]; |
1080 | outDims[1] = transposeB ? B.dims()[0] : B.dims()[1]; |
1081 | TypeRef outTy = getParent()->uniqueTypeWithNewShape(A.getType(), outDims); |
1082 | return createGemm(name, outTy, A, B, C, alpha, beta, transposeA, transposeB); |
1083 | } |
1084 | |
1085 | GemmNode *Function::createGemm(llvm::StringRef name, TypeRef outTy, NodeValue A, |
1086 | NodeValue B, NodeValue C, float alpha, |
1087 | float beta, bool transposeA, bool transposeB) { |
1088 | // If C operand is not given then we create a 1D splat with 0. |
1089 | if (!C.getNode()) { |
1090 | TypeRef splatTy = |
1091 | getParent()->uniqueTypeWithNewShape(outTy, {outTy->dims()[1]}); |
1092 | C = createSplat(name.str() + ".SplatC" , splatTy, 0.0f); |
1093 | } |
1094 | // If C operand is a 2D constant we check if it is a broadcasted version of |
1095 | // a 1D tensor. If yes then we slice and reshape the C operand to 1D. |
1096 | if (auto *constC = llvm::dyn_cast<Constant>(C.getNode())) { |
1097 | if ((constC->dims().size() == 2) && (constC->getPayload().isTiled(0))) { |
1098 | // Slice and reshape to 1D. |
1099 | dim_t lengthC = constC->dims()[1]; |
1100 | C = createSlice(name.str() + ".SliceC" , C, {0, 0}, {1, lengthC}); |
1101 | C = createReshape(name.str() + ".ReshapeC" , C, {lengthC}); |
1102 | } |
1103 | } |
1104 | TypeRef OT = getParent()->uniqueType(*outTy); |
1105 | return addNode( |
1106 | new GemmNode(name, OT, A, B, C, alpha, beta, transposeA, transposeB)); |
1107 | } |
1108 | |
1109 | DynamicQuantizedFullyConnectedNode * |
1110 | Function::createDynamicQuantizedFullyConnected(llvm::StringRef name, |
1111 | NodeValue input, NodeValue W, |
1112 | NodeValue B, bool isSymmetric, |
1113 | bool isPerBatchElement) { |
1114 | TypeRef T = input.getType(); |
1115 | TypeRef OT = |
1116 | getParent()->uniqueTypeWithNewShape(T, {input.dims()[0], B.dims()[0]}); |
1117 | return addNode(new DynamicQuantizedFullyConnectedNode( |
1118 | name, OT, input, W, B, isSymmetric, isPerBatchElement)); |
1119 | } |
1120 | |
1121 | DynamicRowwiseQuantizedFullyConnectedNode * |
1122 | Function::createDynamicRowwiseQuantizedFullyConnected( |
1123 | llvm::StringRef name, NodeValue input, NodeValue W, NodeValue B, |
1124 | NodeValue scales, NodeValue offsets, bool isSymmetric, |
1125 | bool isPerBatchElement) { |
1126 | TypeRef T = input.getType(); |
1127 | TypeRef OT = |
1128 | getParent()->uniqueTypeWithNewShape(T, {input.dims()[0], B.dims()[0]}); |
1129 | return addNode(new DynamicRowwiseQuantizedFullyConnectedNode( |
1130 | name, OT, input, W, B, scales, offsets, isSymmetric, isPerBatchElement)); |
1131 | } |
1132 | |
1133 | FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name, |
1134 | NodeValue input, Storage *W, |
1135 | Storage *B, |
1136 | unsigned_t axis) { |
1137 | TypeRef T = input.getType(); |
1138 | TypeRef OT = getParent()->uniqueTypeWithNewShape( |
1139 | T, {input.dims()[0], B->getType()->dims()[0]}); |
1140 | |
1141 | return createFullyConnected(name, input, W, B, OT, axis); |
1142 | } |
1143 | |
1144 | FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name, |
1145 | NodeValue input, NodeValue W, |
1146 | NodeValue B, |
1147 | unsigned_t axis) { |
1148 | TypeRef T = input.getType(); |
1149 | TypeRef OT = |
1150 | getParent()->uniqueTypeWithNewShape(T, {input.dims()[0], B.dims()[0]}); |
1151 | |
1152 | return createFullyConnected(name, input, W, B, OT, axis); |
1153 | } |
1154 | |
1155 | FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name, |
1156 | NodeValue input, NodeValue W, |
1157 | NodeValue B, TypeRef outTy, |
1158 | unsigned_t axis) { |
1159 | assert(outTy->dims().size() == 2 && "Invalid number of dimensions" ); |
1160 | |
1161 | // FC always uses 2D input; flatten if necessary. |
1162 | if (input.dims().size() != 2) { |
1163 | input = createFlatten(name.str() + ".reshape2D" , input, axis); |
1164 | } |
1165 | |
1166 | TypeRef OT = getParent()->uniqueType(*outTy); |
1167 | return addNode(new FullyConnectedNode(name, OT, input, W, B)); |
1168 | } |
1169 | |
1170 | RowwiseQuantizedFullyConnectedNode * |
1171 | Function::createRowwiseQuantizedFullyConnected(llvm::StringRef name, |
1172 | NodeValue input, NodeValue W, |
1173 | Constant *scales, |
1174 | Constant *offsets, NodeValue B, |
1175 | TypeRef outTy) { |
1176 | return addNode(new RowwiseQuantizedFullyConnectedNode(name, outTy, input, W, |
1177 | scales, offsets, B)); |
1178 | } |
1179 | |
1180 | RowwiseQuantizedFullyConnectedNode * |
1181 | Function::createRowwiseQuantizedFullyConnected(llvm::StringRef name, |
1182 | NodeValue input, NodeValue W, |
1183 | NodeValue B, TypeRef outTy, |
1184 | quantization::Schema schema, |
1185 | bool transposeWeight) { |
1186 | // Since W is constant, quantize it in compilation time. |
1187 | // The quantized data is in qWeights, the scale of each row is in scales, |
1188 | // and the offset of each row is in offsets. |
1189 | Constant *weights = llvm::cast<Constant>(W); |
1190 | CHECK(weights) |
1191 | << "Expected RowwiseQuantizedFullyConnected weights to be a Constant" ; |
1192 | |
1193 | dim_t numRows = transposeWeight ? weights->getType()->dims()[1] |
1194 | : weights->getType()->dims()[0]; |
1195 | dim_t numCols = transposeWeight ? weights->getType()->dims()[0] |
1196 | : weights->getType()->dims()[1]; |
1197 | |
1198 | // So far, if we want to create a storage with Int8QTy/Int16QTy, |
1199 | // it is assumed to be quantized data and the scale and offset should be |
1200 | // provided. But for rowwise quantization, the scales and offsets are stored |
1201 | // in vectors separately, we add the dummy scale and offset here. |
1202 | auto *qWeights = getParent()->createConstant( |
1203 | ElemKind::Int8QTy, {numRows, numCols}, 0.0, 0, "weights.rwqfc" ); |
1204 | auto *scales = |
1205 | getParent()->createConstant(ElemKind::FloatTy, {numRows}, "scales.rwqfc" ); |
1206 | auto *offsets = getParent()->createConstant(ElemKind::Int32ITy, {numRows}, |
1207 | "offsets.rwqfc" ); |
1208 | |
1209 | Tensor wt; |
1210 | if (transposeWeight) { |
1211 | // This happens when the RowwiseQuantizedFullyConnected node is converted |
1212 | // from a quantized FullyConnected node in Glow's quantization procedure. |
1213 | // Since in FC, the weights is stored as transposed (i.e. I * W + B), but |
1214 | // in RowwiseQuantizedFullyConnected, the weights is stored as it is (i.e. |
1215 | // I * W(T) + B). |
1216 | weights->getPayloadMutable().transpose(&wt, {1, 0}); |
1217 | } else { |
1218 | wt.assign(&(weights->getPayload())); |
1219 | } |
1220 | |
1221 | // Note: Using int32_t offset here as that is what RWQ-FC expects. |
1222 | quantization::tensorRowwiseQuantization<float, int32_t, int8_t>( |
1223 | wt, qWeights->getPayloadMutable(), scales->getPayloadMutable(), |
1224 | offsets->getPayloadMutable(), schema); |
1225 | |
1226 | return addNode(new RowwiseQuantizedFullyConnectedNode( |
1227 | name, outTy, input, qWeights, scales, offsets, B)); |
1228 | } |
1229 | |
1230 | ReluNode *Function::createRelu(llvm::StringRef name, TypeRef outTy, |
1231 | NodeValue input) { |
1232 | return addNode(new ReluNode(name, outTy, input)); |
1233 | } |
1234 | |
1235 | ReluNode *Function::createRELU(llvm::StringRef name, NodeValue input, |
1236 | TypeRef outTy) { |
1237 | return createRelu(name, outTy, input); |
1238 | } |
1239 | |
1240 | ReluNode *Function::createRelu(llvm::StringRef name, NodeValue input) { |
1241 | return createRelu(name, input.getType(), input); |
1242 | } |
1243 | |
1244 | ReluNode *Function::createRELU(llvm::StringRef name, NodeValue input) { |
1245 | return createRelu(name, input); |
1246 | } |
1247 | |
1248 | GeluNode *Function::createGelu(llvm::StringRef name, NodeValue input) { |
1249 | return addNode(new GeluNode(name, input.getType(), input)); |
1250 | } |
1251 | |
1252 | GeluNode *Function::createGELU(llvm::StringRef name, NodeValue input) { |
1253 | return createGelu(name, input); |
1254 | } |
1255 | |
1256 | PReluNode *Function::createPRELU(llvm::StringRef name, NodeValue input, |
1257 | NodeValue slope, TypeRef outTy) { |
1258 | return addNode(new PReluNode(name, outTy, input, slope)); |
1259 | } |
1260 | |
1261 | PReluNode *Function::createPRELU(llvm::StringRef name, NodeValue input, |
1262 | NodeValue slope) { |
1263 | return addNode(new PReluNode(name, input.getType(), input, slope)); |
1264 | } |
1265 | |
1266 | SigmoidNode *Function::createSigmoid(llvm::StringRef name, TypeRef outTy, |
1267 | NodeValue input) { |
1268 | return addNode(new SigmoidNode(name, outTy, input)); |
1269 | } |
1270 | |
1271 | SigmoidNode *Function::createSigmoid(llvm::StringRef name, NodeValue input) { |
1272 | return createSigmoid(name, input.getType(), input); |
1273 | } |
1274 | |
1275 | SwishNode *Function::createSwish(llvm::StringRef name, NodeValue input) { |
1276 | return createSwish(name, getParent()->uniqueType(*input.getType()), input); |
1277 | } |
1278 | |
1279 | SwishNode *Function::createSwish(llvm::StringRef name, TypeRef OT, |
1280 | NodeValue input) { |
1281 | return addNode(new SwishNode(name, OT, input)); |
1282 | } |
1283 | |
1284 | SwishNode *Function::createSwish(llvm::StringRef name, NodeValue input, |
1285 | TypeRef OT) { |
1286 | return createSwish(name, OT, input); |
1287 | } |
1288 | |
1289 | ClipNode *Function::createHardSigmoid(llvm::StringRef name, TypeRef outTy, |
1290 | NodeValue input, float alpha, |
1291 | float beta) { |
1292 | auto ty = input.getType(); |
1293 | |
1294 | // max(0, min(1, alpha * x + beta)) |
1295 | auto *alphaSplat = createSplat(name.str() + ".alpha" , ty, alpha); |
1296 | auto *betaSplat = createSplat(name.str() + ".beta" , ty, beta); |
1297 | |
1298 | auto *mul = createMul(name.str() + ".mul" , alphaSplat, input); |
1299 | auto *add = createAdd(name.str() + ".add" , mul, betaSplat); |
1300 | |
1301 | return createClip(name.str() + ".clip" , add, outTy, 0, 1); |
1302 | } |
1303 | |
1304 | ClipNode *Function::createHardSigmoid(llvm::StringRef name, NodeValue input, |
1305 | float alpha, float beta) { |
1306 | return createHardSigmoid(name, input.getType(), input, alpha, beta); |
1307 | } |
1308 | |
1309 | TanhNode *Function::createTanh(llvm::StringRef name, TypeRef outTy, |
1310 | NodeValue input) { |
1311 | return addNode(new TanhNode(name, outTy, input)); |
1312 | } |
1313 | |
1314 | TanhNode *Function::createTanh(llvm::StringRef name, NodeValue input) { |
1315 | return createTanh(name, input.getType(), input); |
1316 | } |
1317 | |
1318 | SoftPlusNode *Function::createSoftPlus(llvm::StringRef name, NodeValue input, |
1319 | TypeRef outTy) { |
1320 | if (!outTy) { |
1321 | outTy = getParent()->uniqueType(*input.getType()); |
1322 | } |
1323 | return addNode(new SoftPlusNode(name, outTy, input)); |
1324 | } |
1325 | |
1326 | SoftMaxNode *Function::createSoftMax(llvm::StringRef name, NodeValue input, |
1327 | NodeValue selected, TypeRef outTy, |
1328 | float beta) { |
1329 | // Create input multiplier with beta. |
1330 | if (beta != 1.0) { |
1331 | auto *splat = createSplat(name, input.getType(), beta); |
1332 | input = createMul(name, input, splat); |
1333 | } |
1334 | // By default, pick the input type. |
1335 | if (!outTy) { |
1336 | outTy = getParent()->uniqueType(*input.getType()); |
1337 | } |
1338 | return addNode(new SoftMaxNode(name, outTy, input, selected)); |
1339 | } |
1340 | |
1341 | LogSoftMaxNode *Function::createLogSoftMax(llvm::StringRef name, |
1342 | NodeValue input, NodeValue selected, |
1343 | TypeRef outTy, float beta) { |
1344 | // Create input multiplier with beta. |
1345 | if (beta != 1.0) { |
1346 | auto *splat = createSplat(name, input.getType(), beta); |
1347 | input = createMul(name, input, splat); |
1348 | } |
1349 | // By default, pick the input type. |
1350 | if (!outTy) { |
1351 | outTy = getParent()->uniqueType(*input.getType()); |
1352 | } |
1353 | return addNode(new LogSoftMaxNode(name, outTy, input, selected)); |
1354 | } |
1355 | |
1356 | CrossEntropyLossNode *Function::createCrossEntropyLoss(llvm::StringRef name, |
1357 | NodeValue input, |
1358 | NodeValue labels) { |
1359 | auto ty = getParent()->uniqueTypeWithNewShape(input.getType(), {1}); |
1360 | return addNode(new CrossEntropyLossNode(name, ty, input, labels)); |
1361 | } |
1362 | |
1363 | RegressionNode *Function::createRegression(llvm::StringRef name, |
1364 | NodeValue input, |
1365 | NodeValue expected) { |
1366 | return addNode(new RegressionNode(name, input, expected)); |
1367 | } |
1368 | |
1369 | SigmoidCrossEntropyWithLogitsNode * |
1370 | Function::createSigmoidCrossEntropyWithLogits(llvm::StringRef name, |
1371 | NodeValue logits, |
1372 | NodeValue targets) { |
1373 | assert(logits.dims().size() > 1); |
1374 | std::vector<dim_t> outDims(logits.dims().begin(), logits.dims().end() - 1); |
1375 | auto ty = getParent()->uniqueTypeWithNewShape(logits.getType(), outDims); |
1376 | return addNode( |
1377 | new SigmoidCrossEntropyWithLogitsNode(name, ty, logits, targets)); |
1378 | } |
1379 | |
1380 | ReshapeNode *Function::createReshape(llvm::StringRef name, NodeValue input, |
1381 | llvm::ArrayRef<dim_t> shape, |
1382 | llvm::StringRef layout) { |
1383 | auto TR = getParent()->uniqueTypeWithNewShape(input.getType(), shape); |
1384 | DCHECK_EQ(TR->size(), input.getType()->size()) |
1385 | << "Reshape to a different size" ; |
1386 | return addNode( |
1387 | new ReshapeNode(name.str(), TR, input, shape.vec(), layout.str())); |
1388 | } |
1389 | |
1390 | TransposeNode *Function::createTranspose(llvm::StringRef name, NodeValue input, |
1391 | llvm::ArrayRef<unsigned_t> shuffle, |
1392 | const std::string &layout) { |
1393 | ShapeVector shape; |
1394 | auto dims = input.dims(); |
1395 | for (size_t i = 0; i < dims.size(); i++) { |
1396 | shape.push_back(dims[shuffle[i]]); |
1397 | } |
1398 | |
1399 | // If the layout is known, check that it matches the shuffle: |
1400 | auto compareShuffle = [&](const std::vector<unsigned_t> targetShuffle) { |
1401 | auto shuffleVec = shuffle.vec(); |
1402 | return targetShuffle.size() == dims.size() && |
1403 | std::equal(shuffleVec.begin(), shuffleVec.end(), |
1404 | targetShuffle.begin()); |
1405 | }; |
1406 | |
1407 | auto currLayout = layout; |
1408 | if (currLayout == ANY_LAYOUT) { |
1409 | // If layout got a default value, change it based on shuffle: |
1410 | // TODO: remove the shuffle and replace it with layout. |
1411 | if (compareShuffle(NCHW2NHWC) || compareShuffle(HWCN2NHWC)) { |
1412 | currLayout = "NHWC" ; |
1413 | } else if (compareShuffle(NCTHW2NTHWC)) { |
1414 | currLayout = "NTHWC" ; |
1415 | } else if (compareShuffle(NHWC2NCHW)) { |
1416 | currLayout = "NCHW" ; |
1417 | } else if (compareShuffle(NTHWC2NCTHW)) { |
1418 | currLayout = "NCTHW" ; |
1419 | } else if (compareShuffle(NHWC2HWNC)) { |
1420 | currLayout = "HWNC" ; |
1421 | } else if (compareShuffle(CNHW2NHWC)) { |
1422 | currLayout = "NHWC" ; |
1423 | } |
1424 | } |
1425 | |
1426 | auto NT = getParent()->uniqueTypeWithNewShape(input.getType(), shape); |
1427 | return addNode(new TransposeNode(name, NT, input, shuffle.vec(), currLayout)); |
1428 | } |
1429 | |
1430 | FlipNode *Function::createFlip(llvm::StringRef name, NodeValue input, |
1431 | unsigned_t axis) { |
1432 | auto OT = getParent()->uniqueType(*input.getType()); |
1433 | return addNode(new FlipNode(name, OT, input, axis)); |
1434 | } |
1435 | |
1436 | BroadcastNode *Function::createBroadcast(llvm::StringRef name, NodeValue input, |
1437 | UnsignedArrayRef newShape, |
1438 | unsigned_t axis) { |
1439 | auto OT = getParent()->uniqueTypeWithNewShape(input.getType(), newShape); |
1440 | return addNode(new BroadcastNode(name, OT, input, axis, newShape.vec())); |
1441 | } |
1442 | |
1443 | std::array<Node *, 2> Function::createRMSNorm(llvm::StringRef name, NodeValue X, |
1444 | NodeValue gamma, NodeValue beta, |
1445 | float epsilon) { |
1446 | // np.square(X) |
1447 | auto square = createSquare(name.str() + ".square" , X); |
1448 | |
1449 | // np.mean(np.square(X), axis=1) |
1450 | auto mean = createBatchedReduceMean(name.str() + ".mean" , square, {1}); |
1451 | |
1452 | // np.mean(np.square(X), axis=1) + eps |
1453 | auto eps = getParent()->createConstant(ElemKind::FloatTy, 1, "eps" ); |
1454 | eps->getPayloadMutable() = {epsilon}; |
1455 | auto bcastEps = |
1456 | createBroadcast(name.str() + ".bcastEps" , eps, {X.dims()[0]}, 0); |
1457 | auto addEps = createAdd(name.str() + ".addEps" , mean, bcastEps); |
1458 | |
1459 | // np.sqrt(np.mean(np.square(X), axis=1) + eps) |
1460 | auto sqrt = createPow(name.str() + ".sqrt" , addEps, 0.5f); |
1461 | |
1462 | // rrms = 1.0 / np.sqrt(np.mean(np.square(X), axis=1) + eps) |
1463 | auto one = getParent()->createConstant(ElemKind::FloatTy, 1, "one" ); |
1464 | one->getPayloadMutable() = {1}; |
1465 | auto bcastOne = |
1466 | createBroadcast(name.str() + ".bcastOne" , one, {X.dims()[0]}, 0); |
1467 | auto rrms = createDiv(name.str() + ".rrms" , bcastOne, sqrt); |
1468 | |
1469 | // np.expand_dims(rrms, axis=1) |
1470 | auto reshape = createReshape(name.str() + ".expandD" , rrms, {X.dims()[0], 1}); |
1471 | auto bcastReshape = |
1472 | createBroadcast(name.str() + ".bcastReshape" , reshape, X.dims(), 0); |
1473 | |
1474 | // X * np.expand_dims(rrms, axis=1) |
1475 | auto mul = createMul(name.str() + "mul" , X, bcastReshape); |
1476 | |
1477 | // X * np.expand_dims(rrms, axis=1) * gamma |
1478 | auto bcastGamma = |
1479 | createBroadcast(name.str() + ".bcastGamma" , gamma, X.dims(), 1); |
1480 | auto mulGamma = createMul(name.str() + ".mulGamma" , mul, bcastGamma); |
1481 | |
1482 | // Y = X * np.expand_dims(rrms, axis=1) * gamma + beta |
1483 | auto bcastBeta = |
1484 | createBroadcast(name.str() + ".bcastBeta" , beta, X.dims(), 1); |
1485 | auto Y = createAdd(name.str() + ".Y" , mulGamma, bcastBeta); |
1486 | |
1487 | return {Y, rrms}; |
1488 | } |
1489 | |
1490 | /// \returns true if \p T1 and T2 has the exact same type except for dimension |
1491 | /// \p dim. It will log an error when returning false. |
1492 | static bool sameSameShapeExceptDim(TypeRef T1, TypeRef T2, unsigned dim) { |
1493 | if (T1->getElementType() != T2->getElementType()) { |
1494 | LOG(ERROR) << "Different types " << (int)T1->getElementType() << " " |
1495 | << (int)T2->getElementType(); |
1496 | return false; |
1497 | } |
1498 | |
1499 | auto D1 = T1->dims(); |
1500 | auto D2 = T2->dims(); |
1501 | |
1502 | if (D1.size() != D2.size()) { |
1503 | LOG(ERROR) << "Different size " << D1.size() << " " << D2.size(); |
1504 | return false; |
1505 | } |
1506 | |
1507 | for (unsigned i = 0, e = D1.size(); i < e; i++) { |
1508 | // Ignore the dimension \p dim. |
1509 | if (i == dim) { |
1510 | continue; |
1511 | } |
1512 | |
1513 | if (D1[i] != D2[i]) { |
1514 | LOG(ERROR) << "Different dimension at " << i << " " << D1[i] << " " |
1515 | << D2[i]; |
1516 | return false; |
1517 | } |
1518 | } |
1519 | |
1520 | return true; |
1521 | } |
1522 | |
1523 | ConcatNode *Function::createConcat(llvm::StringRef name, |
1524 | llvm::ArrayRef<NodeValue> inputs, |
1525 | unsigned_t dimension) { |
1526 | for (int i = 1, e = inputs.size(); i < e; i++) { |
1527 | assert(sameSameShapeExceptDim(inputs[i].getType(), inputs[0].getType(), |
1528 | dimension) && |
1529 | "Invalid type" ); |
1530 | (void)sameSameShapeExceptDim; |
1531 | } |
1532 | auto inDim = inputs[0].dims(); |
1533 | |
1534 | ShapeVector shape(inDim.begin(), inDim.end()); |
1535 | |
1536 | // We are stacking the tensors along a specific dimension. This means that |
1537 | // we increase the size of the tensor along this dimension. |
1538 | shape[dimension] = 0; |
1539 | for (auto I : inputs) { |
1540 | shape[dimension] += I.getType()->dims()[dimension]; |
1541 | } |
1542 | |
1543 | auto NT = getParent()->uniqueTypeWithNewShape(inputs[0].getType(), shape); |
1544 | std::vector<NodeValue> ops; |
1545 | ops.reserve(inputs.size()); |
1546 | for (auto I : inputs) { |
1547 | ops.emplace_back(I); |
1548 | } |
1549 | return addNode(new ConcatNode(name, NT, ops, dimension)); |
1550 | } |
1551 | |
1552 | ConcatNode *Function::createConcat(llvm::StringRef name, |
1553 | llvm::ArrayRef<NodeValue> inputs, |
1554 | unsigned_t dimension, TypeRef outTy) { |
1555 | std::vector<NodeValue> ops; |
1556 | ops.reserve(inputs.size()); |
1557 | for (auto I : inputs) { |
1558 | ops.emplace_back(I); |
1559 | } |
1560 | |
1561 | TypeRef OT = getParent()->uniqueType(*outTy); |
1562 | return addNode(new ConcatNode(name, OT, ops, dimension)); |
1563 | } |
1564 | |
1565 | TileNode *Function::createTile(llvm::StringRef name, NodeValue input, |
1566 | unsigned_t tiles, unsigned_t axis, |
1567 | TypeRef outTy) { |
1568 | assert(tiles > 0 && "Tiles must be non-zero." ); |
1569 | assert(axis >= 0 && axis < input.dims().size() && |
1570 | "Axis must fall in range of source dims." ); |
1571 | |
1572 | if (outTy == nullptr) { |
1573 | ShapeVector outShape(input.dims().begin(), input.dims().end()); |
1574 | outShape[axis] *= tiles; |
1575 | outTy = getParent()->uniqueTypeWithNewShape(input.getType(), outShape); |
1576 | } |
1577 | |
1578 | return addNode(new TileNode(name, outTy, input, tiles, axis)); |
1579 | } |
1580 | |
1581 | TileNode *Function::createTile(llvm::StringRef name, NodeValue input, |
1582 | llvm::ArrayRef<unsigned_t> tiles, |
1583 | llvm::ArrayRef<unsigned_t> axes) { |
1584 | assert(tiles.size() && "The array of tiles is empty!" ); |
1585 | assert(axes.size() && "The array of axes is empty!" ); |
1586 | assert(tiles.size() == axes.size() && |
1587 | "The array for tiles and axes must be equal!" ); |
1588 | TileNode *tileNode = nullptr; |
1589 | for (size_t idx = 0; idx < tiles.size(); ++idx) { |
1590 | tileNode = createTile(name.str() + "." + std::to_string(idx), |
1591 | tileNode ? tileNode->getResult() : input, tiles[idx], |
1592 | axes[idx]); |
1593 | } |
1594 | return tileNode; |
1595 | } |
1596 | |
1597 | InsertTensorNode *Function::createInsertTensor(llvm::StringRef name, |
1598 | NodeValue big, NodeValue small, |
1599 | llvm::ArrayRef<dim_t> start, |
1600 | unsigned_t count, |
1601 | unsigned_t axis) { |
1602 | return addNode(new InsertTensorNode(name, big, small, start, count, axis)); |
1603 | } |
1604 | |
1605 | SliceNode *Function::createSlice(llvm::StringRef name, NodeValue input, |
1606 | llvm::ArrayRef<dim_t> start, TypeRef outTy) { |
1607 | assert(input.dims().size() == start.size() && |
1608 | "Start and input dims should match" ); |
1609 | assert(outTy->dims().size() == start.size() && |
1610 | "Output and start dims should match" ); |
1611 | |
1612 | for (unsigned i = 0, e = input.dims().size(); i < e; i++) { |
1613 | assert(start[i] + outTy->dims()[i] <= input.dims()[i] && |
1614 | "Input/Output/Start dims mismatch" ); |
1615 | } |
1616 | |
1617 | TypeRef OT = getParent()->uniqueType(*outTy); |
1618 | return addNode(new SliceNode(name, OT, input, start)); |
1619 | } |
1620 | |
1621 | SliceNode *Function::createSlice(llvm::StringRef name, NodeValue input, |
1622 | llvm::ArrayRef<dim_t> begin, |
1623 | llvm::ArrayRef<dim_t> end) { |
1624 | std::vector<dim_t> beginV, shape; |
1625 | auto dims = input.dims(); |
1626 | assert(begin.size() == end.size() && "Begin and End dimensions should match" ); |
1627 | assert(begin.size() == dims.size() && |
1628 | "Begin and Input dimensions should match" ); |
1629 | for (unsigned i = 0; i < dims.size(); i++) { |
1630 | dim_t beginI = begin[i]; |
1631 | dim_t endI = end[i]; |
1632 | dim_t dimI = dims[i]; |
1633 | (void)dimI; |
1634 | assert(beginI >= 0 && "Illegal Begin indices" ); |
1635 | assert(endI > 0 && "Illegal End indices" ); |
1636 | assert(beginI < dimI && "Illegal Begin indices" ); |
1637 | assert(endI <= dimI && "Illegal End indices" ); |
1638 | assert(endI > beginI && "Illegal Begin and End indices" ); |
1639 | beginV.push_back(beginI); |
1640 | shape.push_back(endI - beginI); |
1641 | } |
1642 | |
1643 | auto NT = getParent()->uniqueTypeWithNewShape(input.getType(), shape); |
1644 | return addNode(new SliceNode(name, NT, input, beginV)); |
1645 | } |
1646 | |
1647 | Node *Function::createChannelShuffle(llvm::StringRef name, NodeValue input, |
1648 | size_t group, size_t kernel) { |
1649 | return addNode( |
1650 | new ChannelShuffleNode(name, input.getType(), input, group, kernel)); |
1651 | } |
1652 | |
1653 | ReshapeNode *Function::createSqueeze(llvm::StringRef name, NodeValue input, |
1654 | llvm::ArrayRef<dim_t> axes) { |
1655 | assert(!axes.empty() && "Parameter `axes` must be provided." ); |
1656 | |
1657 | ShapeVector shapeAxes(axes.begin(), axes.end()); |
1658 | |
1659 | // Sort and unique the values in axes to |
1660 | // 1. make sure each dim is only removed once; |
1661 | // 2. check if the size and value of dimensions to squeeze are valid. |
1662 | std::sort(shapeAxes.begin(), shapeAxes.end()); |
1663 | shapeAxes.erase(std::unique(shapeAxes.begin(), shapeAxes.end()), |
1664 | shapeAxes.end()); |
1665 | auto inDims = input.dims(); |
1666 | assert(shapeAxes.back() < inDims.size() && "The size and value of dimensions " |
1667 | "to squeeze must be less than the " |
1668 | "input size." ); |
1669 | |
1670 | ShapeVector newDims; |
1671 | size_t j = 0; |
1672 | for (size_t i = 0, e = inDims.size(); i < e; i++) { |
1673 | if (j < shapeAxes.size() && shapeAxes[j] == i) { |
1674 | assert(inDims[i] == 1 && "The dimension to squeeze must be 1." ); |
1675 | j++; |
1676 | } else { |
1677 | newDims.push_back(inDims[i]); |
1678 | } |
1679 | } |
1680 | return createReshape(name.str() + ".reshape" , input, newDims); |
1681 | } |
1682 | |
1683 | ReshapeNode *Function::createExpandDims(llvm::StringRef name, NodeValue input, |
1684 | llvm::ArrayRef<dim_t> axes) { |
1685 | assert(!axes.empty() && "Parameter `axes` must be provided." ); |
1686 | |
1687 | // Dimensions provided in axes are for the output tensor, so we sort them |
1688 | // and unique them to make sure they are processed correctly and in the |
1689 | // right order. |
1690 | ShapeVector shapeAxes(axes.begin(), axes.end()); |
1691 | std::sort(shapeAxes.begin(), shapeAxes.end()); |
1692 | shapeAxes.erase(std::unique(shapeAxes.begin(), shapeAxes.end()), |
1693 | shapeAxes.end()); |
1694 | |
1695 | const auto inDims = input.dims(); |
1696 | |
1697 | // The total number of dimensions in the new shape is equal to the original |
1698 | // shape size plus the uniqued new shape axes, which represents where to |
1699 | // insert dimensions of 1 into the output tensor's shape. |
1700 | const size_t totalNumNewDims = shapeAxes.size() + inDims.size(); |
1701 | assert(totalNumNewDims <= max_tensor_dimensions && |
1702 | "New expanded shape has too many dimensions." ); |
1703 | assert(shapeAxes.back() < totalNumNewDims && |
1704 | "Specified axis expands outside size of output tensor shape." ); |
1705 | ShapeVector newDims; |
1706 | for (size_t i = 0, j = 0, k = 0; k < totalNumNewDims; k++) { |
1707 | if (j < shapeAxes.size() && shapeAxes[j] == k) { |
1708 | newDims.push_back(1); |
1709 | j++; |
1710 | } else { |
1711 | assert(i < inDims.size() && "Somehow overflowing inDims." ); |
1712 | newDims.push_back(inDims[i]); |
1713 | i++; |
1714 | } |
1715 | } |
1716 | |
1717 | // Create a reshape of the original data with the newly determined |
1718 | // dimensions. |
1719 | return createReshape(name.str() + ".expanddims" , input, newDims); |
1720 | } |
1721 | |
1722 | ReshapeNode *Function::createFlatten(llvm::StringRef name, NodeValue input, |
1723 | unsigned_t axis) { |
1724 | std::pair<dim_t, dim_t> xDim; |
1725 | if (axis == 0) { |
1726 | dim_t d = 1; |
1727 | for (auto dim : input.getType()->dims()) { |
1728 | d *= dim; |
1729 | } |
1730 | xDim = {1, d}; |
1731 | } else { |
1732 | xDim = flattenCdr(input.getType()->dims(), axis); |
1733 | } |
1734 | return createReshape(name, input, {xDim.first, xDim.second}); |
1735 | } |
1736 | |
1737 | ReshapeNode *Function::createFlattenV1(llvm::StringRef name, NodeValue input, |
1738 | unsigned_t axis) { |
1739 | auto xDim = collapseShape(input.getType()->dims(), axis); |
1740 | return createReshape(name, input, {xDim.first, xDim.second}); |
1741 | } |
1742 | |
1743 | void Function::createSplit(llvm::StringRef name, NodeValue input, |
1744 | unsigned_t outputNum, unsigned_t axis, |
1745 | llvm::ArrayRef<dim_t> split, |
1746 | std::vector<SliceNode *> &outputs) { |
1747 | auto inDims = input.dims(); |
1748 | if (split.empty()) { |
1749 | assert(inDims[axis] % outputNum == 0 && |
1750 | "Dimension to split must be divisible by outputs number." ); |
1751 | } else { |
1752 | assert(outputNum == split.size() && |
1753 | "Number of splits must be divisible by outputs number." ); |
1754 | } |
1755 | |
1756 | ShapeVector start(inDims.size(), 0); |
1757 | ShapeVector end(inDims.begin(), inDims.end()); |
1758 | end[axis] = 0; |
1759 | |
1760 | outputs.resize(outputNum); |
1761 | for (size_t i = 0; i < outputNum; i++) { |
1762 | size_t curLength = split.empty() ? inDims[axis] / outputNum : split[i]; |
1763 | end[axis] += curLength; |
1764 | outputs[i] = |
1765 | createSlice(name.str() + ".out" + std::to_string(i), input, start, end); |
1766 | start[axis] = end[axis]; |
1767 | } |
1768 | |
1769 | assert(end[axis] == inDims[axis] && |
1770 | "Total size of results must be equal to input size." ); |
1771 | } |
1772 | |
1773 | BatchNormalizationNode *Function::createBatchNormalization( |
1774 | llvm::StringRef name, TypeRef resType, NodeValue input, NodeValue beta, |
1775 | NodeValue scale, NodeValue mean, NodeValue var, unsigned_t channelIdx, |
1776 | float epsilon, float momentum) { |
1777 | return addNode(new BatchNormalizationNode(name, resType, input, scale, beta, |
1778 | mean, var, channelIdx, epsilon, |
1779 | momentum)); |
1780 | } |
1781 | |
1782 | InstanceNormalizationNode * |
1783 | Function::createInstanceNormalization(llvm::StringRef name, NodeValue input, |
1784 | NodeValue beta, NodeValue scale, |
1785 | unsigned_t channelIdx, float epsilon) { |
1786 | return addNode(new InstanceNormalizationNode(name, input, scale, beta, |
1787 | channelIdx, epsilon)); |
1788 | } |
1789 | |
1790 | LayerNormalizationNode * |
1791 | Function::createLayerNormalization(llvm::StringRef name, TypeRef outTy, |
1792 | NodeValue input, NodeValue scale, |
1793 | NodeValue bias, float epsilon) { |
1794 | return addNode( |
1795 | new LayerNormalizationNode(name, outTy, input, scale, bias, epsilon)); |
1796 | } |
1797 | |
1798 | BucketizeNode *Function::createBucketizeNode(llvm::StringRef name, |
1799 | NodeValue input, |
1800 | llvm::ArrayRef<float> boundaries) { |
1801 | auto OT = getParent()->uniqueType(ElemKind::Int32ITy, input.dims()); |
1802 | return addNode(new BucketizeNode(name, OT, input, boundaries)); |
1803 | } |
1804 | |
1805 | LocalResponseNormalizationNode *Function::createLocalResponseNormalization( |
1806 | llvm::StringRef name, NodeValue input, unsigned_t halfWindowSize, |
1807 | float alpha, float beta, float k) { |
1808 | // The output tensor is of the same shape as the input tensor. |
1809 | return addNode(new LocalResponseNormalizationNode(name, input, halfWindowSize, |
1810 | alpha, beta, k)); |
1811 | } |
1812 | |
1813 | ModuloNode *Function::createModulo(llvm::StringRef name, NodeValue input, |
1814 | int64_t divisor, bool signFollowDivisor) { |
1815 | // The output tensor is of the same shape as the input tensor. |
1816 | auto OT = getParent()->uniqueType(*input.getType()); |
1817 | return addNode(new ModuloNode(name, OT, input, divisor, signFollowDivisor)); |
1818 | } |
1819 | |
1820 | NotNode *Function::createNot(llvm::StringRef name, NodeValue input) { |
1821 | TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, input.dims()); |
1822 | return addNode(new NotNode(name, OT, input)); |
1823 | } |
1824 | |
1825 | BitwiseNotNode *Function::createBitwiseNot(llvm::StringRef name, |
1826 | NodeValue input) { |
1827 | TypeRef OT = getParent()->uniqueType(*input.getType()); |
1828 | return addNode(new BitwiseNotNode(name, OT, input)); |
1829 | } |
1830 | |
1831 | #define UNARY_ARITHMETIC_FUN_DEF(NODE_NAME_) \ |
1832 | NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \ |
1833 | NodeValue input) { \ |
1834 | return create##NODE_NAME_(name, input.getType(), input); \ |
1835 | } \ |
1836 | NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \ |
1837 | TypeRef T, NodeValue input) { \ |
1838 | TypeRef OT = getParent()->uniqueType(*T); \ |
1839 | return addNode(new NODE_NAME_##Node(name, OT, input)); \ |
1840 | } |
1841 | UNARY_ARITHMETIC_FUN_DEF(Abs) |
1842 | UNARY_ARITHMETIC_FUN_DEF(Neg) |
1843 | UNARY_ARITHMETIC_FUN_DEF(Floor) |
1844 | UNARY_ARITHMETIC_FUN_DEF(Sign) |
1845 | UNARY_ARITHMETIC_FUN_DEF(Ceil) |
1846 | UNARY_ARITHMETIC_FUN_DEF(Round) |
1847 | UNARY_ARITHMETIC_FUN_DEF(Sqrt) |
1848 | UNARY_ARITHMETIC_FUN_DEF(Rsqrt) |
1849 | UNARY_ARITHMETIC_FUN_DEF(Reciprocal) |
1850 | UNARY_ARITHMETIC_FUN_DEF(Sin) |
1851 | UNARY_ARITHMETIC_FUN_DEF(Cos) |
1852 | UNARY_ARITHMETIC_FUN_DEF(Erf) |
1853 | UNARY_ARITHMETIC_FUN_DEF(Truncate) |
1854 | UNARY_ARITHMETIC_FUN_DEF(HardSwish) |
1855 | #undef UNARY_ARITHMETIC_FUN_DEF |
1856 | |
1857 | #define ARITHMETIC_FUN_DEF(NODE_NAME_) \ |
1858 | NODE_NAME_##Node *Function::create##NODE_NAME_( \ |
1859 | llvm::StringRef name, NodeValue LHS, NodeValue RHS) { \ |
1860 | return create##NODE_NAME_(name, LHS.getType(), LHS, RHS); \ |
1861 | } \ |
1862 | NODE_NAME_##Node *Function::create##NODE_NAME_( \ |
1863 | llvm::StringRef name, TypeRef T, NodeValue LHS, NodeValue RHS) { \ |
1864 | DCHECK(LHS.dims() == RHS.dims()) \ |
1865 | << "Invalid operand shapes LHS:" << LHS.getNode()->getName().str() \ |
1866 | << " RHS: " << RHS.getNode()->getName().str() << " " << LHS.dims() \ |
1867 | << " vs " << RHS.dims(); \ |
1868 | TypeRef OT = getParent()->uniqueType(*T); \ |
1869 | return addNode(new NODE_NAME_##Node(name, OT, LHS, RHS)); \ |
1870 | } |
1871 | ARITHMETIC_FUN_DEF(Add); |
1872 | ARITHMETIC_FUN_DEF(Mul); |
1873 | ARITHMETIC_FUN_DEF(Sub); |
1874 | ARITHMETIC_FUN_DEF(Div); |
1875 | ARITHMETIC_FUN_DEF(Max); |
1876 | ARITHMETIC_FUN_DEF(Min); |
1877 | ARITHMETIC_FUN_DEF(Pow); |
1878 | ARITHMETIC_FUN_DEF(And); |
1879 | ARITHMETIC_FUN_DEF(Or); |
1880 | ARITHMETIC_FUN_DEF(Xor); |
1881 | ARITHMETIC_FUN_DEF(BitwiseAnd); |
1882 | ARITHMETIC_FUN_DEF(BitwiseOr); |
1883 | ARITHMETIC_FUN_DEF(BitwiseXor); |
1884 | ARITHMETIC_FUN_DEF(Fmod); |
1885 | #undef ARITHMETIC_FUN_DEF |
1886 | |
1887 | #define TRIGONOMETRIC_FUN_DEF(NODE_NAME_) \ |
1888 | NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \ |
1889 | NodeValue input) { \ |
1890 | return create##NODE_NAME_(name, input.getType(), input); \ |
1891 | } \ |
1892 | NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \ |
1893 | TypeRef T, NodeValue input) { \ |
1894 | TypeRef OT = getParent()->uniqueType(*T); \ |
1895 | return addNode(new NODE_NAME_##Node(name, OT, input)); \ |
1896 | } |
1897 | |
1898 | TRIGONOMETRIC_FUN_DEF(Acos) |
1899 | TRIGONOMETRIC_FUN_DEF(Asin) |
1900 | TRIGONOMETRIC_FUN_DEF(Atan) |
1901 | #undef TRIGONOMETRIC_FUN_DEF |
1902 | |
1903 | FloorDivNode *Function::createFloorDiv(llvm::StringRef name, NodeValue LHS, |
1904 | NodeValue RHS, bool truncate) { |
1905 | return createFloorDiv(name, LHS.getType(), LHS, RHS, truncate); |
1906 | } |
1907 | |
1908 | FloorDivNode *Function::createFloorDiv(llvm::StringRef name, TypeRef outTy, |
1909 | NodeValue LHS, NodeValue RHS, |
1910 | bool truncate) { |
1911 | DCHECK(LHS.dims() == RHS.dims()) |
1912 | << "Invalid operand shapes LHS:" << LHS.getNode()->getName().str() |
1913 | << " RHS: " << RHS.getNode()->getName().str() << " " << LHS.dims() |
1914 | << " vs " << RHS.dims(); |
1915 | TypeRef OT = getParent()->uniqueType(*outTy); |
1916 | return addNode(new FloorDivNode(name, OT, LHS, RHS, truncate)); |
1917 | } |
1918 | |
1919 | FloorDivNode *Function::createFloorDivWithBroadcast(llvm::StringRef name, |
1920 | int axis, NodeValue LHS, |
1921 | NodeValue RHS, |
1922 | bool truncate) { |
1923 | std::vector<NodeValue> inputs = broadcastInputs(axis, {LHS, RHS}); |
1924 | return createFloorDiv(name, inputs[0].getType(), inputs[0], inputs[1], |
1925 | truncate); |
1926 | } |
1927 | |
1928 | FloorDivNode *Function::createFloorDivWithBroadcast(llvm::StringRef name, |
1929 | int axis, TypeRef outTy, |
1930 | NodeValue LHS, |
1931 | NodeValue RHS, |
1932 | bool truncate) { |
1933 | std::vector<NodeValue> inputs = broadcastInputs(axis, {LHS, RHS}); |
1934 | return createFloorDiv(name, outTy, inputs[0], inputs[1], truncate); |
1935 | } |
1936 | |
1937 | CmpLTENode *Function::createCmpLTE(llvm::StringRef name, NodeValue LHS, |
1938 | NodeValue RHS) { |
1939 | assert(LHS.dims() == RHS.dims() && "Invalid operand shapes" ); |
1940 | TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims()); |
1941 | return addNode(new CmpLTENode(name, OT, LHS, RHS)); |
1942 | } |
1943 | |
1944 | CmpLTNode *Function::createCmpLT(llvm::StringRef name, NodeValue LHS, |
1945 | NodeValue RHS) { |
1946 | assert(LHS.dims() == RHS.dims() && "Invalid operand shapes" ); |
1947 | TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims()); |
1948 | return addNode(new CmpLTNode(name, OT, LHS, RHS)); |
1949 | } |
1950 | |
1951 | CmpLTENode *Function::createCmpGTE(llvm::StringRef name, NodeValue LHS, |
1952 | NodeValue RHS) { |
1953 | assert(LHS.dims() == RHS.dims() && "Invalid operand shapes" ); |
1954 | TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims()); |
1955 | return addNode(new CmpLTENode(name, OT, RHS, LHS)); |
1956 | } |
1957 | |
1958 | CmpLTNode *Function::createCmpGT(llvm::StringRef name, NodeValue LHS, |
1959 | NodeValue RHS) { |
1960 | assert(LHS.dims() == RHS.dims() && "Invalid operand shapes" ); |
1961 | TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims()); |
1962 | return addNode(new CmpLTNode(name, OT, RHS, LHS)); |
1963 | } |
1964 | |
1965 | CmpEQNode *Function::createCmpEQ(llvm::StringRef name, NodeValue LHS, |
1966 | NodeValue RHS) { |
1967 | assert(LHS.dims() == RHS.dims() && "Invalid operand shapes" ); |
1968 | TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims()); |
1969 | return addNode(new CmpEQNode(name, OT, LHS, RHS)); |
1970 | } |
1971 | |
1972 | CmpNEQNode *Function::createCmpNEQ(llvm::StringRef name, NodeValue LHS, |
1973 | NodeValue RHS) { |
1974 | assert(LHS.dims() == RHS.dims() && "Invalid operand shapes" ); |
1975 | TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims()); |
1976 | return addNode(new CmpNEQNode(name, OT, LHS, RHS)); |
1977 | } |
1978 | |
1979 | MulNode *Function::createSquare(llvm::StringRef name, NodeValue input) { |
1980 | return createMul(name, input, input); |
1981 | } |
1982 | |
1983 | MulNode *Function::createSquare(llvm::StringRef name, TypeRef outTy, |
1984 | NodeValue input) { |
1985 | return createMul(name, outTy, input, input); |
1986 | } |
1987 | |
1988 | LeakyReluNode *Function::createLeakyRELU(llvm::StringRef name, NodeValue input, |
1989 | float alpha) { |
1990 | return addNode(new LeakyReluNode(name, input.getType(), input, alpha)); |
1991 | } |
1992 | |
1993 | LeakyReluNode *Function::createLeakyRELU(llvm::StringRef name, TypeRef outTy, |
1994 | NodeValue input, float alpha) { |
1995 | return addNode(new LeakyReluNode(name, outTy, input, alpha)); |
1996 | } |
1997 | |
1998 | IsNaNNode *Function::createIsNaN(llvm::StringRef name, NodeValue input) { |
1999 | TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, input.dims()); |
2000 | return addNode(new IsNaNNode(name, OT, input)); |
2001 | } |
2002 | |
2003 | ReplaceNaNNode *Function::createReplaceNaN(llvm::StringRef name, |
2004 | NodeValue input, float value) { |
2005 | return addNode(new ReplaceNaNNode(name, input.getType(), input, value)); |
2006 | } |
2007 | |
2008 | PowNode *Function::createPow(llvm::StringRef name, NodeValue base, float exp) { |
2009 | auto *SP = createSplat(name, base.getType(), exp); |
2010 | return createPow(name, base, SP); |
2011 | } |
2012 | |
2013 | LogNode *Function::createLog(llvm::StringRef name, NodeValue input) { |
2014 | return createLog(name, input.getType(), input); |
2015 | } |
2016 | |
2017 | LogNode *Function::createLog(llvm::StringRef name, TypeRef outTy, |
2018 | NodeValue input) { |
2019 | return addNode(new LogNode(name, outTy, input)); |
2020 | } |
2021 | |
2022 | LogNode *Function::createLog(llvm::StringRef name, NodeValue input, |
2023 | TypeRef outTy) { |
2024 | return createLog(name, outTy, input); |
2025 | } |
2026 | |
2027 | ExpNode *Function::createExp(llvm::StringRef name, NodeValue input) { |
2028 | return addNode(new ExpNode(name, input.getType(), input)); |
2029 | } |
2030 | |
2031 | ExpNode *Function::createExp(llvm::StringRef name, TypeRef outTy, |
2032 | NodeValue input) { |
2033 | return addNode(new ExpNode(name, outTy, input)); |
2034 | } |
2035 | |
2036 | LogitNode *Function::createLogit(llvm::StringRef name, NodeValue input, |
2037 | float eps) { |
2038 | return addNode(new LogitNode(name, input.getType(), input, eps)); |
2039 | } |
2040 | |
2041 | NonZeroNode *Function::createNonZero(llvm::StringRef name, NodeValue Cond) { |
2042 | auto outTy = getParent()->uniqueType(ElemKind::Int32ITy, {Cond.dims()[0], 1}); |
2043 | return addNode(new NonZeroNode(name, outTy, Cond)); |
2044 | } |
2045 | |
2046 | SelectNode *Function::createSelect(llvm::StringRef name, TypeRef outTy, |
2047 | NodeValue Cond, NodeValue LHS, |
2048 | NodeValue RHS) { |
2049 | assert(LHS.dims() == RHS.dims() && "Invalid operand shapes" ); |
2050 | assert(LHS.dims() == Cond.dims() && "Invalid operand shapes" ); |
2051 | assert(LHS.dims() == outTy->dims() && "Invalid result shape" ); |
2052 | auto OT = getParent()->uniqueType(*outTy); |
2053 | return addNode(new SelectNode(name, OT, Cond, LHS, RHS)); |
2054 | } |
2055 | |
2056 | SelectNode *Function::createSelect(llvm::StringRef name, NodeValue Cond, |
2057 | NodeValue LHS, NodeValue RHS) { |
2058 | auto inDims = LHS.dims(); |
2059 | assert(inDims.size() > 0); |
2060 | ShapeVector outDims(inDims.begin(), inDims.end()); |
2061 | auto OT = getParent()->uniqueTypeWithNewShape(LHS.getType(), outDims); |
2062 | return createSelect(name, OT, Cond, LHS, RHS); |
2063 | } |
2064 | |
2065 | SplatNode *Function::createSplat(llvm::StringRef name, TypeRef ty, |
2066 | float value) { |
2067 | return addNode(new SplatNode(name, getParent()->uniqueType(*ty), value)); |
2068 | } |
2069 | |
2070 | TouchNode *Function::createTouch(llvm::StringRef name, TypeRef ty) { |
2071 | return addNode(new TouchNode(name, getParent()->uniqueType(*ty))); |
2072 | } |
2073 | |
2074 | MatMulNode *Function::createMatMul(llvm::StringRef name, TypeRef outTy, |
2075 | NodeValue lhs, NodeValue rhs) { |
2076 | return addNode( |
2077 | new MatMulNode(name, getParent()->uniqueType(*outTy), lhs, rhs)); |
2078 | } |
2079 | |
2080 | MatMulNode *Function::createMatMul(llvm::StringRef name, NodeValue lhs, |
2081 | NodeValue rhs) { |
2082 | auto LT = lhs.getType(); |
2083 | auto RT = rhs.getType(); |
2084 | auto LDims = LT->dims(); |
2085 | auto RDims = RT->dims(); |
2086 | assert(lhs.getType()->getElementType() == rhs.getType()->getElementType()); |
2087 | |
2088 | auto ty = |
2089 | getParent()->uniqueTypeWithNewShape(lhs.getType(), {LDims[0], RDims[1]}); |
2090 | return createMatMul(name, ty, lhs, rhs); |
2091 | } |
2092 | |
2093 | BatchMatMulNode *Function::createBatchMatMul(llvm::StringRef name, |
2094 | NodeValue LHS, NodeValue RHS) { |
2095 | const size_t numDimsLHS = LHS.dims().size(); |
2096 | if (numDimsLHS > 3) { |
2097 | const size_t numDimsRHS = RHS.dims().size(); |
2098 | std::vector<dim_t> newLHSShape = {0, LHS.dims()[numDimsLHS - 2], |
2099 | LHS.dims()[numDimsLHS - 1]}; |
2100 | newLHSShape[0] = LHS.getType()->size() / (newLHSShape[1] * newLHSShape[2]); |
2101 | LHS = createReshape(name.str() + ".reshapeLHS3D" , LHS, newLHSShape); |
2102 | std::vector<dim_t> newRHSShape = {0, RHS.dims()[numDimsRHS - 2], |
2103 | RHS.dims()[numDimsRHS - 1]}; |
2104 | newRHSShape[0] = RHS.getType()->size() / (newRHSShape[1] * newRHSShape[2]); |
2105 | RHS = createReshape(name.str() + ".reshapeRHS3D" , RHS, newRHSShape); |
2106 | } |
2107 | |
2108 | const size_t numDimsRHS = RHS.dims().size(); |
2109 | assert(LHS.dims().size() == 3 && "LHS must be 3 dimensional." ); |
2110 | assert((numDimsRHS == 2 || numDimsRHS == 3) && |
2111 | "RHS must be 2 or 3 dimensional." ); |
2112 | // If necessary, expand the RHS input to be 3D by adding initial leading |
2113 | // dim. |
2114 | if (numDimsRHS == 2) { |
2115 | RHS = createExpandDims(name.str() + ".reshapeRHS" , RHS, {0}); |
2116 | } |
2117 | // If necessary, Tile the RHS input so it matches the numBatches of LHS. |
2118 | if (RHS.dims()[0] == 1 && LHS.dims()[0] != 1) { |
2119 | RHS = createTile(name.str() + ".tileRHS" , RHS, LHS.dims()[0], /*axis */ 0); |
2120 | } |
2121 | |
2122 | // LHS = {numBatches, N, M} |
2123 | // RHS = {numBatches, M, P} |
2124 | // Result = {numBatches, N, P} |
2125 | const dim_t numBatches = LHS.dims()[0]; |
2126 | const dim_t N = LHS.dims()[1]; |
2127 | const dim_t M = LHS.dims()[2]; |
2128 | (void)M; |
2129 | const dim_t P = RHS.dims()[2]; |
2130 | assert((RHS.dims()[0] == numBatches) && "Batch sizes are invalid." ); |
2131 | assert((RHS.dims()[1] == M) && "Batch matmul dimensions are invalid." ); |
2132 | |
2133 | auto OT = |
2134 | getParent()->uniqueTypeWithNewShape(LHS.getType(), {numBatches, N, P}); |
2135 | return addNode(new BatchMatMulNode(name, OT, LHS, RHS)); |
2136 | } |
2137 | |
2138 | BatchedReduceAddNode * |
2139 | Function::createBatchedReduceAdd(llvm::StringRef name, TypeRef outTy, |
2140 | NodeValue batch, |
2141 | llvm::ArrayRef<unsigned_t> axes) { |
2142 | assert(axes.size() == 1 && "Only supporting single reduction for now." ); |
2143 | auto axis = axes[0]; |
2144 | |
2145 | // Calculate the expected total number of elements in the output tensor |
2146 | // based on the number of elements in the batch divided by the axis |
2147 | // dimension. |
2148 | const size_t outNumElements = batch.getType()->size() / batch.dims()[axis]; |
2149 | (void)outNumElements; |
2150 | assert(outTy->size() == outNumElements && |
2151 | "Incorrect number of elements in the output type." ); |
2152 | auto OT = getParent()->uniqueType(*outTy); |
2153 | return addNode(new BatchedReduceAddNode(name, OT, batch, axis)); |
2154 | } |
2155 | |
2156 | BatchedReduceSumSquareNode * |
2157 | Function::createBatchedReduceSumSquare(llvm::StringRef name, NodeValue batch, |
2158 | llvm::ArrayRef<unsigned_t> axes) { |
2159 | auto outDims = getNewShapeWithoutAxes(batch.dims(), axes); |
2160 | auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims); |
2161 | return createBatchedReduceSumSquare(name, OT, batch, axes); |
2162 | } |
2163 | |
2164 | BatchedReduceSumSquareNode * |
2165 | Function::createBatchedReduceSumSquare(llvm::StringRef name, TypeRef outTy, |
2166 | NodeValue batch, |
2167 | llvm::ArrayRef<unsigned_t> axes) { |
2168 | assert(axes.size() == 1 && "Only supporting single reduction for now." ); |
2169 | auto axis = axes[0]; |
2170 | |
2171 | // Calculate the expected total number of elements in the output tensor |
2172 | // based on the number of elements in the batch divided by the axis |
2173 | // dimension. |
2174 | const size_t outNumElements = batch.getType()->size() / batch.dims()[axis]; |
2175 | (void)outNumElements; |
2176 | assert(outTy->size() == outNumElements && |
2177 | "Incorrect number of elements in the output type." ); |
2178 | auto OT = getParent()->uniqueType(*outTy); |
2179 | return addNode(new BatchedReduceSumSquareNode(name, OT, batch, axis)); |
2180 | } |
2181 | |
2182 | BatchedReduceAddNode * |
2183 | Function::createBatchedReduceAdd(llvm::StringRef name, NodeValue batch, |
2184 | llvm::ArrayRef<unsigned_t> axes) { |
2185 | auto outDims = getNewShapeWithoutAxes(batch.dims(), axes); |
2186 | auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims); |
2187 | return createBatchedReduceAdd(name, OT, batch, axes); |
2188 | } |
2189 | |
2190 | BatchedReduceMeanNode * |
2191 | Function::createBatchedReduceMean(llvm::StringRef name, TypeRef outTy, |
2192 | NodeValue batch, |
2193 | llvm::ArrayRef<unsigned_t> axes) { |
2194 | auto OT = getParent()->uniqueType(*outTy); |
2195 | return addNode(new BatchedReduceMeanNode(name, OT, batch, axes)); |
2196 | } |
2197 | |
2198 | BatchedReduceMeanNode * |
2199 | Function::createBatchedReduceMean(llvm::StringRef name, NodeValue batch, |
2200 | llvm::ArrayRef<unsigned_t> axes) { |
2201 | // Create new shape with specified dimensions either reduced or removed. |
2202 | auto outDims = getNewShapeWithoutAxes(batch.dims(), axes); |
2203 | auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims); |
2204 | return createBatchedReduceMean(name, OT, batch, axes); |
2205 | } |
2206 | |
2207 | BatchedReduceMinNode * |
2208 | Function::createBatchedReduceMin(llvm::StringRef name, TypeRef outTy, |
2209 | NodeValue batch, |
2210 | llvm::ArrayRef<unsigned_t> axes) { |
2211 | auto OT = getParent()->uniqueType(*outTy); |
2212 | return addNode(new BatchedReduceMinNode(name, OT, batch, axes)); |
2213 | } |
2214 | |
2215 | BatchedReduceMinNode * |
2216 | Function::createBatchedReduceMin(llvm::StringRef name, NodeValue batch, |
2217 | llvm::ArrayRef<unsigned_t> axes) { |
2218 | // Create new shape with specified dimensions either reduced or removed. |
2219 | auto outDims = getNewShapeWithoutAxes(batch.dims(), axes); |
2220 | auto OT = getParent()->uniqueType(batch.getType()->getElementType(), outDims); |
2221 | return addNode(new BatchedReduceMinNode(name, OT, batch, axes)); |
2222 | } |
2223 | |
2224 | BatchedReduceMaxNode * |
2225 | Function::createBatchedReduceMax(llvm::StringRef name, TypeRef outTy, |
2226 | NodeValue batch, |
2227 | llvm::ArrayRef<unsigned_t> axes) { |
2228 | auto OT = getParent()->uniqueType(*outTy); |
2229 | return addNode(new BatchedReduceMaxNode(name, OT, batch, axes)); |
2230 | } |
2231 | |
2232 | BatchedReduceMaxNode * |
2233 | Function::createBatchedReduceMax(llvm::StringRef name, NodeValue batch, |
2234 | llvm::ArrayRef<unsigned_t> axes) { |
2235 | // Create new shape with specified dimensions either reduced or removed. |
2236 | auto outDims = getNewShapeWithoutAxes(batch.dims(), axes); |
2237 | auto OT = getParent()->uniqueType(batch.getType()->getElementType(), outDims); |
2238 | return addNode(new BatchedReduceMaxNode(name, OT, batch, axes)); |
2239 | } |
2240 | |
2241 | BatchedReduceProdNode * |
2242 | Function::createBatchedReduceProd(llvm::StringRef name, TypeRef outTy, |
2243 | NodeValue batch, |
2244 | llvm::ArrayRef<unsigned_t> axes) { |
2245 | assert(axes.size() == 1 && "Only supporting single reduction for now." ); |
2246 | auto axis = axes[0]; |
2247 | |
2248 | // Calculate the expected total number of elements in the output tensor |
2249 | // based on the number of elements in the batch divided by the axis |
2250 | // dimension. |
2251 | const size_t outNumElements = batch.getType()->size() / batch.dims()[axis]; |
2252 | (void)outNumElements; |
2253 | assert(outTy->size() == outNumElements && |
2254 | "Incorrect number of elements in the output type." ); |
2255 | auto OT = getParent()->uniqueType(*outTy); |
2256 | return addNode(new BatchedReduceProdNode(name, OT, batch, axis)); |
2257 | } |
2258 | |
2259 | BatchedReduceProdNode * |
2260 | Function::createBatchedReduceProd(llvm::StringRef name, NodeValue batch, |
2261 | llvm::ArrayRef<unsigned_t> axes) { |
2262 | auto outDims = getNewShapeWithoutAxes(batch.dims(), axes); |
2263 | auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims); |
2264 | return createBatchedReduceProd(name, OT, batch, axes); |
2265 | } |
2266 | |
2267 | BatchedAddNode *Function::createBatchedAdd(llvm::StringRef name, |
2268 | NodeValue batch, NodeValue slice) { |
2269 | return addNode(new BatchedAddNode(name, batch.getType(), batch, slice)); |
2270 | } |
2271 | |
2272 | BatchedAddNode *Function::createBatchedAdd(llvm::StringRef name, TypeRef outTy, |
2273 | NodeValue batch, NodeValue slice) { |
2274 | return addNode( |
2275 | new BatchedAddNode(name, getParent()->uniqueType(*outTy), batch, slice)); |
2276 | } |
2277 | |
2278 | BatchedMulNode *Function::createBatchedMul(llvm::StringRef name, |
2279 | NodeValue batch, NodeValue slice) { |
2280 | return addNode(new BatchedMulNode(name, batch.getType(), batch, slice)); |
2281 | } |
2282 | |
2283 | BatchedMulNode *Function::createBatchedMul(llvm::StringRef name, TypeRef outTy, |
2284 | NodeValue batch, NodeValue slice) { |
2285 | return addNode( |
2286 | new BatchedMulNode(name, getParent()->uniqueType(*outTy), batch, slice)); |
2287 | } |
2288 | |
2289 | CumSumNode *Function::createCumSum(llvm::StringRef name, NodeValue input, |
2290 | int64_t dim, bool exclusive, bool reverse) { |
2291 | return addNode( |
2292 | new CumSumNode(name, input.getType(), input, dim, exclusive, reverse)); |
2293 | } |
2294 | |
2295 | LengthsSumNode *Function::createLengthsSum(llvm::StringRef name, NodeValue data, |
2296 | NodeValue lengths) { |
2297 | ShapeVector outDims(data.dims().begin(), data.dims().end()); |
2298 | outDims[0] = lengths.dims()[0]; |
2299 | auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims); |
2300 | return addNode(new LengthsSumNode(name, outTy, data, lengths)); |
2301 | } |
2302 | |
2303 | SparseLengthsSumNode * |
2304 | Function::createSparseLengthsSum(llvm::StringRef name, NodeValue data, |
2305 | NodeValue indices, NodeValue lengths, |
2306 | LengthsMode lengthsMode, float avgLength) { |
2307 | auto inDims = data.dims(); |
2308 | ShapeVector outDims(inDims.begin(), inDims.end()); |
2309 | outDims[0] = lengths.dims()[0]; |
2310 | auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims); |
2311 | return addNode(new SparseLengthsSumNode(name, outTy, data, indices, lengths, |
2312 | lengthsMode, avgLength)); |
2313 | } |
2314 | |
2315 | SparseLengthsWeightedSumNode *Function::createSparseLengthsWeightedSum( |
2316 | llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices, |
2317 | NodeValue lengths, LengthsMode lengthsMode, float avgLength) { |
2318 | auto inDims = data.dims(); |
2319 | ShapeVector outDims(inDims.begin(), inDims.end()); |
2320 | outDims[0] = lengths.dims()[0]; |
2321 | auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims); |
2322 | return addNode(new SparseLengthsWeightedSumNode( |
2323 | name, outTy, data, weights, indices, lengths, lengthsMode, avgLength)); |
2324 | } |
2325 | |
2326 | SparseLengthsWeightedSumNode *Function::createSparseLengthsWeightedSum( |
2327 | llvm::StringRef name, TypeRef outTy, NodeValue data, NodeValue weights, |
2328 | NodeValue indices, NodeValue lengths, LengthsMode lengthsMode, |
2329 | float avgLength) { |
2330 | return addNode(new SparseLengthsWeightedSumNode( |
2331 | name, outTy, data, weights, indices, lengths, lengthsMode, avgLength)); |
2332 | } |
2333 | |
2334 | RowwiseQuantizedSparseLengthsWeightedSumNode * |
2335 | Function::createRowwiseQuantizedSparseLengthsWeightedSum( |
2336 | llvm::StringRef name, Storage *data, NodeValue scales, NodeValue offsets, |
2337 | NodeValue weights, NodeValue indices, NodeValue lengths, ElemKind precision, |
2338 | bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) { |
2339 | auto inDims = data->dims(); |
2340 | ShapeVector outDims(inDims.begin(), inDims.end()); |
2341 | outDims[0] = lengths.dims()[0]; |
2342 | auto outTy = getParent()->uniqueType(precision, outDims); |
2343 | return addNode(new RowwiseQuantizedSparseLengthsWeightedSumNode( |
2344 | name, outTy, data, scales, offsets, weights, indices, lengths, |
2345 | useFP16Accumulation, lengthsMode, avgLength)); |
2346 | } |
2347 | |
2348 | RowwiseQuantizedSparseLengthsWeightedSumNode * |
2349 | Function::createRowwiseQuantizedSparseLengthsSum( |
2350 | llvm::StringRef name, Storage *data, NodeValue scales, NodeValue offsets, |
2351 | NodeValue indices, NodeValue lengths, ElemKind precision, |
2352 | bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) { |
2353 | auto ty = getParent()->uniqueType(precision, {indices.dims()[0]}); |
2354 | auto ones = createSplat(name.str() + ".ones" , ty, 1.0); |
2355 | return createRowwiseQuantizedSparseLengthsWeightedSum( |
2356 | name, data, scales, offsets, ones, indices, lengths, precision, |
2357 | useFP16Accumulation, lengthsMode, avgLength); |
2358 | } |
2359 | |
2360 | /// Helper to create a RowwiseQuantizedSparseLengthsWeightedSumNode in the |
2361 | /// Function \p F with \p name, using \ data, \p weights, \p indices, and \p |
2362 | /// lengths as inputs. The provided float data in \p Tensor is rowwise |
2363 | /// quantized, creating Constants for the rowwise quantized data as well as |
2364 | /// Scales and Offsets, in the Module containing \p F. |
2365 | static RowwiseQuantizedSparseLengthsWeightedSumNode * |
2366 | quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum( |
2367 | Function *F, llvm::StringRef name, Tensor &data, NodeValue weights, |
2368 | NodeValue indices, NodeValue lengths, quantization::Schema schema, |
2369 | ElemKind precision, bool useFP16Accumulation, LengthsMode lengthsMode, |
2370 | float avgLength) { |
2371 | auto inDims = data.dims(); |
2372 | |
2373 | // Note: In rwqData, we are using a quantized type, however the scale/offset |
2374 | // are set to dummy values 0.0/0. This is because the actually used |
2375 | // scale/offset come from dataScales and dataOffsets. |
2376 | Constant *rwqData = F->getParent()->createConstant(ElemKind::UInt8QTy, inDims, |
2377 | 0.0, 0, "data" ); |
2378 | Constant *dataScales = |
2379 | F->getParent()->createConstant(precision, {inDims[0]}, "dataScales" ); |
2380 | Constant *dataOffsets = |
2381 | F->getParent()->createConstant(precision, {inDims[0]}, "dataOffsets" ); |
2382 | |
2383 | // Note: Using floating point offset here as that is what RWQ-SLWS expects. |
2384 | switch (precision) { |
2385 | case ElemKind::FloatTy: |
2386 | quantization::tensorRowwiseQuantization<float, float, uint8_t>( |
2387 | data, rwqData->getPayloadMutable(), dataScales->getPayloadMutable(), |
2388 | dataOffsets->getPayloadMutable(), schema); |
2389 | break; |
2390 | case ElemKind::Float16Ty: |
2391 | quantization::tensorRowwiseQuantization<float16_t, float16_t, uint8_t>( |
2392 | data, rwqData->getPayloadMutable(), dataScales->getPayloadMutable(), |
2393 | dataOffsets->getPayloadMutable(), schema); |
2394 | break; |
2395 | default: |
2396 | LOG(FATAL) << "Unsupported precision for RWQ-SLWS." ; |
2397 | } |
2398 | return F->createRowwiseQuantizedSparseLengthsWeightedSum( |
2399 | name, rwqData, dataScales, dataOffsets, weights, indices, lengths, |
2400 | precision, useFP16Accumulation, lengthsMode, avgLength); |
2401 | } |
2402 | |
2403 | RowwiseQuantizedSparseLengthsWeightedSumNode * |
2404 | Function::createRowwiseQuantizedSparseLengthsWeightedSum( |
2405 | llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices, |
2406 | NodeValue lengths, quantization::Schema schema, ElemKind precision, |
2407 | bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) { |
2408 | return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum( |
2409 | this, name, data, weights, indices, lengths, schema, precision, |
2410 | useFP16Accumulation, lengthsMode, avgLength); |
2411 | } |
2412 | |
2413 | RowwiseQuantizedSparseLengthsWeightedSumNode * |
2414 | Function::createRowwiseQuantizedSparseLengthsSum( |
2415 | llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths, |
2416 | quantization::Schema schema, ElemKind precision, bool useFP16Accumulation, |
2417 | LengthsMode lengthsMode, float avgLength) { |
2418 | auto ty = getParent()->uniqueType(precision, {indices.dims()[0]}); |
2419 | auto ones = createSplat(name.str() + ".ones" , ty, 1.0); |
2420 | return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum( |
2421 | this, name, data, ones, indices, lengths, schema, precision, |
2422 | useFP16Accumulation, lengthsMode, avgLength); |
2423 | } |
2424 | |
2425 | /// Helper used to get specific output type required for |
2426 | /// createRowwiseQuantizedSparseLengthsSum, |
2427 | /// createRowwiseQuantizedSparseLengthsWeightedSum, and |
2428 | /// EmbeddingBagByteRowwiseOffsets. Function \p F is used to get the specific |
2429 | /// type, using inputs \p data and \p segmentsDim to compute output dimensions. |
2430 | static TypeRef |
2431 | getOutputTypeOfFusedRowwiseQuantizedSLS(Function *F, NodeValue data, |
2432 | llvm::ArrayRef<dim_t> segmentsDim) { |
2433 | ShapeVector outDims(data.dims().begin(), data.dims().end()); |
2434 | outDims[0] = segmentsDim[0]; |
2435 | // The output column count is the same as the input column count, but |
2436 | // without the extra bytes for the fused scale/offset, as the output is not |
2437 | // fused. |
2438 | CHECK(isFusedQuantizedElemKind(data.getElementType())) |
2439 | << "Must use a fused ElemKind for data." ; |
2440 | outDims[1] -= 2 * ((data.getElementType() == ElemKind::UInt8FusedQTy || |
2441 | data.getElementType() == ElemKind::UInt4FusedQTy) |
2442 | ? sizeof(float) |
2443 | : sizeof(float16_t)); |
2444 | // If using 4-bit quantization, then the input data has packed two 4-bit |
2445 | // elements into one byte, so we need to double the outDims. |
2446 | if (data.getElementType() == ElemKind::UInt4FusedFP16QTy || |
2447 | data.getElementType() == ElemKind::UInt4FusedQTy) { |
2448 | outDims[1] *= 2; |
2449 | } |
2450 | const ElemKind outputK = (data.getElementType() == ElemKind::UInt8FusedQTy || |
2451 | data.getElementType() == ElemKind::UInt4FusedQTy) |
2452 | ? ElemKind::FloatTy |
2453 | : ElemKind::Float16Ty; |
2454 | return F->getParent()->uniqueType(outputK, outDims); |
2455 | } |
2456 | |
2457 | FusedRowwiseQuantizedSparseLengthsWeightedSumNode * |
2458 | Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
2459 | llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices, |
2460 | NodeValue lengths, bool useFP16Accumulation, LengthsMode lengthsMode, |
2461 | float avgLength) { |
2462 | auto outTy = |
2463 | getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims()); |
2464 | return addNode(new FusedRowwiseQuantizedSparseLengthsWeightedSumNode( |
2465 | name, outTy, data, weights, indices, lengths, useFP16Accumulation, |
2466 | lengthsMode, avgLength)); |
2467 | } |
2468 | |
2469 | FusedRowwiseQuantizedSparseLengthsSumNode * |
2470 | Function::createFusedRowwiseQuantizedSparseLengthsSum( |
2471 | llvm::StringRef name, Storage *data, NodeValue indices, NodeValue lengths, |
2472 | bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) { |
2473 | auto outTy = |
2474 | getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims()); |
2475 | return addNode(new FusedRowwiseQuantizedSparseLengthsSumNode( |
2476 | name, outTy, data, indices, lengths, useFP16Accumulation, lengthsMode, |
2477 | avgLength)); |
2478 | } |
2479 | |
2480 | /// Helper to get quantized data required for |
2481 | /// RowwiseQuantizedSparseLengthsWeightedSumNode and |
2482 | /// RowwiseQuantizedSparseLengthsSumNode. Function \p F uses float Tensor \p |
2483 | /// data to create a rowwise qunatized Constant \p rwqData, which contains fused |
2484 | /// scales and offsets. |
2485 | static Constant *quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum( |
2486 | Function *F, Tensor &data, ElemKind precision) { |
2487 | // For fused rowwise quantization, we must have a two-dimensional input. If |
2488 | // passed in a single dimensional data Tensor then add an extra dimension. |
2489 | const auto fDims = flattenCdr(data.dims()); |
2490 | Tensor fData = data.getUnowned({fDims.first, fDims.second}); |
2491 | |
2492 | // Note: In rwqData, we are using a quantized type, however the scale/offset |
2493 | // are set to dummy values 0.0/0. This is because the actually used |
2494 | // scale/offset are fused inline with each row. Also, we expand the second |
2495 | // dimension to include space for the scale/offset, each 4 bytes |
2496 | // (float/int32_t). |
2497 | switch (precision) { |
2498 | case ElemKind::UInt8FusedQTy: { |
2499 | Constant *rwqData = F->getParent()->createConstant( |
2500 | precision, {fDims.first, fDims.second + 2 * (dim_t)sizeof(float)}, 0.0, |
2501 | 0, "data" ); |
2502 | quantization::tensorFusedRowwiseQuantization<float>( |
2503 | fData, rwqData->getPayloadMutable()); |
2504 | return rwqData; |
2505 | } |
2506 | case ElemKind::UInt8FusedFP16QTy: { |
2507 | Constant *rwqData = F->getParent()->createConstant( |
2508 | precision, {fDims.first, fDims.second + 2 * (dim_t)sizeof(float16_t)}, |
2509 | 0.0, 0, "data" ); |
2510 | quantization::tensorFusedRowwiseQuantization<float16_t>( |
2511 | fData, rwqData->getPayloadMutable()); |
2512 | return rwqData; |
2513 | } |
2514 | case ElemKind::UInt4FusedFP16QTy: { |
2515 | // We pack 4-bit values into bytes, so given the input size in float we |
2516 | // divide by two and take the ceiling to make sure we have enough space for |
2517 | // all elements. |
2518 | const dim_t outerDim = |
2519 | std::ceil(((float)fDims.second) / 2) + 2 * sizeof(float16_t); |
2520 | Constant *rwqData = F->getParent()->createConstant( |
2521 | precision, {fDims.first, outerDim}, 0.0, 0, "data" ); |
2522 | quantization::tensorFusedRowwiseQuantization<float16_t>( |
2523 | fData, rwqData->getPayloadMutable()); |
2524 | return rwqData; |
2525 | } |
2526 | case ElemKind::UInt4FusedQTy: { |
2527 | // We pack 4-bit values into bytes, so given the input size in float we |
2528 | // divide by two and take the ceiling to make sure we have enough space for |
2529 | // all elements. |
2530 | const dim_t outerDim = |
2531 | std::ceil(((float)fDims.second) / 2) + 2 * sizeof(float); |
2532 | Constant *rwqData = F->getParent()->createConstant( |
2533 | precision, {fDims.first, outerDim}, 0.0, 0, "data" ); |
2534 | quantization::tensorFusedRowwiseQuantization<float>( |
2535 | fData, rwqData->getPayloadMutable()); |
2536 | return rwqData; |
2537 | } |
2538 | default: |
2539 | llvm_unreachable("Invalid type for FusedRowwiswQuantization." ); |
2540 | } |
2541 | } |
2542 | |
2543 | FusedRowwiseQuantizedSparseLengthsWeightedSumNode * |
2544 | Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
2545 | llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices, |
2546 | NodeValue lengths, ElemKind fusedElemKind, bool useFP16Accumulation, |
2547 | LengthsMode lengthsMode, float avgLength) { |
2548 | Constant *rwqData = |
2549 | quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum( |
2550 | this, data, fusedElemKind); |
2551 | return createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
2552 | name, rwqData, weights, indices, lengths, useFP16Accumulation, |
2553 | lengthsMode, avgLength); |
2554 | } |
2555 | |
2556 | FusedRowwiseQuantizedSparseLengthsSumNode * |
2557 | Function::createFusedRowwiseQuantizedSparseLengthsSum( |
2558 | llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths, |
2559 | ElemKind fusedElemKind, bool useFP16Accumulation, LengthsMode lengthsMode, |
2560 | float avgLength) { |
2561 | Constant *rwqData = |
2562 | quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum( |
2563 | this, data, fusedElemKind); |
2564 | return this->createFusedRowwiseQuantizedSparseLengthsSum( |
2565 | name, rwqData, indices, lengths, useFP16Accumulation, lengthsMode, |
2566 | avgLength); |
2567 | } |
2568 | |
2569 | EmbeddingNode *Function::createEmbedding(llvm::StringRef name, |
2570 | NodeValue weights, NodeValue indices, |
2571 | int32_t padIdx, bool scale, |
2572 | bool sparse) { |
2573 | auto indDims = indices.dims(); |
2574 | auto wtDims = weights.dims(); |
2575 | |
2576 | assert(wtDims.size() == 2 && "weights must be a 2D tensor" ); |
2577 | |
2578 | ShapeVector outDims(indDims.begin(), indDims.end()); |
2579 | dim_t embedding_dim = wtDims[1]; |
2580 | outDims.push_back(embedding_dim); |
2581 | |
2582 | auto outTy = getParent()->uniqueTypeWithNewShape(weights.getType(), outDims); |
2583 | return addNode( |
2584 | new EmbeddingNode(name, outTy, weights, indices, padIdx, scale, sparse)); |
2585 | } |
2586 | |
2587 | EmbeddingBagNode * |
2588 | Function::createEmbeddingBag(llvm::StringRef name, NodeValue data, |
2589 | NodeValue weights, NodeValue indices, |
2590 | NodeValue offsets, bool hasEndOffset, |
2591 | LengthsMode lengthsMode, float avgLength) { |
2592 | auto inDims = data.dims(); |
2593 | ShapeVector outDims(inDims.begin(), inDims.end()); |
2594 | outDims[0] = hasEndOffset ? offsets.dims()[0] - 1 : offsets.dims()[0]; |
2595 | auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims); |
2596 | return addNode(new EmbeddingBagNode(name, outTy, data, weights, indices, |
2597 | offsets, hasEndOffset, lengthsMode, |
2598 | avgLength)); |
2599 | } |
2600 | |
2601 | EmbeddingBagByteRowwiseOffsetsNode * |
2602 | Function::createEmbeddingBagByteRowwiseOffsets( |
2603 | llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices, |
2604 | NodeValue offsets, ElemKind fusedElemKind, bool useFP16Accumulation, |
2605 | bool hasEndOffset, LengthsMode lengthsMode, float avgLength) { |
2606 | Constant *rwqData = |
2607 | quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum( |
2608 | this, data, fusedElemKind); |
2609 | return createEmbeddingBagByteRowwiseOffsets( |
2610 | name, rwqData, weights, indices, offsets, useFP16Accumulation, |
2611 | hasEndOffset, lengthsMode, avgLength); |
2612 | } |
2613 | |
2614 | EmbeddingBagByteRowwiseOffsetsNode * |
2615 | Function::createEmbeddingBagByteRowwiseOffsets( |
2616 | llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices, |
2617 | NodeValue offsets, bool useFP16Accumulation, bool hasEndOffset, |
2618 | LengthsMode lengthsMode, float avgLength) { |
2619 | std::vector<dim_t> segmentDims(offsets.dims().begin(), offsets.dims().end()); |
2620 | // If hasEndOffset the last offset is just for marking the end of the last |
2621 | // segment. |
2622 | if (hasEndOffset) { |
2623 | segmentDims[0] -= 1; |
2624 | } |
2625 | auto outTy = getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, segmentDims); |
2626 | return addNode(new EmbeddingBagByteRowwiseOffsetsNode( |
2627 | name, outTy, data, weights, indices, offsets, useFP16Accumulation, |
2628 | hasEndOffset, lengthsMode, avgLength)); |
2629 | } |
2630 | |
2631 | LengthsToRangesNode *Function::createLengthsToRanges(llvm::StringRef name, |
2632 | NodeValue lengths) { |
2633 | ShapeVector outDims({lengths.dims()[0], 2}); |
2634 | auto outTy = getParent()->uniqueTypeWithNewShape(lengths.getType(), outDims); |
2635 | return addNode(new LengthsToRangesNode(name, outTy, lengths)); |
2636 | } |
2637 | |
2638 | LengthsRangeFillNode * |
2639 | Function::createLengthsRangeFill(llvm::StringRef name, NodeValue lengths, |
2640 | unsigned_t maxOutputSize) { |
2641 | auto outTy = |
2642 | getParent()->uniqueTypeWithNewShape(lengths.getType(), {maxOutputSize}); |
2643 | return addNode(new LengthsRangeFillNode(name, outTy, lengths)); |
2644 | } |
2645 | |
2646 | GaussianFillNode *Function::createGaussianFill(llvm::StringRef name, |
2647 | NodeValue input, float mean, |
2648 | float scale, float seed) { |
2649 | auto outTy = getParent()->uniqueType(ElemKind::Float16Ty, input.dims()); |
2650 | |
2651 | return addNode(new GaussianFillNode(name, outTy, input, mean, scale, seed)); |
2652 | } |
2653 | |
2654 | BatchSparseToDenseNode *Function::createBatchSparseToDense( |
2655 | llvm::StringRef name, NodeValue lengths, NodeValue indices, |
2656 | NodeValue values, float defaultValue, unsigned_t denseLastDim) { |
2657 | // The output is a 2-D tensor with first dim = number of lengths and second |
2658 | // dim = denseLastDim |
2659 | ShapeVector outDims({lengths.dims()[0], denseLastDim}); |
2660 | auto outTy = getParent()->uniqueTypeWithNewShape(values.getType(), outDims); |
2661 | return addNode(new BatchSparseToDenseNode( |
2662 | name, outTy, lengths, indices, values, defaultValue, denseLastDim)); |
2663 | } |
2664 | |
2665 | FillExamplesWithIndicatorNode * |
2666 | Function::createFillExamplesWithIndicator(llvm::StringRef name, NodeValue data, |
2667 | NodeValue indicator) { |
2668 | ShapeVector outDims({indicator.dims()[0]}); |
2669 | if (data.dims().size() > 1) { |
2670 | outDims.insert(outDims.end(), data.dims().begin() + 1, data.dims().end()); |
2671 | } |
2672 | auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims); |
2673 | return addNode( |
2674 | new FillExamplesWithIndicatorNode(name, outTy, data, indicator)); |
2675 | } |
2676 | |
2677 | SparseToDenseMaskNode *Function::createSparseToDenseMask( |
2678 | llvm::StringRef name, NodeValue indices, NodeValue values, |
2679 | NodeValue defaultValue, NodeValue lengths, llvm::ArrayRef<dim_t> mask) { |
2680 | auto lengthsDims = lengths.dims(); |
2681 | auto valueDims = defaultValue.dims(); |
2682 | ShapeVector outDims = {(dim_t)mask.size()}; |
2683 | // If lengths is 0-dimensional tensor, then there is no batch dimension. |
2684 | if (lengthsDims.size() > 0) { |
2685 | outDims.insert(outDims.begin(), lengthsDims[0]); |
2686 | } |
2687 | outDims.insert(outDims.end(), valueDims.begin(), valueDims.end()); |
2688 | auto outTy = getParent()->uniqueTypeWithNewShape(values.getType(), outDims); |
2689 | return addNode(new SparseToDenseMaskNode(name, outTy, indices, values, |
2690 | defaultValue, lengths, mask)); |
2691 | } |
2692 | |
2693 | SparseLabelSplitNode *Function::createSparseLabelSplit(llvm::StringRef name, |
2694 | NodeValue lengths, |
2695 | NodeValue indices, |
2696 | NodeValue values, |
2697 | dim_t numLabels) { |
2698 | const auto numItems = indices.dims()[0]; |
2699 | // The assumption here is that all output tensors (excluding offsetMap) |
2700 | // will have the same number of elements, i.e. numItems / numLabels. |
2701 | auto labelValuesTy = getParent()->uniqueTypeWithNewShape( |
2702 | values.getType(), {numLabels, numItems / numLabels}); |
2703 | auto exampleIdsTy = getParent()->uniqueType( |
2704 | ElemKind::Int32ITy, {numLabels, numItems / numLabels}); |
2705 | auto gradientOffsetMapTy = |
2706 | getParent()->uniqueType(ElemKind::Int32ITy, {indices.dims()[0]}); |
2707 | return addNode(new SparseLabelSplitNode(name, labelValuesTy, exampleIdsTy, |
2708 | gradientOffsetMapTy, lengths, indices, |
2709 | values, numLabels)); |
2710 | } |
2711 | |
2712 | SaveNode *Function::createSave(llvm::StringRef name, NodeValue input) { |
2713 | auto *dest = getParent()->createPlaceholder(input.getType(), name, false); |
2714 | return createSave(name, input, dest); |
2715 | } |
2716 | |
2717 | SaveNode *Function::createSave(llvm::StringRef name, NodeValue input, |
2718 | Placeholder *output, bool skipSuffix) { |
2719 | return addNode(new SaveNode(skipSuffix ? name.str() : (name + "_save" ).str(), |
2720 | input, output)); |
2721 | } |
2722 | |
2723 | QuantizationProfileNode * |
2724 | Function::createQuantizationProfile(PlaceholderBindings &bindings, |
2725 | llvm::StringRef name, NodeValue input, |
2726 | dim_t numHistogramBins) { |
2727 | auto *histogram = getParent()->createPlaceholder( |
2728 | ElemKind::FloatTy, {numHistogramBins}, "histogram_" + name.str(), false); |
2729 | bindings.allocate(histogram)->zero(); |
2730 | // Intermediate data used for histogram calculations. |
2731 | // Min tensor value seen so far is kept on the first position. |
2732 | // Max tensor value seen so far is kept on the second position. |
2733 | auto *computationInfoPH = getParent()->createPlaceholder( |
2734 | ElemKind::FloatTy, {2}, "CI_" + name.str(), false); |
2735 | bindings.allocate(computationInfoPH); |
2736 | auto *computationInfoTensor = bindings.get(computationInfoPH); |
2737 | auto handle = computationInfoTensor->getHandle<float>(); |
2738 | handle.raw(0) = std::numeric_limits<float>::max(); |
2739 | handle.raw(1) = std::numeric_limits<float>::lowest(); |
2740 | |
2741 | return addNode(new QuantizationProfileNode( |
2742 | "QI_" + name.str(), input, histogram, computationInfoPH, |
2743 | input.getNode()->getName().str(), input.getResNo())); |
2744 | } |
2745 | |
2746 | template <typename T> |
2747 | IntLookupTableNode * |
2748 | Function::createIntLookupTable(llvm::StringRef name, NodeValue input, |
2749 | llvm::ArrayRef<T> initValues, TypeRef outTy) { |
2750 | assert(initValues.size() == input.getType()->getQuantizedValueCount() && |
2751 | "Lookup table length must match input type!" ); |
2752 | assert(outTy->isType<T>() && "Lookup table element must match output type!" ); |
2753 | if (std::is_same<T, int8_t>::value) { |
2754 | // Create INT8 lookup table. |
2755 | auto *mapping = getParent()->createConstant( |
2756 | ElemKind::Int8QTy, {(dim_t)initValues.size()}, outTy->getScale(), |
2757 | outTy->getOffset(), "mapping" ); |
2758 | mapping->getHandle<T>() = initValues; |
2759 | return addNode(new IntLookupTableNode(name, outTy, input, mapping)); |
2760 | } else if (std::is_same<T, int16_t>::value) { |
2761 | // Create INT16 lookup table. |
2762 | auto *mapping = getParent()->createConstant( |
2763 | ElemKind::Int16QTy, {(dim_t)initValues.size()}, outTy->getScale(), |
2764 | outTy->getOffset(), "mapping" ); |
2765 | mapping->getHandle<T>() = initValues; |
2766 | return addNode(new IntLookupTableNode(name, outTy, input, mapping)); |
2767 | } else if (std::is_same<T, int32_t>::value) { |
2768 | // Create INT32 lookup table. |
2769 | auto *mapping = getParent()->createConstant( |
2770 | ElemKind::Int32QTy, {(dim_t)initValues.size()}, outTy->getScale(), |
2771 | outTy->getOffset(), "mapping" ); |
2772 | mapping->getHandle<T>() = initValues; |
2773 | return addNode(new IntLookupTableNode(name, outTy, input, mapping)); |
2774 | } else { |
2775 | llvm_unreachable("Lookup table type not supported." ); |
2776 | } |
2777 | } |
2778 | |
2779 | IntLookupTableNode * |
2780 | Function::createIntLookupTable(llvm::StringRef name, NodeValue input, |
2781 | std::function<float(float)> func, |
2782 | TypeRef outTy) { |
2783 | if (outTy->isType<int8_t>()) { |
2784 | std::vector<int8_t> initValues = |
2785 | quantization::createMapping<int8_t>(input.getType(), outTy, func); |
2786 | return createIntLookupTable<int8_t>(name, input, initValues, outTy); |
2787 | } else if (outTy->isType<int16_t>()) { |
2788 | std::vector<int16_t> initValues = |
2789 | quantization::createMapping<int16_t>(input.getType(), outTy, func); |
2790 | return createIntLookupTable<int16_t>(name, input, initValues, outTy); |
2791 | } else if (outTy->isType<int32_t>()) { |
2792 | std::vector<int32_t> initValues = |
2793 | quantization::createMapping<int32_t>(input.getType(), outTy, func); |
2794 | return createIntLookupTable<int32_t>(name, input, initValues, outTy); |
2795 | } else { |
2796 | llvm_unreachable("Lookup table type not supported." ); |
2797 | } |
2798 | } |
2799 | |
2800 | LookupTableNode *Function::createLookupTable( |
2801 | llvm::StringRef name, NodeValue input, LUTOperator lutOperator, |
2802 | std::vector<float> &lutOperatorArgs, NodeValue table, NodeValue idxTable, |
2803 | TypeRef outTy) { |
2804 | return addNode(new LookupTableNode(name, outTy, input, table, idxTable, |
2805 | lutOperator, lutOperatorArgs)); |
2806 | } |
2807 | |
2808 | IntLookupTableNode *Function::createIntLog(llvm::StringRef name, |
2809 | NodeValue input, TypeRef outTy) { |
2810 | auto inputRange = input.getType()->getQuantizedValueRange(); |
2811 | (void)inputRange; |
2812 | assert(inputRange.first >= 0 && |
2813 | "Input range must not be negative since this is input to log()." ); |
2814 | auto func = [](float x) -> float { |
2815 | return (x == 0.0) ? std::log(std::numeric_limits<float>::min()) : log(x); |
2816 | }; |
2817 | return createIntLookupTable(name, input, func, outTy); |
2818 | } |
2819 | |
2820 | IntLookupTableNode *Function::createIntExp(llvm::StringRef name, |
2821 | NodeValue input, TypeRef outTy) { |
2822 | return createIntLookupTable(name, input, expf, outTy); |
2823 | } |
2824 | |
2825 | IntLookupTableNode *Function::createIntTanh(llvm::StringRef name, |
2826 | NodeValue input, TypeRef outTy) { |
2827 | return createIntLookupTable(name, input, tanhf, outTy); |
2828 | } |
2829 | |
2830 | IntLookupTableNode *Function::createIntSigmoid(llvm::StringRef name, |
2831 | NodeValue input, TypeRef outTy) { |
2832 | auto func = [](float x) -> float { return 1.0f / (1.0f + expf(-x)); }; |
2833 | return createIntLookupTable(name, input, func, outTy); |
2834 | } |
2835 | |
2836 | TopKNode *Function::createTopK(llvm::StringRef name, NodeValue input, |
2837 | unsigned_t k, ElemKind outIndicesTyKind) { |
2838 | auto inDims = input.dims(); |
2839 | assert(inDims.size() > 0); |
2840 | assert(k <= inDims.back()); |
2841 | ShapeVector outDims(inDims.begin(), inDims.end()); |
2842 | outDims.back() = k; |
2843 | auto OT = getParent()->uniqueTypeWithNewShape(input.getType(), outDims); |
2844 | return addNode(new TopKNode( |
2845 | name, OT, getParent()->uniqueType(outIndicesTyKind, outDims), input, k)); |
2846 | } |
2847 | |
2848 | TopKNode *Function::createTopK(llvm::StringRef name, NodeValue input, |
2849 | unsigned_t k) { |
2850 | return createTopK(name, input, k, ElemKind::Int64ITy); |
2851 | } |
2852 | |
2853 | ArgMaxNode *Function::createArgMax(llvm::StringRef name, NodeValue input, |
2854 | unsigned_t axis, bool keepDims, |
2855 | ElemKind elemTy) { |
2856 | ShapeVector outDims = reduceDims(input.dims(), {axis}, keepDims); |
2857 | auto OT = getParent()->uniqueType(elemTy, outDims); |
2858 | return addNode(new ArgMaxNode(name, OT, input, axis, keepDims)); |
2859 | } |
2860 | |
2861 | ArgMinNode *Function::createArgMin(llvm::StringRef name, NodeValue input, |
2862 | unsigned_t axis, bool keepDims, |
2863 | ElemKind elemTy) { |
2864 | ShapeVector outDims = reduceDims(input.dims(), {axis}, keepDims); |
2865 | auto OT = getParent()->uniqueType(elemTy, outDims); |
2866 | return addNode(new ArgMinNode(name, OT, input, axis, keepDims)); |
2867 | } |
2868 | |
2869 | VectorNormNode *Function::createVectorNorm(llvm::StringRef name, |
2870 | NodeValue input, unsigned_t axis, |
2871 | unsigned_t p) { |
2872 | auto outDims = getNewShapeWithoutAxes(input.dims(), axis); |
2873 | auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), outDims); |
2874 | const size_t outNumElements = input.getType()->size() / input.dims()[axis]; |
2875 | (void)outNumElements; |
2876 | assert(outTy->size() == outNumElements && |
2877 | "Incorrect number of elements in the output type." ); |
2878 | auto OT = getParent()->uniqueType(*outTy); |
2879 | return addNode(new VectorNormNode(name, OT, input, axis, p)); |
2880 | } |
2881 | |
2882 | CollectRpnProposalsNode *Function::createCollectRpnProposals( |
2883 | llvm::StringRef name, std::vector<NodeValue> &roisIn, |
2884 | std::vector<NodeValue> &roiProbsIn, int64_t rpnMaxLevel, |
2885 | int64_t rpnMinLevel, unsigned_t rpnPostNmsTopN) { |
2886 | |
2887 | auto boxDim = roisIn[0].dims()[1]; |
2888 | |
2889 | assert(rpnPostNmsTopN > 0 && "RpnPostNmsTopN should be greater than zero" ); |
2890 | |
2891 | ShapeVector outDims{rpnPostNmsTopN, boxDim}; |
2892 | |
2893 | auto OT = getParent()->uniqueTypeWithNewShape(roisIn[0].getType(), outDims); |
2894 | return addNode(new CollectRpnProposalsNode( |
2895 | name, OT, roisIn, roiProbsIn, rpnMaxLevel, rpnMinLevel, rpnPostNmsTopN)); |
2896 | } |
2897 | |
2898 | GatherNode *Function::createGather(llvm::StringRef name, NodeValue data, |
2899 | NodeValue indices, unsigned_t axis) { |
2900 | auto dDims = data.dims(); |
2901 | auto iDims = indices.dims(); |
2902 | assert(dDims.size() > axis); |
2903 | ShapeVector outDims; |
2904 | outDims.insert(outDims.end(), dDims.begin(), dDims.begin() + axis); |
2905 | outDims.insert(outDims.end(), iDims.begin(), iDims.end()); |
2906 | outDims.insert(outDims.end(), dDims.begin() + axis + 1, dDims.end()); |
2907 | return addNode(new GatherNode( |
2908 | name, getParent()->uniqueTypeWithNewShape(data.getType(), outDims), data, |
2909 | indices, axis)); |
2910 | } |
2911 | |
2912 | GatherNDNode *Function::createGatherND(llvm::StringRef name, NodeValue data, |
2913 | NodeValue indices, |
2914 | unsigned_t batchDims) { |
2915 | auto dataDims = data.dims(); |
2916 | auto indicesDims = indices.dims(); |
2917 | size_t indicesDimLast = indicesDims.back(); |
2918 | |
2919 | // Validate the input dimensions. |
2920 | assert(dataDims.size() >= 1 && "Input data rank must be >= 1." ); |
2921 | assert(indicesDims.size() >= 1 && "Input indices rank must be >= 1." ); |
2922 | for (size_t idx = 0; idx < batchDims; ++idx) { |
2923 | assert(dataDims[idx] == indicesDims[idx] && |
2924 | "Batch dimensions of data and indices must be the same!" ); |
2925 | } |
2926 | assert(batchDims < std::min(dataDims.size(), indicesDims.size()) && |
2927 | "The number of batch dimensions must be smaller than both the input " |
2928 | "data and indices rank!" ); |
2929 | assert(indicesDimLast >= 1 && |
2930 | "Last dimension of indices must be at least 1!" ); |
2931 | assert(indicesDimLast <= (dataDims.size() - batchDims) && |
2932 | "Last dimension of indices must be at most rank of data without batch " |
2933 | "dimensions!" ); |
2934 | |
2935 | // Compute the output dimensions. |
2936 | size_t outRank = |
2937 | dataDims.size() + indicesDims.size() - indicesDimLast - 1 - batchDims; |
2938 | std::vector<dim_t> outDims(outRank); |
2939 | size_t outIdx = 0; |
2940 | for (size_t idx = 0; idx < batchDims; ++idx) { |
2941 | outDims[outIdx++] = dataDims[idx]; |
2942 | } |
2943 | for (size_t idx = batchDims; idx < indicesDims.size() - 1; ++idx) { |
2944 | outDims[outIdx++] = indicesDims[idx]; |
2945 | } |
2946 | for (size_t idx = batchDims + indicesDimLast; idx < dataDims.size(); ++idx) { |
2947 | outDims[outIdx++] = dataDims[idx]; |
2948 | } |
2949 | auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims); |
2950 | return addNode(new GatherNDNode(name, outTy, data, indices, batchDims)); |
2951 | } |
2952 | |
2953 | GatherElementsNode *Function::createGatherElements(llvm::StringRef name, |
2954 | NodeValue data, |
2955 | NodeValue indices, |
2956 | unsigned_t dim = 0) { |
2957 | const auto iDims = indices.dims(); |
2958 | const auto dRank = data.dims().size(); |
2959 | const auto iRank = iDims.size(); |
2960 | (void)dRank; |
2961 | (void)iRank; |
2962 | assert((dim < 0 ? dim >= -dRank : dim < dRank) && |
2963 | "[GatherElements] dim must in the range [-rank, rank-1]." ); |
2964 | assert(iRank == dRank && |
2965 | "[GatherElements] Data and indices rank must be equal." ); |
2966 | assert(dRank > 0 && "[GatherElements] Data and indices rank must be >= 1." ); |
2967 | |
2968 | return addNode(new GatherElementsNode( |
2969 | name, getParent()->uniqueTypeWithNewShape(data.getType(), iDims), data, |
2970 | indices, dim)); |
2971 | } |
2972 | |
2973 | GatherRangesNode *Function::createGatherRanges(llvm::StringRef name, |
2974 | NodeValue data, NodeValue ranges, |
2975 | unsigned_t maxOutputSize) { |
2976 | auto numRanges = ranges.dims()[0]; |
2977 | return addNode(new GatherRangesNode( |
2978 | name, |
2979 | /*OutputTy=*/ |
2980 | getParent()->uniqueTypeWithNewShape(data.getType(), {maxOutputSize}), |
2981 | /*LengthsTy=*/ |
2982 | getParent()->uniqueTypeWithNewShape(ranges.getType(), numRanges), data, |
2983 | ranges)); |
2984 | } |
2985 | |
2986 | ScatterDataNode *Function::createScatterData(llvm::StringRef name, |
2987 | NodeValue data, NodeValue indices, |
2988 | NodeValue slices, |
2989 | bool cumulative) { |
2990 | return addNode(new ScatterDataNode(name, data, indices, slices, cumulative)); |
2991 | } |
2992 | |
2993 | BatchOneHotNode *Function::createBatchOneHot(llvm::StringRef name, |
2994 | NodeValue data, NodeValue lengths, |
2995 | NodeValue values) { |
2996 | auto outTy = getParent()->uniqueTypeWithNewShape( |
2997 | data.getType(), {data.dims()[0], values.dims()[0]}); |
2998 | return addNode(new BatchOneHotNode(name, outTy, data, lengths, values)); |
2999 | } |
3000 | |
3001 | SpaceToDepthNode *Function::createSpaceToDepth(llvm::StringRef name, |
3002 | NodeValue input, |
3003 | unsigned blockSize) { |
3004 | assert(blockSize > 0 && "BlockSize must be >= 1." ); |
3005 | |
3006 | auto inputDim = input.dims(); |
3007 | assert(inputDim.size() == 4 && "Dimension size of 4 is expected." ); |
3008 | assert((inputDim[1] % blockSize == 0 && inputDim[2] % blockSize == 0) && |
3009 | "Height and Width needs to be multiple of blockSize." ); |
3010 | std::vector<dim_t> newDim = {inputDim[0], inputDim[1] / blockSize, |
3011 | inputDim[2] / blockSize, |
3012 | inputDim[3] * blockSize * blockSize}; |
3013 | auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim); |
3014 | return addNode(new SpaceToDepthNode(name, outTy, input, blockSize)); |
3015 | } |
3016 | |
3017 | ReshapeNode *Function::createDepthToSpace(llvm::StringRef name, NodeValue input, |
3018 | unsigned blockSize, bool isCRD) { |
3019 | assert(blockSize > 0 && "Block size must be >= 1." ); |
3020 | |
3021 | auto inputDim = input.dims(); |
3022 | assert(inputDim.size() == 4 && "Dimension size of 4 is expected." ); |
3023 | dim_t N = inputDim[0]; |
3024 | dim_t H = inputDim[1]; |
3025 | dim_t W = inputDim[2]; |
3026 | dim_t C = inputDim[3]; |
3027 | assert(C % (blockSize * blockSize) == 0 && |
3028 | "Depth should be divisible by block size squared." ); |
3029 | |
3030 | llvm::SmallVector<unsigned_t, 6> shuffle; |
3031 | llvm::SmallVector<dim_t, 6> tmpShape; |
3032 | llvm::SmallVector<dim_t, 4> outShape = {N, H * blockSize, W * blockSize, |
3033 | C / (blockSize * blockSize)}; |
3034 | if (isCRD) { |
3035 | tmpShape = {N, H, W, C / (blockSize * blockSize), blockSize, blockSize}; |
3036 | shuffle = D2S_CRD; |
3037 | } else { |
3038 | tmpShape = {N, H, W, blockSize, blockSize, C / (blockSize * blockSize)}; |
3039 | shuffle = D2S_DCR; |
3040 | } |
3041 | |
3042 | auto *RN1 = createReshape(name.str() + "_reshape_in" , input, tmpShape); |
3043 | auto *TN = createTranspose(name.str() + "_transpose" , RN1, shuffle); |
3044 | return createReshape(name.str() + "_reshape_out" , TN, outShape); |
3045 | } |
3046 | |
3047 | ResizeNearestNode *Function::createResizeNearest(llvm::StringRef name, |
3048 | NodeValue input, |
3049 | llvm::ArrayRef<float> scale) { |
3050 | auto inputDim = input.dims(); |
3051 | DCHECK_EQ(inputDim.size(), scale.size()) |
3052 | << "Input Dimension size: " << inputDim.size() |
3053 | << " Scale size: " << scale.size() << " should be same." ; |
3054 | |
3055 | std::vector<dim_t> newDim; |
3056 | |
3057 | for (size_t i = 0; i < scale.size(); i++) { |
3058 | auto newD = dim_t(std::floor(inputDim[i] * scale[i])); |
3059 | DCHECK_GT(newD, 0) << "Scaled dim is " << newD |
3060 | << ", Scaled value needs to be larger than 0." ; |
3061 | newDim.push_back(newD); |
3062 | } |
3063 | |
3064 | auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim); |
3065 | return addNode(new ResizeNearestNode(name, outTy, input, scale)); |
3066 | } |
3067 | |
3068 | ResizeNearestNode *Function::createResizeNearest(llvm::StringRef name, |
3069 | NodeValue input, |
3070 | TypeRef outTy) { |
3071 | auto inputDim = input.dims(); |
3072 | auto outputDim = outTy->dims(); |
3073 | DCHECK_EQ(inputDim.size(), outputDim.size()) |
3074 | << "Input dimension size: " << inputDim.size() |
3075 | << " output dimension size: " << outputDim.size() << " should be same." ; |
3076 | |
3077 | std::vector<float> scales; |
3078 | for (size_t i = 0; i < inputDim.size(); i++) { |
3079 | float scale = (outputDim[i] / (float)inputDim[i]); |
3080 | DCHECK_GT(scale, 0.0) << "Scale: " << scale |
3081 | << ", Scale larger than 0 is expected." ; |
3082 | scales.push_back(scale); |
3083 | } |
3084 | |
3085 | return addNode(new ResizeNearestNode(name, outTy, input, scales)); |
3086 | } |
3087 | |
3088 | ResizeBilinearNode * |
3089 | Function::createResizeBilinear(llvm::StringRef name, NodeValue input, |
3090 | llvm::ArrayRef<float> scale) { |
3091 | auto inputDim = input.dims(); |
3092 | DCHECK_EQ(inputDim.size(), scale.size()) |
3093 | << "Input Dimension size: " << inputDim.size() |
3094 | << " Scale size: " << scale.size() << " should be same." ; |
3095 | |
3096 | std::vector<dim_t> newDim; |
3097 | |
3098 | for (size_t i = 0; i < scale.size(); i++) { |
3099 | auto newD = dim_t(std::floor(inputDim[i] * scale[i])); |
3100 | DCHECK_GT(newD, 0) << "Scaled dim is " << newD |
3101 | << ", Scaled value needs to be larger than 0." ; |
3102 | newDim.push_back(newD); |
3103 | } |
3104 | |
3105 | auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim); |
3106 | return addNode(new ResizeBilinearNode(name, outTy, input, scale)); |
3107 | } |
3108 | |
3109 | ResizeBilinearNode *Function::createResizeBilinear(llvm::StringRef name, |
3110 | NodeValue input, |
3111 | TypeRef outTy) { |
3112 | auto inputDim = input.dims(); |
3113 | auto outputDim = outTy->dims(); |
3114 | DCHECK_EQ(inputDim.size(), outputDim.size()) |
3115 | << "Input dimension size: " << inputDim.size() |
3116 | << " output dimension size: " << outputDim.size() << " should be same." ; |
3117 | |
3118 | std::vector<float> scales; |
3119 | for (size_t i = 0; i < inputDim.size(); i++) { |
3120 | float scale = (outputDim[i] / (float)inputDim[i]); |
3121 | DCHECK_GT(scale, 0.0) << "Scale: " << scale |
3122 | << ", Scale larger than 0 is expected." ; |
3123 | scales.push_back(scale); |
3124 | } |
3125 | |
3126 | return addNode(new ResizeBilinearNode(name, outTy, input, scales)); |
3127 | } |
3128 | |
3129 | QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input, |
3130 | TypeRef outTy) { |
3131 | assert(input.getType()->isFPType() && "Input must be a floating type" ); |
3132 | assert(outTy->isQuantizedType() && "Output must be a quantized type" ); |
3133 | assert(input.dims().equals(outTy->dims()) && |
3134 | "Different dimensions for input and output" ); |
3135 | |
3136 | return addNode( |
3137 | new QuantizeNode(name, getParent()->uniqueType(*outTy), input)); |
3138 | } |
3139 | |
3140 | QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input, |
3141 | ElemKind q, float scale, |
3142 | int32_t offset) { |
3143 | TypeRef OT = getParent()->uniqueType(q, input.dims(), scale, offset); |
3144 | return createQuantize(name, input, OT); |
3145 | } |
3146 | |
3147 | DequantizeNode *Function::createDequantize(llvm::StringRef name, |
3148 | NodeValue input, ElemKind k) { |
3149 | assert(input.getType()->isQuantizedType() && |
3150 | "Input must be a quantized type" ); |
3151 | assert(isFloatElemKind(k) && "Result must be float type." ); |
3152 | ShapeVector outShape(input.dims().begin(), input.dims().end()); |
3153 | if (input.getElementType() == ElemKind::UInt8FusedQTy) { |
3154 | assert(outShape.size() == 2 && "Fused tensors should be 2D" ); |
3155 | assert(outShape[1] > 2 * sizeof(float) && |
3156 | "Expected space for per-row scale/offset" ); |
3157 | outShape[1] -= 2 * sizeof(float); |
3158 | } |
3159 | TypeRef outTy = getParent()->uniqueType(Type(k, outShape)); |
3160 | return createDequantize(name, input, outTy); |
3161 | } |
3162 | |
3163 | DequantizeNode *Function::createDequantize(llvm::StringRef name, |
3164 | NodeValue input, TypeRef outTy) { |
3165 | assert(input.getType()->isQuantizedType() && |
3166 | "Input must be a quantized type" ); |
3167 | assert(outTy->isFPType() && "Output should be an FP type" ); |
3168 | return addNode(new DequantizeNode(name, outTy, input)); |
3169 | } |
3170 | |
3171 | RescaleQuantizedNode *Function::createRescaleQuantized(llvm::StringRef name, |
3172 | NodeValue input, |
3173 | TypeRef outTy) { |
3174 | assert(input.getType()->isQuantizedType() && |
3175 | "Input must be a quantized type" ); |
3176 | assert(outTy->isQuantizedType() && "Output must be a quantized type" ); |
3177 | assert(input.dims().equals(outTy->dims()) && |
3178 | "Different dimensions for input and output" ); |
3179 | |
3180 | return addNode( |
3181 | new RescaleQuantizedNode(name, getParent()->uniqueType(*outTy), input)); |
3182 | } |
3183 | |
3184 | Node *Function::createWeightedSum(llvm::StringRef name, |
3185 | llvm::ArrayRef<NodeValue> data, |
3186 | llvm::ArrayRef<NodeValue> weights) { |
3187 | assert(data.size() == weights.size() && |
3188 | "Must have same number of data and weights." ); |
3189 | assert(data.size() > 0 && "No inputs provided." ); |
3190 | |
3191 | const auto *outTy = data[0].getType(); |
3192 | |
3193 | // Create a zero splat to bootstrap the adding chain. |
3194 | Node *currAdd = createSplat(name.str() + ".splat" , outTy, 0.); |
3195 | |
3196 | for (size_t i = 0, e = data.size(); i < e; i++) { |
3197 | assert(weights[i].getType()->size() == 1 && |
3198 | "Each provided weight node must be size 1." ); |
3199 | assert(outTy == data[i].getType() && |
3200 | "All data nodes must have the same type." ); |
3201 | |
3202 | // Broadcast the current weight to same shape as the data. |
3203 | auto *bcastW = |
3204 | createBroadcast(name.str() + ".bcastWeight" + std::to_string(i), |
3205 | weights[i], outTy->dims(), /* axis */ 0); |
3206 | |
3207 | // Element-wise multiply the broadcasted weight by the data. |
3208 | auto *scaledD = |
3209 | createMul(name.str() + ".mul" + std::to_string(i), bcastW, data[i]); |
3210 | |
3211 | // Element-wise add the scaled data to the running total. |
3212 | currAdd = |
3213 | createAdd(name.str() + ".add" + std::to_string(i), scaledD, currAdd); |
3214 | } |
3215 | |
3216 | // Return the final weighted sum via the last add in the chain. |
3217 | return currAdd; |
3218 | } |
3219 | |
3220 | Node *Function::createBatchBoxCox(llvm::StringRef name, NodeValue data, |
3221 | NodeValue lambda1, NodeValue lambda2, |
3222 | float epsilon) { |
3223 | assert((lambda1.dims() == lambda2.dims()) && |
3224 | "lambda1 and lambda2 must have the same shape" ); |
3225 | assert((lambda1.getType()->getElementType() == lambda2.getElementType()) && |
3226 | "lambda1 and lambda2 must have the same element type" ); |
3227 | assert((lambda1.getType()->getElementType() == data.getElementType()) && |
3228 | "data and lambdas must have the same element type" ); |
3229 | assert((lambda1.dims().size() == 1) && "lambda1 and lambda2 must be vectors" ); |
3230 | assert((data.dims().size() == 2) && "data must be a matrix" ); |
3231 | assert((data.dims()[1] == lambda1.dims()[0]) && |
3232 | "data, lambda1 and lambda2 must have the same number of rows" ); |
3233 | |
3234 | return addNode(new BatchBoxCoxNode(name, data, lambda1, lambda2, epsilon)); |
3235 | } |
3236 | |
3237 | ClipNode *Function::createClip(llvm::StringRef name, NodeValue input, |
3238 | TypeRef outTy, float min, float max) { |
3239 | return addNode(new ClipNode(name, outTy, input, min, max)); |
3240 | } |
3241 | |
3242 | ClipNode *Function::createClip(llvm::StringRef name, NodeValue input, float min, |
3243 | float max) { |
3244 | return addNode(new ClipNode(name, input.getType(), input, min, max)); |
3245 | } |
3246 | |
3247 | ClipNode *Function::createClipMinMaxFP16(llvm::StringRef name, |
3248 | NodeValue input) { |
3249 | return createClip(name, input, kMinFP16, kMaxFP16); |
3250 | } |
3251 | |
3252 | ClipNode *Function::createClipMinMaxBFloat16(llvm::StringRef name, |
3253 | NodeValue input) { |
3254 | constexpr float bfloat16Min = FLT_MIN; |
3255 | constexpr float bfloat16Max = FLT_MAX; |
3256 | return createClip(name, input, bfloat16Min, bfloat16Max); |
3257 | } |
3258 | |
3259 | BatchedUnaryEmbeddingsBagsNode *Function::createBatchedUnaryEmbeddingsBags( |
3260 | llvm::StringRef name, NodeValue weights, NodeValue tableOffsets, |
3261 | NodeValue indices, NodeValue offsets) { |
3262 | auto inDims = weights.dims(); |
3263 | ShapeVector outDims(inDims.begin(), inDims.end()); |
3264 | outDims[0] = weights.dims()[0]; |
3265 | outDims[2] = tableOffsets.dims()[0] - 1; |
3266 | outDims[1] = (offsets.dims()[0] - 1) / outDims[2]; |
3267 | |
3268 | auto outTy = getParent()->uniqueTypeWithNewShape(weights.getType(), outDims); |
3269 | return addNode(new BatchedUnaryEmbeddingsBagsNode( |
3270 | name, outTy, weights, tableOffsets, offsets, indices)); |
3271 | } |
3272 | |
3273 | IntNBitSplitEmbeddingBagsNode *Function::createIntNBitSplitEmbeddingBags( |
3274 | llvm::StringRef name, NodeValue devWeights, NodeValue uvmWeights, |
3275 | NodeValue weightsPlacements, NodeValue weightsOffsets, NodeValue weightsTys, |
3276 | NodeValue dimOffsets, int64_t totalDims, NodeValue indices, |
3277 | NodeValue offsets, SplitEmbeddingPoolingMode poolingMode, |
3278 | SplitEmbeddingSparseType outputDtype) { |
3279 | ShapeVector outDims; |
3280 | outDims.insert(outDims.end(), |
3281 | (offsets.dims()[0] - 1) / (dimOffsets.dims()[0] - 1)); |
3282 | int64_t totalSize = outputDtype == SplitEmbeddingSparseType::EST_FLOAT |
3283 | ? totalDims * sizeof(float) |
3284 | : totalDims * sizeof(float16_t); |
3285 | outDims.insert(outDims.end(), totalSize); |
3286 | |
3287 | auto outTy = |
3288 | getParent()->uniqueTypeWithNewShape(devWeights.getType(), outDims); |
3289 | |
3290 | return addNode(new IntNBitSplitEmbeddingBagsNode( |
3291 | name, outTy, devWeights, uvmWeights, weightsPlacements, weightsOffsets, |
3292 | weightsTys, dimOffsets, indices, offsets, totalDims, poolingMode, |
3293 | outputDtype)); |
3294 | } |
3295 | |
3296 | IntNBitSplitEmbeddingWeightedBagsNode * |
3297 | Function::createIntNBitSplitEmbeddingWeightedBags( |
3298 | llvm::StringRef name, NodeValue devWeights, NodeValue uvmWeights, |
3299 | NodeValue weightsPlacements, NodeValue weightsOffsets, NodeValue weightsTys, |
3300 | NodeValue dimOffsets, int64_t totalDims, NodeValue indices, |
3301 | NodeValue offsets, SplitEmbeddingPoolingMode poolingMode, |
3302 | SplitEmbeddingSparseType outputDtype, NodeValue indiceWeights) { |
3303 | ShapeVector outDims; |
3304 | outDims.insert(outDims.end(), |
3305 | (offsets.dims()[0] - 1) / (dimOffsets.dims()[0] - 1)); |
3306 | int64_t totalSize = outputDtype == SplitEmbeddingSparseType::EST_FLOAT |
3307 | ? totalDims * sizeof(float) |
3308 | : totalDims * sizeof(float16_t); |
3309 | outDims.insert(outDims.end(), totalSize); |
3310 | |
3311 | auto outTy = |
3312 | getParent()->uniqueTypeWithNewShape(devWeights.getType(), outDims); |
3313 | |
3314 | return addNode(new IntNBitSplitEmbeddingWeightedBagsNode( |
3315 | name, outTy, devWeights, uvmWeights, weightsPlacements, weightsOffsets, |
3316 | weightsTys, dimOffsets, indices, offsets, indiceWeights, totalDims, |
3317 | poolingMode, outputDtype)); |
3318 | } |
3319 | |
3320 | //===----------------------------------------------------------------------===// |
3321 | // Placeholder-builder methods. |
3322 | //===----------------------------------------------------------------------===// |
3323 | |
3324 | BatchNormalizationNode *Function::createBatchNormalization( |
3325 | PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input, |
3326 | unsigned_t channelIdx, float epsilon, float momentum) { |
3327 | // Figure out how many channels are in the tensor. |
3328 | dim_t channels = input.dims()[channelIdx]; |
3329 | |
3330 | ElemKind inputTy = input.getType()->getElementType(); |
3331 | |
3332 | // Allocate the learnable parameters beta and gamma. |
3333 | auto *beta = |
3334 | getParent()->createPlaceholder(inputTy, {channels}, "beta" , true); |
3335 | bindings.allocate(beta)->init(Tensor::InitKind::Broadcast, 0.1, getPRNG()); |
3336 | |
3337 | auto *scale = |
3338 | getParent()->createPlaceholder(inputTy, {channels}, "scale" , true); |
3339 | bindings.allocate(scale)->init(Tensor::InitKind::Broadcast, 0.001, getPRNG()); |
3340 | |
3341 | auto *mean = |
3342 | getParent()->createPlaceholder(inputTy, {channels}, "mean" , false); |
3343 | bindings.allocate(mean)->zero(); |
3344 | |
3345 | auto *variance = |
3346 | getParent()->createPlaceholder(inputTy, {channels}, "variance" , false); |
3347 | bindings.allocate(variance)->init(Tensor::InitKind::Broadcast, 1.0, |
3348 | getPRNG()); |
3349 | |
3350 | auto resultType = getParent()->uniqueType(inputTy, input.dims()); |
3351 | |
3352 | return createBatchNormalization(name, resultType, input, beta, scale, mean, |
3353 | variance, channelIdx, epsilon, momentum); |
3354 | } |
3355 | |
3356 | ConvolutionNode *Function::createConv( |
3357 | PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input, |
3358 | dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels, |
3359 | llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads, |
3360 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation, |
3361 | ConvolutionLayout layout) { |
3362 | ShapeNHWC idim = ShapeNHWC(input.dims()); |
3363 | ShapeHW kdim(kernels); |
3364 | PaddingTLBR pdim(pads); |
3365 | (void)pdim; |
3366 | assert((idim.w + pdim.left + pdim.right) >= kdim.width && |
3367 | (idim.h + pdim.top + pdim.bottom) >= kdim.height && |
3368 | "buffer too small for selected stride" ); |
3369 | |
3370 | assert(group > 0 && "group should be larger than 0" ); |
3371 | assert(idim.c % group == 0 && "channels number must be divisible by groups" ); |
3372 | assert(outChannels % group == 0 && "outChannels must be divisible by groups" ); |
3373 | |
3374 | // Calculate the size and allocate the output buffer. |
3375 | auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, |
3376 | pads, dilation); |
3377 | |
3378 | std::array<dim_t, 4> outDims = { |
3379 | {idim.n, outSz.first, outSz.second, outChannels}}; |
3380 | |
3381 | // Allocate the Filter and Bias tensors. |
3382 | std::array<dim_t, 4> filterDim = { |
3383 | {outChannels, kdim.height, kdim.width, idim.c / group}}; |
3384 | size_t fanIn = kdim.height * kdim.width * idim.c; |
3385 | ElemKind inputTy = input.getType()->getElementType(); |
3386 | assert(isFloatElemKind(inputTy) && "Convolution on non-floating point type?" ); |
3387 | auto *filter = |
3388 | getParent()->createPlaceholder(inputTy, filterDim, "filter" , true); |
3389 | bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn, |
3390 | getPRNG()); |
3391 | |
3392 | auto *bias = |
3393 | getParent()->createPlaceholder(inputTy, {outChannels}, "bias" , true); |
3394 | bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1, |
3395 | getPRNG()); |
3396 | |
3397 | auto OT = getParent()->uniqueType(inputTy, outDims); |
3398 | |
3399 | return addNode(new ConvolutionNode(name, OT, input, filter, bias, kernels, |
3400 | strides, pads, group, dilation, layout, |
3401 | FusedActivation::NONE, {})); |
3402 | } |
3403 | |
3404 | ConvolutionNode *Function::createConv(PlaceholderBindings &bindings, |
3405 | llvm::StringRef name, NodeValue input, |
3406 | dim_t outChannels, unsigned_t kernel, |
3407 | unsigned_t stride, unsigned_t pad, |
3408 | unsigned_t group, |
3409 | llvm::ArrayRef<unsigned_t> dilation, |
3410 | ConvolutionLayout layout) { |
3411 | llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad}; |
3412 | llvm::SmallVector<unsigned_t, 2> strides = {stride, stride}; |
3413 | llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel}; |
3414 | return createConv(bindings, name, input, outChannels, kernels, strides, pads, |
3415 | group, dilation, layout); |
3416 | } |
3417 | |
3418 | Convolution3DNode *Function::createConv3D(PlaceholderBindings &bindings, |
3419 | llvm::StringRef name, NodeValue input, |
3420 | dim_t outChannels, |
3421 | llvm::ArrayRef<unsigned_t> kernels, |
3422 | llvm::ArrayRef<unsigned_t> strides, |
3423 | llvm::ArrayRef<unsigned_t> pads, |
3424 | unsigned_t group) { |
3425 | ShapeNTHWC idim(input.dims()); |
3426 | ShapeTHW kdim(kernels); |
3427 | |
3428 | assert(group > 0 && "group should be larger than 0" ); |
3429 | assert(idim.c % group == 0 && "channels number must be divisible by groups" ); |
3430 | assert(outChannels % group == 0 && "outChannels must be divisible by groups" ); |
3431 | |
3432 | // Calculate the size and allocate the output buffer. |
3433 | auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels, |
3434 | strides, pads); |
3435 | |
3436 | std::array<dim_t, 5> outDims = { |
3437 | {idim.n, outSz.temporal_frames, outSz.height, outSz.width, outChannels}}; |
3438 | |
3439 | // Allocate the Filter and Bias tensors. |
3440 | std::array<dim_t, 5> filterDim = {{outChannels, kdim.temporal_frames, |
3441 | kdim.height, kdim.width, idim.c / group}}; |
3442 | |
3443 | dim_t fanIn = kdim.temporal_frames * kdim.height * kdim.width * idim.c; |
3444 | ElemKind inputTy = input.getType()->getElementType(); |
3445 | assert(isFloatElemKind(inputTy) && |
3446 | "Convolution3D on non-floating point type?" ); |
3447 | auto *filter = |
3448 | getParent()->createPlaceholder(inputTy, filterDim, "filter" , true); |
3449 | bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn, |
3450 | getPRNG()); |
3451 | |
3452 | auto *bias = |
3453 | getParent()->createPlaceholder(inputTy, {outChannels}, "bias" , true); |
3454 | bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1, |
3455 | getPRNG()); |
3456 | |
3457 | auto OT = getParent()->uniqueType(inputTy, outDims); |
3458 | |
3459 | assertConv3DDims(input, filter, bias, kernels, strides, pads, group); |
3460 | |
3461 | return addNode(new Convolution3DNode(name, OT, input, filter, bias, kernels, |
3462 | strides, pads, group)); |
3463 | } |
3464 | |
3465 | Convolution3DNode *Function::createConv3D(PlaceholderBindings &bindings, |
3466 | llvm::StringRef name, NodeValue input, |
3467 | size_t outChannels, unsigned_t kernel, |
3468 | unsigned_t stride, unsigned_t pad, |
3469 | unsigned_t group) { |
3470 | llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad}; |
3471 | llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride}; |
3472 | llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel}; |
3473 | return createConv3D(bindings, name, input, outChannels, kernels, strides, |
3474 | pads, group); |
3475 | } |
3476 | |
3477 | ChannelwiseQuantizedConvolutionNode *Function::createChannelwiseQuantizedConv( |
3478 | llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias, |
3479 | NodeValue filterScales, NodeValue filterOffsets, NodeValue biasScales, |
3480 | NodeValue biasOffsets, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels, |
3481 | llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads, |
3482 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation, bool quantizeFilter, |
3483 | bool quantizeBias, quantization::Schema schema, ElemKind filterElemQTy, |
3484 | ElemKind biasElemQTy) { |
3485 | |
3486 | // Validate dimensions. |
3487 | bool isConv3D = (input.getType()->dims().size() == 5); |
3488 | if (isConv3D) { |
3489 | assertConv3DDims(input, filter, bias, kernels, strides, pads, group); |
3490 | } else { |
3491 | assertConvDims(input, filter, bias, kernels, strides, pads, group); |
3492 | } |
3493 | |
3494 | // Validate bias precision. |
3495 | auto biasElemKind = bias.getElementType(); |
3496 | DCHECK(biasElemKind == ElemKind::Int8QTy || |
3497 | biasElemKind == ElemKind::Int32QTy || |
3498 | biasElemKind == ElemKind::FloatTy) |
3499 | << "Unsupported element type for ChannelwiseQuantizedConvolution bias: " |
3500 | << Type::getElementName(biasElemKind).str(); |
3501 | |
3502 | // Validate filter precision. |
3503 | auto filterElemKind = filter.getElementType(); |
3504 | DCHECK(filterElemKind == ElemKind::Int8QTy || |
3505 | filterElemKind == ElemKind::FloatTy) |
3506 | << "Unsupported element type for ChannelwiseQuantizedConvolution " |
3507 | << "filter: " << Type::getElementName(filterElemKind).str(); |
3508 | |
3509 | DCHECK(!filterScales.getNode() || dyn_cast<Constant>(filterScales.getNode())) |
3510 | << "Filter scales input to ChannelwiseQuantizedConvolutionNode must be " |
3511 | "null or Constant" ; |
3512 | |
3513 | DCHECK(!filterOffsets.getNode() || |
3514 | dyn_cast<Constant>(filterOffsets.getNode())) |
3515 | << "Filter offsets input to ChannelwiseQuantizedConvolutionNode must be " |
3516 | "null or Constant" ; |
3517 | |
3518 | DCHECK(!biasScales.getNode() || dyn_cast<Constant>(biasScales.getNode())) |
3519 | << "Bias scales input to ChannelwiseQuantizedConvolutionNode must be " |
3520 | "null or Constant" ; |
3521 | |
3522 | DCHECK(!biasOffsets.getNode() || dyn_cast<Constant>(biasOffsets.getNode())) |
3523 | << "Bias offsets input to ChannelwiseQuantizedConvolutionNode must be " |
3524 | "null or Constant" ; |
3525 | |
3526 | // Number of output channels. |
3527 | dim_t numChannels = outTy->dims().back(); |
3528 | dim_t qDim = 0; |
3529 | dim_t qStep = 1; |
3530 | |
3531 | // Whether filter/bias quantization parameters were explicitly provided. |
3532 | bool filterQParamsGiven = filterScales.getNode() && filterOffsets.getNode(); |
3533 | bool biasQParamsGiven = biasScales.getNode() && biasOffsets.getNode(); |
3534 | |
3535 | // If input filter is FLOAT and filterScales/filterOffsets are NOT provided |
3536 | // then compute them automatically for given schema and filterElemQTy. |
3537 | // If input filter is QUANTIZED then filterScales/filterOffsets are mandatory. |
3538 | if (!filterQParamsGiven) { |
3539 | DCHECK(filterElemKind == ElemKind::FloatTy) |
3540 | << "ChannelwiseQuantizedConvolution: If the input filter is " |
3541 | << "quantized then the filter scales/offsets must be provided!" ; |
3542 | Constant *filterC = dyn_cast<Constant>(filter.getNode()); |
3543 | DCHECK(filterC) |
3544 | << "Filter input to ChannelwiseQuantizedConvolutionNode must be a " |
3545 | "Constant to quantize it" ; |
3546 | Constant *filterScalesC = getParent()->createConstant( |
3547 | ElemKind::FloatTy, {numChannels}, "filterScales" ); |
3548 | Constant *filterOffsetsC = getParent()->createConstant( |
3549 | ElemKind::Int32ITy, {numChannels}, "filterOffsets" ); |
3550 | // Get filter channelwise TensorQuantizationParams. |
3551 | quantization::getTensorQuantizationParams( |
3552 | filterC->getPayload(), filterScalesC->getPayloadMutable(), |
3553 | filterOffsetsC->getPayloadMutable(), schema, filterElemQTy, qDim, |
3554 | qStep); |
3555 | filterScales = NodeValue(filterScalesC); |
3556 | filterOffsets = NodeValue(filterOffsetsC); |
3557 | } |
3558 | |
3559 | // If input bias is FLOAT and biasScales/biasOffsets are NOT provided |
3560 | // then compute them automatically for given schema and biasElemQTy. |
3561 | // If input bias is QUANTIZED and biasScales/biasOffsets are NOT provided |
3562 | // then assume the channel wise quantization parameters are implicitly: |
3563 | // biasScales[i] = inputScale * filterScales[i] and biasOffsets[i] = 0. |
3564 | if (!biasQParamsGiven) { |
3565 | Constant *biasC = dyn_cast<Constant>(bias.getNode()); |
3566 | DCHECK(biasC) |
3567 | << "Bias input to ChannelwiseQuantizedConvolutionNode must be a " |
3568 | "Constant to quantize it" ; |
3569 | Constant *biasScalesC = getParent()->createConstant( |
3570 | ElemKind::FloatTy, {numChannels}, "biasScales" ); |
3571 | Constant *biasOffsetsC = getParent()->createConstant( |
3572 | ElemKind::Int32ITy, {numChannels}, "biasOffsets" ); |
3573 | auto biasScalesH = biasScalesC->getPayload().getHandle<float>(); |
3574 | auto biasOffsetsH = biasOffsetsC->getPayload().getHandle<int32_t>(); |
3575 | Constant *filterScalesC = dyn_cast<Constant>(filterScales.getNode()); |
3576 | Constant *filterOffsetsC = dyn_cast<Constant>(filterOffsets.getNode()); |
3577 | auto filterScalesH = filterScalesC->getPayload().getHandle<float>(); |
3578 | auto filterOffsetsH = filterOffsetsC->getPayload().getHandle<int32_t>(); |
3579 | auto inputScale = input.getType()->getScale(); |
3580 | auto inputOffset = input.getType()->getOffset(); |
3581 | if (biasElemKind == ElemKind::FloatTy) { |
3582 | auto biasH = biasC->getPayload().getHandle<float>(); |
3583 | // Get bias channelwise TensorQuantizationParams. |
3584 | quantization::getTensorQuantizationParams( |
3585 | biasC->getPayload(), biasScalesC->getPayloadMutable(), |
3586 | biasOffsetsC->getPayloadMutable(), schema, biasElemQTy, qDim, qStep); |
3587 | // Specialize the bias channelwise TensorQuantizationParams. |
3588 | for (dim_t idx = 0; idx < numChannels; idx++) { |
3589 | bool biasZero = (biasH.raw(idx) == 0.f); |
3590 | TensorQuantizationParams biasTQP = {biasScalesH.raw(idx), |
3591 | biasOffsetsH.raw(idx)}; |
3592 | TensorQuantizationParams inputTQP = {inputScale, inputOffset}; |
3593 | TensorQuantizationParams filterTQP = {filterScalesH.raw(idx), |
3594 | filterOffsetsH.raw(idx)}; |
3595 | if (filterQParamsGiven) { |
3596 | // Specialize only bias quantization parameters. |
3597 | biasTQP = specializeBiasQuantizationParams( |
3598 | biasTQP, inputTQP, filterTQP, schema, biasElemQTy, biasZero); |
3599 | } else { |
3600 | // Specialize bias and weights quantization parameters. |
3601 | specializeBiasWeightsQuantizationParams( |
3602 | biasTQP, inputTQP, filterTQP, schema, biasElemQTy, biasZero); |
3603 | } |
3604 | biasScalesH.raw(idx) = biasTQP.scale; |
3605 | biasOffsetsH.raw(idx) = biasTQP.offset; |
3606 | filterScalesH.raw(idx) = filterTQP.scale; |
3607 | filterOffsetsH.raw(idx) = filterTQP.offset; |
3608 | } |
3609 | } else { |
3610 | // Set implicit bias channelwise TensorQuantizationParams. |
3611 | for (dim_t idx = 0; idx < numChannels; idx++) { |
3612 | float filterScale = filterScalesH.raw(idx); |
3613 | biasScalesH.raw(idx) = inputScale * filterScale; |
3614 | biasOffsetsH.raw(idx) = 0; |
3615 | } |
3616 | } |
3617 | biasScales = NodeValue(biasScalesC); |
3618 | biasOffsets = NodeValue(biasOffsetsC); |
3619 | } |
3620 | |
3621 | // If input filter is FLOAT then quantize channel wise to filterElemQTy. |
3622 | if (quantizeFilter && filterElemKind == ElemKind::FloatTy) { |
3623 | Constant *filterC = dyn_cast<Constant>(filter.getNode()); |
3624 | DCHECK(filterC) |
3625 | << "Filter input to ChannelwiseQuantizedConvolutionNode must be a " |
3626 | "Constant to quantize it" ; |
3627 | Constant *filterCQ = getParent()->createConstant( |
3628 | filterElemQTy, filterC->getType()->dims(), 1.0, 0, "filter" ); |
3629 | Constant *filterScalesC = dyn_cast<Constant>(filterScales.getNode()); |
3630 | Constant *filterOffsetsC = dyn_cast<Constant>(filterOffsets.getNode()); |
3631 | // Quantize filter channelwise. |
3632 | filterCQ->getPayloadMutable() = quantization::quantizeTensor( |
3633 | filterC->getPayload(), filterScalesC->getPayload(), |
3634 | filterOffsetsC->getPayload(), filterElemQTy, qDim, qStep); |
3635 | filter = NodeValue(filterCQ); |
3636 | } |
3637 | |
3638 | // If input bias is FLOAT then quantize channel wise to biasElemQTy. |
3639 | if (quantizeBias && biasElemKind == ElemKind::FloatTy) { |
3640 | Constant *biasC = dyn_cast<Constant>(bias.getNode()); |
3641 | DCHECK(biasC) |
3642 | << "Bias input to ChannelwiseQuantizedConvolutionNode must be a " |
3643 | "Constant to quantize it" ; |
3644 | Constant *biasCQ = getParent()->createConstant( |
3645 | biasElemQTy, biasC->getType()->dims(), 1.0, 0, "bias" ); |
3646 | Constant *biasScalesC = dyn_cast<Constant>(biasScales.getNode()); |
3647 | Constant *biasOffsetsC = dyn_cast<Constant>(biasOffsets.getNode()); |
3648 | // Quantize bias channelwise. |
3649 | biasCQ->getPayloadMutable() = quantization::quantizeTensor( |
3650 | biasC->getPayload(), biasScalesC->getPayload(), |
3651 | biasOffsetsC->getPayload(), biasElemQTy, qDim, qStep); |
3652 | bias = NodeValue(biasCQ); |
3653 | } |
3654 | |
3655 | auto OT = getParent()->uniqueType(*outTy); |
3656 | return addNode(new ChannelwiseQuantizedConvolutionNode( |
3657 | name, OT, input, filter, bias, filterScales, filterOffsets, biasScales, |
3658 | biasOffsets, kernels, strides, pads, group, dilation, |
3659 | FusedActivation::NONE, {})); |
3660 | } |
3661 | |
3662 | ConvTransposeNode *Function::createConvTranspose( |
3663 | PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input, |
3664 | dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels, |
3665 | llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads, |
3666 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation) { |
3667 | ShapeNHWC idim = ShapeNHWC(input.dims()); |
3668 | ShapeHW kdim(kernels); |
3669 | PaddingTLBR pdim(pads); |
3670 | (void)pdim; |
3671 | assert((idim.w + pdim.left + pdim.right) >= kdim.width && |
3672 | (idim.h + pdim.top + pdim.bottom) >= kdim.height && |
3673 | "buffer too small for selected stride" ); |
3674 | |
3675 | assert(group > 0 && "group should be larger than 0" ); |
3676 | assert(idim.c % group == 0 && "channels number must be divisible by groups" ); |
3677 | assert(outChannels % group == 0 && "outChannels must be divisible by groups" ); |
3678 | |
3679 | // Calculate the size and allocate the output buffer. |
3680 | auto outSz = calculateConvTransposeOutputDims(idim.h, idim.w, kernels, |
3681 | strides, pads, dilation); |
3682 | |
3683 | std::array<dim_t, 4> outDims = { |
3684 | {idim.n, outSz.first, outSz.second, outChannels}}; |
3685 | |
3686 | // Allocate the Filter and Bias tensors. |
3687 | std::array<dim_t, 4> filterDim = { |
3688 | {outChannels, kdim.height, kdim.width, idim.c / group}}; |
3689 | size_t fanIn = kdim.height * kdim.width * idim.c; |
3690 | ElemKind inputTy = input.getType()->getElementType(); |
3691 | assert((inputTy == ElemKind::FloatTy || inputTy == ElemKind::Float16Ty) && |
3692 | "Convolution on non-floating point type?" ); |
3693 | auto *filter = |
3694 | getParent()->createPlaceholder(inputTy, filterDim, "filter" , true); |
3695 | |
3696 | auto *bias = |
3697 | getParent()->createPlaceholder(inputTy, {outChannels}, "bias" , true); |
3698 | bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1, |
3699 | getPRNG()); |
3700 | |
3701 | bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn, |
3702 | getPRNG()); |
3703 | |
3704 | auto OT = getParent()->uniqueType(inputTy, outDims); |
3705 | |
3706 | return addNode(new ConvTransposeNode(name, OT, input, filter, bias, kernels, |
3707 | strides, pads, group, dilation)); |
3708 | } |
3709 | |
3710 | ConvTransposeNode *Function::createConvTranspose( |
3711 | PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input, |
3712 | dim_t outChannels, unsigned_t kernel, unsigned_t stride, unsigned_t pad, |
3713 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation) { |
3714 | llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad}; |
3715 | llvm::SmallVector<unsigned_t, 2> strides = {stride, stride}; |
3716 | llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel}; |
3717 | return createConvTranspose(bindings, name, input, outChannels, kernels, |
3718 | strides, pads, group, dilation); |
3719 | } |
3720 | |
3721 | ConvertToNode *Function::createConvertTo(llvm::StringRef name, NodeValue input, |
3722 | TypeRef outTy) { |
3723 | return addNode(new ConvertToNode(name, outTy, input)); |
3724 | } |
3725 | |
3726 | ConvertToNode *Function::createConvertTo(llvm::StringRef name, NodeValue input, |
3727 | ElemKind k) { |
3728 | auto OT = getParent()->uniqueType(k, input.dims()); |
3729 | return addNode(new ConvertToNode(name, OT, input)); |
3730 | } |
3731 | |
3732 | FullyConnectedNode * |
3733 | Function::createFullyConnected(PlaceholderBindings &bindings, |
3734 | llvm::StringRef name, NodeValue input, |
3735 | dim_t outDepth, unsigned_t axis) { |
3736 | const ElemKind k = input.getType()->getElementType(); |
3737 | |
3738 | // FC always uses 2D input; flatten if necessary. |
3739 | if (input.dims().size() != 2) { |
3740 | input = createFlatten(name.str() + ".reshape2D" , input, axis); |
3741 | } |
3742 | auto *W = getParent()->createPlaceholder(k, {input.dims()[1], outDepth}, |
3743 | "weights" , true); |
3744 | auto *B = getParent()->createPlaceholder(k, {outDepth}, "bias" , true); |
3745 | |
3746 | bindings.allocate(W)->init(Tensor::InitKind::Xavier, input.dims()[1], |
3747 | getPRNG()); |
3748 | bindings.allocate(B)->init(Tensor::InitKind::Broadcast, .1, getPRNG()); |
3749 | |
3750 | auto OT = getParent()->uniqueType(k, {input.dims()[0], outDepth}); |
3751 | return createFullyConnected(name, input, W, B, OT, axis); |
3752 | } |
3753 | |
3754 | Node *Function::createDotProduct(llvm::StringRef name, NodeValue X, |
3755 | NodeValue Y) { |
3756 | auto XDimsSize = X.dims().size(); |
3757 | (void)XDimsSize; |
3758 | |
3759 | assert(X.dims() == Y.dims() && "X and Y must have the same shape" ); |
3760 | assert(((XDimsSize == 1) || (XDimsSize == 2)) && "X and Y must be 1D or 2D" ); |
3761 | |
3762 | // Create Mul node. |
3763 | auto *MN = createMul(name.str() + ".mul" , X, Y); |
3764 | |
3765 | // If X and Y are 1D, the BatchedReduceAdd node is not needed. |
3766 | if (XDimsSize == 1) { |
3767 | return MN; |
3768 | } |
3769 | |
3770 | // Create and return BatchedReduceAdd node. |
3771 | return createBatchedReduceAdd(name.str() + ".bra" , MN, 1); |
3772 | } |
3773 | |
3774 | BatchedPairwiseDotProductNode * |
3775 | Function::createBatchedPairwiseDotProduct(llvm::StringRef name, |
3776 | llvm::ArrayRef<NodeValue> inputs) { |
3777 | assert(!inputs.empty()); |
3778 | dim_t batchCount = inputs[0].getType()->dims()[0]; |
3779 | dim_t numPairs = inputs.size() * (inputs.size() - 1) / 2; |
3780 | auto *outTy = getParent()->uniqueTypeWithNewShape(inputs[0].getType(), |
3781 | {batchCount, numPairs}); |
3782 | |
3783 | return addNode(new BatchedPairwiseDotProductNode(name, outTy, inputs)); |
3784 | } |
3785 | |
3786 | Node *Function::createElementwiseLinear(llvm::StringRef name, NodeValue X, |
3787 | NodeValue w, NodeValue b, |
3788 | unsigned axis) { |
3789 | auto XDims = X.dims(); |
3790 | auto wDims = w.dims(); |
3791 | auto bDims = b.dims(); |
3792 | |
3793 | // Suppress release mode unused variable warnings. |
3794 | (void)wDims; |
3795 | (void)bDims; |
3796 | |
3797 | // Check that the inputs are sensible. |
3798 | assert(XDims.size() == 2 && "X must be 2D" ); |
3799 | assert((axis == 0 || axis == 1) && "axis must be 0 or 1" ); |
3800 | assert(wDims.size() == 1 && "w must be 1D" ); |
3801 | assert(bDims.size() == 1 && "b must be 1D" ); |
3802 | assert(wDims[0] == XDims[axis] && |
3803 | "size of w must match input dimension of X" ); |
3804 | assert(bDims[0] == XDims[axis] && |
3805 | "size of b must match input dimension of X" ); |
3806 | |
3807 | // Broadcast w and b so that they have the same dimensions as X. |
3808 | auto *broadcastW = |
3809 | createBroadcast(name.str() + ".broadcastW" , w, XDims, axis); |
3810 | auto *broadcastB = |
3811 | createBroadcast(name.str() + ".broadcastB" , b, XDims, axis); |
3812 | |
3813 | // Implement the elementwise linear operation by multiplying X elementwise |
3814 | // with broadcasted w and adding broadcasted b elementwise. |
3815 | auto *wX = createMul(name.str() + ".mul" , broadcastW, X); |
3816 | auto *out = createAdd(name.str() + ".add" , wX, broadcastB); |
3817 | |
3818 | return out; |
3819 | } |
3820 | |
3821 | void Function::createGRU(PlaceholderBindings &bindings, |
3822 | llvm::StringRef namePrefix, |
3823 | llvm::ArrayRef<NodeValue> inputs, unsigned batchSize, |
3824 | unsigned hiddenSize, unsigned outputSize, |
3825 | std::vector<NodeValue> &outputs) { |
3826 | std::string nameBase = namePrefix.str(); |
3827 | const unsigned timeSteps = inputs.size(); |
3828 | assert(timeSteps > 0 && "empty input" ); |
3829 | const unsigned inputSize = inputs.front().dims().back(); |
3830 | assert(inputSize > 0 && "input dimensionality is zero" ); |
3831 | |
3832 | // Initialize the state to zero. |
3833 | Placeholder *HInit = getParent()->createPlaceholder( |
3834 | ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_state" , false); |
3835 | bindings.allocate(HInit)->zero(); |
3836 | Node *Ht = HInit; |
3837 | |
3838 | // Update gate: |
3839 | // Z <- sigmoid(Wxz * x + Whz * h + bz) |
3840 | // Reset gate: |
3841 | // R <- sigmoid(Wxr * x + Whr * h + br) |
3842 | // Hidden state: |
3843 | // h <- Z . h + (1 - Z) tanh (Wxh * x + Whh * (R . h) + bh) |
3844 | |
3845 | // update gate |
3846 | float bUpdate = 0.1; |
3847 | Placeholder *Wxz = getParent()->createPlaceholder( |
3848 | ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxz" , true); |
3849 | Placeholder *Whz = getParent()->createPlaceholder( |
3850 | ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whz" , true); |
3851 | Placeholder *Bz1 = getParent()->createPlaceholder( |
3852 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz1" , true); |
3853 | Placeholder *Bz2 = getParent()->createPlaceholder( |
3854 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz2" , true); |
3855 | |
3856 | bindings.allocate(Wxz)->init(glow::Tensor::InitKind::Xavier, inputSize, |
3857 | getPRNG()); |
3858 | bindings.allocate(Whz)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
3859 | getPRNG()); |
3860 | bindings.allocate(Bz1)->init(glow::Tensor::InitKind::Broadcast, bUpdate, |
3861 | getPRNG()); |
3862 | bindings.allocate(Bz2)->init(glow::Tensor::InitKind::Broadcast, bUpdate, |
3863 | getPRNG()); |
3864 | |
3865 | // Reset gate. |
3866 | float bReset = -1.0; |
3867 | Placeholder *Wxr = getParent()->createPlaceholder( |
3868 | ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxr" , true); |
3869 | Placeholder *Whr = getParent()->createPlaceholder( |
3870 | ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whr" , true); |
3871 | Placeholder *Br1 = getParent()->createPlaceholder( |
3872 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".br1" , true); |
3873 | Placeholder *Br2 = getParent()->createPlaceholder( |
3874 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".br2" , true); |
3875 | |
3876 | bindings.allocate(Wxr)->init(glow::Tensor::InitKind::Xavier, inputSize, |
3877 | getPRNG()); |
3878 | bindings.allocate(Whr)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
3879 | getPRNG()); |
3880 | bindings.allocate(Br1)->init(glow::Tensor::InitKind::Broadcast, bReset, |
3881 | getPRNG()); |
3882 | bindings.allocate(Br2)->init(glow::Tensor::InitKind::Broadcast, bReset, |
3883 | getPRNG()); |
3884 | |
3885 | // hidden state |
3886 | float b = 0.1; |
3887 | Placeholder *Wxh = getParent()->createPlaceholder( |
3888 | ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh" , true); |
3889 | Placeholder *Whh = getParent()->createPlaceholder( |
3890 | ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh" , true); |
3891 | Placeholder *Bh1 = getParent()->createPlaceholder( |
3892 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh1" , true); |
3893 | Placeholder *Bh2 = getParent()->createPlaceholder( |
3894 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh2" , true); |
3895 | |
3896 | bindings.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize, |
3897 | getPRNG()); |
3898 | bindings.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
3899 | getPRNG()); |
3900 | bindings.allocate(Bh1)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); |
3901 | bindings.allocate(Bh2)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); |
3902 | |
3903 | // Output Layer. |
3904 | Placeholder *Why = getParent()->createPlaceholder( |
3905 | ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why" , true); |
3906 | Placeholder *By = getParent()->createPlaceholder( |
3907 | ElemKind::FloatTy, {outputSize}, nameBase + ".by" , true); |
3908 | |
3909 | bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
3910 | getPRNG()); |
3911 | bindings.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); |
3912 | |
3913 | auto ty = getParent()->uniqueType(ElemKind::FloatTy, {batchSize, hiddenSize}); |
3914 | auto *Ones = createSplat(nameBase + ".ones" , ty, 1.0); |
3915 | |
3916 | std::vector<Node *> outputNodes; |
3917 | for (unsigned t = 0; t < timeSteps; t++) { |
3918 | auto fc1Name = nameBase + ".fc1." + std::to_string(t); |
3919 | auto fc2Name = nameBase + ".fc2." + std::to_string(t); |
3920 | auto add1Name = nameBase + ".add1." + std::to_string(t); |
3921 | auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t); |
3922 | |
3923 | auto *Zt = createSigmoid( |
3924 | sigmoid1Name, |
3925 | createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whz, Bz1), |
3926 | createFullyConnected(fc2Name, inputs[t], Wxz, Bz2))); |
3927 | |
3928 | auto fc3Name = nameBase + ".fc3." + std::to_string(t); |
3929 | auto fc4Name = nameBase + ".fc4." + std::to_string(t); |
3930 | auto add2Name = nameBase + ".add2." + std::to_string(t); |
3931 | auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t); |
3932 | |
3933 | auto *Rt = createSigmoid( |
3934 | sigmoid2Name, |
3935 | createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whr, Br1), |
3936 | createFullyConnected(fc4Name, inputs[t], Wxr, Br2))); |
3937 | |
3938 | auto zhtName = nameBase + ".zh." + std::to_string(t); |
3939 | auto *ZHt = createMul(zhtName, Zt, Ht); |
3940 | |
3941 | auto oneMinusZtName = nameBase + ".1-z." + std::to_string(t); |
3942 | auto *OneMinusZt = createSub(oneMinusZtName, Ones, Zt); |
3943 | |
3944 | auto rhtName = nameBase + ".rh." + std::to_string(t); |
3945 | auto *RHt = createMul(rhtName, Rt, Ht); |
3946 | |
3947 | auto fc5Name = nameBase + ".fc5." + std::to_string(t); |
3948 | auto fc6Name = nameBase + ".fc6." + std::to_string(t); |
3949 | auto add3Name = nameBase + ".add3." + std::to_string(t); |
3950 | auto tanh1Name = nameBase + ".tanh1." + std::to_string(t); |
3951 | |
3952 | auto *Ut = createTanh( |
3953 | tanh1Name, |
3954 | createAdd(add3Name, createFullyConnected(fc5Name, RHt, Whh, Bh1), |
3955 | createFullyConnected(fc6Name, inputs[t], Wxh, Bh2))); |
3956 | |
3957 | auto oneMinusZtUtName = nameBase + "1.-zu." + std::to_string(t); |
3958 | auto *OneMinusZtUt = createMul(oneMinusZtUtName, OneMinusZt, Ut); |
3959 | |
3960 | auto htName = nameBase + ".H." + std::to_string(t); |
3961 | Ht = createAdd(htName, ZHt, OneMinusZtUt); |
3962 | |
3963 | auto outName = nameBase + ".out." + std::to_string(t); |
3964 | auto *O = createFullyConnected(outName, Ht, Why, By); |
3965 | outputs.push_back(O); |
3966 | } |
3967 | } |
3968 | |
3969 | void Function::createSimpleRNN(PlaceholderBindings &bindings, |
3970 | llvm::StringRef namePrefix, |
3971 | llvm::ArrayRef<NodeValue> inputs, |
3972 | unsigned batchSize, unsigned hiddenSize, |
3973 | unsigned outputSize, |
3974 | std::vector<NodeValue> &outputs) { |
3975 | std::string nameBase = namePrefix.str(); |
3976 | const unsigned timeSteps = inputs.size(); |
3977 | assert(timeSteps > 0 && "empty input" ); |
3978 | const unsigned inputSize = inputs.front().dims().back(); |
3979 | assert(inputSize > 0 && "input dimensionality is zero" ); |
3980 | |
3981 | // Initialize the state to zero. |
3982 | Placeholder *HInit = |
3983 | getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize}, |
3984 | nameBase + ".initial_state" , false); |
3985 | bindings.allocate(HInit)->zero(); |
3986 | Node *Ht = HInit; |
3987 | |
3988 | float b = 0.1; |
3989 | Placeholder *Whh = getParent()->createPlaceholder( |
3990 | ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh" , true); |
3991 | Placeholder *Bhh = getParent()->createPlaceholder( |
3992 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bhh" , true); |
3993 | Placeholder *Wxh = getParent()->createPlaceholder( |
3994 | ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh" , true); |
3995 | |
3996 | Placeholder *Bxh = getParent()->createPlaceholder( |
3997 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bxh" , true); |
3998 | Placeholder *Why = getParent()->createPlaceholder( |
3999 | ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why" , true); |
4000 | Placeholder *Bhy = getParent()->createPlaceholder( |
4001 | ElemKind::FloatTy, {outputSize}, nameBase + ".Bhy" , true); |
4002 | |
4003 | bindings.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
4004 | getPRNG()); |
4005 | bindings.allocate(Bhh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); |
4006 | bindings.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize, |
4007 | getPRNG()); |
4008 | bindings.allocate(Bxh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); |
4009 | bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
4010 | getPRNG()); |
4011 | bindings.allocate(Bhy)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); |
4012 | |
4013 | // Un-roll backpropogation through time as a loop with the shared |
4014 | // parameters. |
4015 | for (unsigned t = 0; t < timeSteps; t++) { |
4016 | auto fc1Name = nameBase + ".fc1." + std::to_string(t); |
4017 | auto *FC1 = createFullyConnected(fc1Name, Ht, Whh, Bhh); |
4018 | auto fc2Name = nameBase + ".fc2." + std::to_string(t); |
4019 | auto *FC2 = createFullyConnected(fc2Name, inputs[t], Wxh, Bxh); |
4020 | auto aName = nameBase + ".add." + std::to_string(t); |
4021 | auto *A = createAdd(aName, FC1, FC2); |
4022 | auto tanhName = nameBase + ".tanh." + std::to_string(t); |
4023 | auto *H = createTanh(tanhName, A); |
4024 | auto outName = nameBase + ".out." + std::to_string(t); |
4025 | auto *O = createFullyConnected(outName, H, Why, Bhy); |
4026 | outputs.push_back(O); |
4027 | |
4028 | Ht = H; |
4029 | }; |
4030 | } |
4031 | |
4032 | LSTMUnitNode *Function::createLSTMUnit(llvm::StringRef namePrefix, |
4033 | NodeValue Input, NodeValue C) { |
4034 | |
4035 | return addNode(new LSTMUnitNode(namePrefix, Input, C)); |
4036 | } |
4037 | |
4038 | template <class T> |
4039 | std::vector<NodeValue> Function::createSingleDirectionLSTM( |
4040 | std::string nameBase, T inputItr, const int timeSteps, NodeValue Wx, |
4041 | NodeValue Wh, NodeValue Bx, NodeValue Bh, NodeValue &H, NodeValue &C) { |
4042 | |
4043 | std::vector<NodeValue> Hs; |
4044 | auto name = [&nameBase](const char *s, int t) { |
4045 | return strFormat("%s.%s_%d" , nameBase.c_str(), s, t); |
4046 | }; |
4047 | for (int t = 0; t < timeSteps; t++, inputItr++) { |
4048 | |
4049 | auto *result = createAdd( |
4050 | name("add" , t), createFullyConnected(name("fc_1" , t), H, Wh, Bh), |
4051 | createFullyConnected(name("fc_2" , t), *inputItr, Wx, Bx)); |
4052 | |
4053 | auto lstmUnitNode = |
4054 | addNode(new LSTMUnitNode(name("lstm_unit" , t), result, C)); |
4055 | H = lstmUnitNode->getNthResult(0); |
4056 | C = lstmUnitNode->getNthResult(1); |
4057 | |
4058 | Hs.push_back( |
4059 | createReshape(name("H_reshape" , t), H, {1, H.dims()[0], H.dims()[1]}) |
4060 | ->getResult()); |
4061 | } |
4062 | return Hs; |
4063 | } |
4064 | |
4065 | std::vector<NodeValue> Function::createMultipleLayerSingleDirectionLSTM( |
4066 | std::string nameBase, NodeValue input, unsigned batchSize, |
4067 | unsigned inputSize, const int timeSteps, std::vector<NodeValue> &Wx, |
4068 | std::vector<NodeValue> &Wh, std::vector<NodeValue> &Bx, |
4069 | std::vector<NodeValue> &Bh, NodeValue &H, NodeValue &C) { |
4070 | |
4071 | assert(Wx.size() > 0 && Wh.size() > 0 && Bx.size() > 0 && Bh.size() > 0 && |
4072 | "Wx, Wh, Bx and Bh should be non empty vectors" ); |
4073 | |
4074 | auto numLayers = Wx.size(); |
4075 | NodeValue temp_input = input; |
4076 | std::vector<NodeValue> Hs; |
4077 | auto name = [&nameBase](const char *s, int t) { |
4078 | return strFormat("%s.%s_%d" , nameBase.c_str(), s, t); |
4079 | }; |
4080 | std::vector<NodeValue> Hv, Cv; |
4081 | auto reshape2dto3d = [=](NodeValue n, std::string nameStr) { |
4082 | return createReshape(nameStr, n, {1, n.dims()[0], n.dims()[1]}); |
4083 | }; |
4084 | for (unsigned layer = 0; layer < numLayers; layer++) { |
4085 | auto slidedInputs = createSlicedInput(temp_input, nameBase, batchSize, |
4086 | inputSize, timeSteps); |
4087 | auto Hn = |
4088 | createReshape(name("reshape_hn" , layer), |
4089 | createSlice(name("slice_hn" , layer), H, {layer, 0, 0}, |
4090 | {layer + 1, H.dims()[1], H.dims()[2]}) |
4091 | ->getResult(), |
4092 | {H.dims()[1], H.dims()[2]}) |
4093 | ->getResult(); |
4094 | auto Cn = |
4095 | createReshape(name("reshape_cn" , layer), |
4096 | createSlice(name("slice_cn" , layer), C, {layer, 0, 0}, |
4097 | {layer + 1, C.dims()[1], C.dims()[2]}) |
4098 | ->getResult(), |
4099 | {C.dims()[1], C.dims()[2]}) |
4100 | ->getResult(); |
4101 | Hs = createSingleDirectionLSTM(nameBase, slidedInputs.begin(), timeSteps, |
4102 | Wx[layer], Wh[layer], Bx[layer], Bh[layer], |
4103 | Hn, Cn); |
4104 | temp_input = createConcat(nameBase + "_lstm_temp_input" , Hs, 0); |
4105 | Hv.emplace_back(reshape2dto3d(Hn, "_lstm_hv" )); |
4106 | Cv.emplace_back(reshape2dto3d(Cn, "_lstm_cv" )); |
4107 | inputSize = temp_input.dims()[2]; |
4108 | } |
4109 | H = createConcat(nameBase + "_lstm_h_output" , Hv, 0); |
4110 | C = createConcat(nameBase + "_lstm_c_output" , Cv, 0); |
4111 | return Hs; |
4112 | } |
4113 | |
4114 | std::vector<NodeValue> Function::createSlicedInput(NodeValue input, |
4115 | std::string &nameBase, |
4116 | unsigned batchSize, |
4117 | unsigned inputSize, |
4118 | const int timeSteps) { |
4119 | std::vector<NodeValue> inputs; |
4120 | auto name = [&nameBase](const char *s, int t) { |
4121 | return strFormat("%s.%s_%d" , nameBase.c_str(), s, t); |
4122 | }; |
4123 | for (unsigned t = 0; t < timeSteps; t++) { |
4124 | auto inputSliced = createSlice(name("slice" , t), input, {t, 0, 0}, |
4125 | {t + 1, batchSize, inputSize}) |
4126 | ->getResult(); |
4127 | inputSliced = |
4128 | createReshape(name("reshape" , t), inputSliced, {batchSize, inputSize}) |
4129 | ->getResult(); |
4130 | inputs.push_back(inputSliced); |
4131 | } |
4132 | return inputs; |
4133 | } |
4134 | |
4135 | void Function::createPyTorchLSTM(llvm::StringRef namePrefix, NodeValue input, |
4136 | std::vector<NodeValue> &Wx, |
4137 | std::vector<NodeValue> &Wh, |
4138 | std::vector<NodeValue> &Bx, |
4139 | std::vector<NodeValue> &Bh, NodeValue &Ht, |
4140 | NodeValue &Ct, NodeValue &output, |
4141 | bool isBidirectional, NodeValue WxR, |
4142 | NodeValue WhR, NodeValue BxR, NodeValue BhR) { |
4143 | std::string nameBase = namePrefix.str(); |
4144 | assert(input.dims().back() > 0 && "input dimensionality is zero" ); |
4145 | assert((!isBidirectional || WxR != NodeValue(nullptr)) && |
4146 | "Bidirectional LSTM must provide reverse weights & biases" ); |
4147 | assert(Wx.size() > 0 && Wh.size() > 0 && Bx.size() > 0 && Bh.size() > 0 && |
4148 | "Wx, Wh, Bx and Bh should be non empty vectors" ); |
4149 | |
4150 | std::vector<NodeValue> inputs, outputs; |
4151 | unsigned batchSize, inputSize, timeSteps, hiddenSize; |
4152 | batchSize = input.dims()[1]; |
4153 | inputSize = input.dims()[2]; |
4154 | if (Wh.size() == 1) { |
4155 | hiddenSize = Wh[0].dims()[0]; |
4156 | } else { |
4157 | hiddenSize = Wh[0].dims()[1]; |
4158 | } |
4159 | |
4160 | timeSteps = input.dims()[0]; |
4161 | |
4162 | // Input gate: |
4163 | // I <- sigmoid(Wxi * x + Bxi + Whi * h + Bhi) |
4164 | // Forget gate: |
4165 | // F <- sigmoid(Wxf * x + Bxf + Whf * h + Bhf) |
4166 | // Cell gate: |
4167 | // G <- tanh(Wxg * x + Bxg + Whg * h + Bhg) |
4168 | // Output gate: |
4169 | // O <- sigmoid(Wxo * x + Bxo + Who * h + Bho) |
4170 | // Cell state: |
4171 | // C <- F . C + I . G |
4172 | // Hidden state: |
4173 | // h <- O . tanh(C) |
4174 | |
4175 | if (isBidirectional) { |
4176 | // For bidirectional LSTM, we split H and C to two part, each direction a |
4177 | // part. For each part we calculate them separately. |
4178 | NodeValue Hforward, Hbackward, Cforward, Cbackward; |
4179 | Hforward = createReshape("reshape_hforward" , |
4180 | createSlice("slice_hforward" , Ht, {0, 0, 0}, |
4181 | {1, batchSize, hiddenSize}), |
4182 | {batchSize, hiddenSize}) |
4183 | ->getResult(); |
4184 | Hbackward = createReshape("reshape_hbackward" , |
4185 | createSlice("slice_hbackward" , Ht, {1, 0, 0}, |
4186 | {2, batchSize, hiddenSize}), |
4187 | {batchSize, hiddenSize}) |
4188 | ->getResult(); |
4189 | Cforward = createReshape("reshape_cforward" , |
4190 | createSlice("slice_cforward" , Ct, {0, 0, 0}, |
4191 | {1, batchSize, hiddenSize}), |
4192 | {batchSize, hiddenSize}) |
4193 | ->getResult(); |
4194 | Cbackward = createReshape("reshape_cbackward" , |
4195 | createSlice("slice_cbackward" , Ct, {1, 0, 0}, |
4196 | {2, batchSize, hiddenSize}), |
4197 | {batchSize, hiddenSize}) |
4198 | ->getResult(); |
4199 | |
4200 | auto slicedInputs = |
4201 | createSlicedInput(input, nameBase, batchSize, inputSize, timeSteps); |
4202 | auto outputForwards = |
4203 | createSingleDirectionLSTM<std::vector<NodeValue>::iterator>( |
4204 | nameBase + "_lstm_forward" , slicedInputs.begin(), timeSteps, Wx[0], |
4205 | Wh[0], Bx[0], Bh[0], Hforward, Cforward); |
4206 | auto outputBackwards = |
4207 | createSingleDirectionLSTM<std::vector<NodeValue>::reverse_iterator>( |
4208 | nameBase + "_lstm_backward" , slicedInputs.rbegin(), timeSteps, WxR, |
4209 | WhR, BxR, BhR, Hbackward, Cbackward); |
4210 | std::reverse(outputBackwards.begin(), outputBackwards.end()); |
4211 | NodeValue outputForward = createConcat( |
4212 | nameBase + "_lstm_forward_output_concat" , outputForwards, 0); |
4213 | NodeValue outputBackward = createConcat( |
4214 | nameBase + "_lstm_backward_output_concat" , outputBackwards, 0); |
4215 | output = |
4216 | createConcat("final_output_concat" , {outputForward, outputBackward}, 2) |
4217 | ->getResult(); |
4218 | |
4219 | auto reshape2dto3d = [=](NodeValue n, std::string nameStr) { |
4220 | return createReshape(nameStr, n, {1, n.dims()[0], n.dims()[1]}); |
4221 | }; |
4222 | Ht = createConcat("Ht_Concat" , |
4223 | {reshape2dto3d(Hforward, "final_Hforward_reshape" ), |
4224 | reshape2dto3d(Hbackward, "final_Hbackward_reshape" )}, |
4225 | 0) |
4226 | ->getResult(); |
4227 | Ct = createConcat("Ct_Concat" , |
4228 | {reshape2dto3d(Cforward, "final_Cforward_reshape" ), |
4229 | reshape2dto3d(Cbackward, "final_Cbackward_reshape" )}, |
4230 | 0) |
4231 | ->getResult(); |
4232 | |
4233 | } else { |
4234 | if (Ht.dims().size() == 2) { |
4235 | auto slicedInputs = |
4236 | createSlicedInput(input, nameBase, batchSize, inputSize, timeSteps); |
4237 | outputs = createSingleDirectionLSTM<std::vector<NodeValue>::iterator>( |
4238 | nameBase + "_lstm" , slicedInputs.begin(), timeSteps, Wx[0], Wh[0], |
4239 | Bx[0], Bh[0], Ht, Ct); |
4240 | } else { |
4241 | outputs = createMultipleLayerSingleDirectionLSTM( |
4242 | nameBase + "_lstm" , input, batchSize, inputSize, timeSteps, Wx, Wh, |
4243 | Bx, Bh, Ht, Ct); |
4244 | } |
4245 | output = createConcat(nameBase + "_lstm_output_concat" , outputs, 0); |
4246 | } |
4247 | }; |
4248 | |
4249 | void Function::createLSTM(PlaceholderBindings &bindings, |
4250 | llvm::StringRef namePrefix, |
4251 | llvm::ArrayRef<NodeValue> inputs, unsigned batchSize, |
4252 | unsigned hiddenSize, unsigned outputSize, |
4253 | std::vector<NodeValue> &outputs) { |
4254 | std::string nameBase = namePrefix.str(); |
4255 | const unsigned timeSteps = inputs.size(); |
4256 | assert(timeSteps > 0 && "empty input" ); |
4257 | const unsigned inputSize = inputs.front().dims().back(); |
4258 | assert(inputSize > 0 && "input dimensionality is zero" ); |
4259 | |
4260 | // Initialize the hidden and cell states to zero. |
4261 | Placeholder *HInit = |
4262 | getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize}, |
4263 | "initial_hidden_state" , false); |
4264 | bindings.allocate(HInit)->zero(); |
4265 | Node *Ht = HInit; |
4266 | |
4267 | Placeholder *CInit = getParent()->createPlaceholder( |
4268 | ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_cell_state" , false); |
4269 | bindings.allocate(CInit)->zero(); |
4270 | Node *Ct = CInit; |
4271 | |
4272 | // Forget gate: |
4273 | // F <- sigmoid(Wxf * x + Whf * h + bf) |
4274 | // Input gate: |
4275 | // I <- sigmoid(Wxi * x + Whi * h + bi) |
4276 | // Output gate: |
4277 | // O <- sigmoid(Wxo * x + Who * h + bi) |
4278 | // Cell state: |
4279 | // C <- F . C + I . tanh(Wxc * x + Whc * h + bc) |
4280 | // Hidden state: |
4281 | // h <- O . tanh(C) |
4282 | |
4283 | // forget gate |
4284 | float bForget = 1.0; |
4285 | Placeholder *Wxf = getParent()->createPlaceholder( |
4286 | ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxf" , true); |
4287 | Placeholder *Whf = getParent()->createPlaceholder( |
4288 | ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whf" , true); |
4289 | Placeholder *Bf1 = getParent()->createPlaceholder( |
4290 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf1" , true); |
4291 | Placeholder *Bf2 = getParent()->createPlaceholder( |
4292 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf2" , true); |
4293 | bindings.allocate(Wxf)->init(glow::Tensor::InitKind::Xavier, inputSize, |
4294 | getPRNG()); |
4295 | bindings.allocate(Whf)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
4296 | getPRNG()); |
4297 | bindings.allocate(Bf1)->init(glow::Tensor::InitKind::Broadcast, bForget, |
4298 | getPRNG()); |
4299 | bindings.allocate(Bf2)->init(glow::Tensor::InitKind::Broadcast, bForget, |
4300 | getPRNG()); |
4301 | |
4302 | // input gate |
4303 | float bInput = 0.1; |
4304 | Placeholder *Wxi = getParent()->createPlaceholder( |
4305 | ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxi" , true); |
4306 | Placeholder *Whi = getParent()->createPlaceholder( |
4307 | ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whi" , true); |
4308 | Placeholder *Bi1 = getParent()->createPlaceholder( |
4309 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi1" , true); |
4310 | Placeholder *Bi2 = getParent()->createPlaceholder( |
4311 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi2" , true); |
4312 | |
4313 | bindings.allocate(Wxi)->init(glow::Tensor::InitKind::Xavier, inputSize, |
4314 | getPRNG()); |
4315 | bindings.allocate(Whi)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
4316 | getPRNG()); |
4317 | bindings.allocate(Bi1)->init(glow::Tensor::InitKind::Broadcast, bInput, |
4318 | getPRNG()); |
4319 | bindings.allocate(Bi2)->init(glow::Tensor::InitKind::Broadcast, bInput, |
4320 | getPRNG()); |
4321 | |
4322 | // output gate |
4323 | float bOutput = 0.1; |
4324 | Placeholder *Wxo = getParent()->createPlaceholder( |
4325 | ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxo" , true); |
4326 | Placeholder *Who = getParent()->createPlaceholder( |
4327 | ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Who" , true); |
4328 | Placeholder *Bo1 = getParent()->createPlaceholder( |
4329 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo1" , true); |
4330 | Placeholder *Bo2 = getParent()->createPlaceholder( |
4331 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo2" , true); |
4332 | |
4333 | bindings.allocate(Wxo)->init(glow::Tensor::InitKind::Xavier, inputSize, |
4334 | getPRNG()); |
4335 | bindings.allocate(Who)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
4336 | getPRNG()); |
4337 | bindings.allocate(Bo1)->init(glow::Tensor::InitKind::Broadcast, bOutput, |
4338 | getPRNG()); |
4339 | bindings.allocate(Bo2)->init(glow::Tensor::InitKind::Broadcast, bOutput, |
4340 | getPRNG()); |
4341 | |
4342 | // cell state |
4343 | float bCell = 0.1; |
4344 | Placeholder *Wxc = getParent()->createPlaceholder( |
4345 | ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxc" , true); |
4346 | Placeholder *Whc = getParent()->createPlaceholder( |
4347 | ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whc" , true); |
4348 | Placeholder *Bc1 = getParent()->createPlaceholder( |
4349 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc1" , true); |
4350 | Placeholder *Bc2 = getParent()->createPlaceholder( |
4351 | ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc2" , true); |
4352 | |
4353 | bindings.allocate(Wxc)->init(glow::Tensor::InitKind::Xavier, inputSize, |
4354 | getPRNG()); |
4355 | bindings.allocate(Whc)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
4356 | getPRNG()); |
4357 | bindings.allocate(Bc1)->init(glow::Tensor::InitKind::Broadcast, bCell, |
4358 | getPRNG()); |
4359 | bindings.allocate(Bc2)->init(glow::Tensor::InitKind::Broadcast, bCell, |
4360 | getPRNG()); |
4361 | |
4362 | // output layer |
4363 | float b = 0.1; |
4364 | Placeholder *Why = getParent()->createPlaceholder( |
4365 | ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why" , true); |
4366 | Placeholder *By = getParent()->createPlaceholder( |
4367 | ElemKind::FloatTy, {outputSize}, nameBase + ".by" , true); |
4368 | |
4369 | bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize, |
4370 | getPRNG()); |
4371 | bindings.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); |
4372 | |
4373 | std::vector<Node *> outputNodes; |
4374 | for (unsigned t = 0; t < timeSteps; t++) { |
4375 | auto fc1Name = nameBase + ".fc1." + std::to_string(t); |
4376 | auto fc2Name = nameBase + ".fc2." + std::to_string(t); |
4377 | auto add1Name = nameBase + ".add1." + std::to_string(t); |
4378 | auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t); |
4379 | |
4380 | auto *Ft = createSigmoid( |
4381 | sigmoid1Name, |
4382 | createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whf, Bf1), |
4383 | createFullyConnected(fc2Name, inputs[t], Wxf, Bf2))); |
4384 | |
4385 | auto fc3Name = nameBase + ".fc3." + std::to_string(t); |
4386 | auto fc4Name = nameBase + ".fc4." + std::to_string(t); |
4387 | auto add2Name = nameBase + ".add2." + std::to_string(t); |
4388 | auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t); |
4389 | |
4390 | auto *It = createSigmoid( |
4391 | sigmoid2Name, |
4392 | createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whi, Bi1), |
4393 | createFullyConnected(fc4Name, inputs[t], Wxi, Bi2))); |
4394 | |
4395 | auto fc5Name = nameBase + ".fc5." + std::to_string(t); |
4396 | auto fc6Name = nameBase + ".fc6." + std::to_string(t); |
4397 | auto add3Name = nameBase + ".add3." + std::to_string(t); |
4398 | auto sigmoid3Name = nameBase + ".sigmoid3." + std::to_string(t); |
4399 | |
4400 | auto *Ot = createSigmoid( |
4401 | sigmoid3Name, |
4402 | createAdd(add3Name, createFullyConnected(fc5Name, Ht, Who, Bo1), |
4403 | createFullyConnected(fc6Name, inputs[t], Wxo, Bo2))); |
4404 | |
4405 | auto fc7Name = nameBase + ".fc7." + std::to_string(t); |
4406 | auto fc8Name = nameBase + ".fc8." + std::to_string(t); |
4407 | auto add4Name = nameBase + ".add4." + std::to_string(t); |
4408 | auto tanh1Name = nameBase + ".tanh1." + std::to_string(t); |
4409 | |
4410 | auto *CRt = createTanh( |
4411 | tanh1Name, |
4412 | createAdd(add4Name, createFullyConnected(fc7Name, Ht, Whc, Bc1), |
4413 | createFullyConnected(fc8Name, inputs[t], Wxc, Bc2))); |
4414 | |
4415 | auto mul1Name = nameBase + ".mul1." + std::to_string(t); |
4416 | auto mul2Name = nameBase + ".mul2." + std::to_string(t); |
4417 | Ct = createAdd(nameBase + ".C." + std::to_string(t), |
4418 | createMul(mul1Name, Ft, Ct), createMul(mul2Name, It, CRt)); |
4419 | |
4420 | auto htName = nameBase + ".H." + std::to_string(t); |
4421 | auto tanh2Name = nameBase + ".tanh2." + std::to_string(t); |
4422 | Ht = createMul(htName, Ot, createTanh(tanh2Name, Ct)); |
4423 | |
4424 | auto outName = nameBase + ".out." + std::to_string(t); |
4425 | auto *O = createFullyConnected(outName, Ht, Why, By); |
4426 | outputs.push_back(O); |
4427 | } |
4428 | }; |
4429 | |
4430 | void Function::createOnnxRNN(llvm::StringRef namePrefix, NodeValue X, |
4431 | NodeValue W, NodeValue R, NodeValue B, |
4432 | NodeValue initial_h, NodeValue &Y, NodeValue &Y_h, |
4433 | unsigned hiddenSize, RnnDirection direction, |
4434 | std::vector<RnnActivation> &activations) { |
4435 | |
4436 | #define RNN_X_SLICE_RANGE(idx) \ |
4437 | {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize } |
4438 | #define RNN_W_SLICE_RANGE(idx0, idx1) \ |
4439 | {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize } |
4440 | #define RNN_R_SLICE_RANGE(idx0, idx1) \ |
4441 | {idx0, idx1 * hiddenSize, 0}, { \ |
4442 | idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \ |
4443 | } |
4444 | #define RNN_B_SLICE_RANGE(idx0, idx1) \ |
4445 | {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize } |
4446 | #define RNN_H_SLICE_RANGE(idx) \ |
4447 | {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize } |
4448 | #define RNN_CREATE_FC(name, LHS, RHS, BIAS) \ |
4449 | BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \ |
4450 | : (Node *)createMatMul(name, LHS, RHS) |
4451 | |
4452 | // Operator name. |
4453 | const std::string &opName = namePrefix.str(); |
4454 | |
4455 | // Get all size parameters. |
4456 | dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1; |
4457 | assert(X.dims().size() == 3 && |
4458 | "ONNX RNN input 'X' should have 3 dimensions!" ); |
4459 | dim_t seqLength = X.dims()[0]; |
4460 | dim_t batchSize = X.dims()[1]; |
4461 | dim_t inputSize = X.dims()[2]; |
4462 | |
4463 | // Validate W size. |
4464 | assert(W.dims().size() == 3 && |
4465 | "ONNX RNN input 'W' should have 3 dimensions!" ); |
4466 | assert(W.dims()[0] == numDirections && W.dims()[1] == hiddenSize && |
4467 | W.dims()[2] == inputSize && "ONNX RNN 'W' tensor size invalid!" ); |
4468 | |
4469 | // Validate R size. |
4470 | assert(R.dims().size() == 3 && |
4471 | "ONNX RNN input 'R' should have 3 dimensions!" ); |
4472 | assert(R.dims()[0] == numDirections && R.dims()[1] == hiddenSize && |
4473 | R.dims()[2] == hiddenSize && "ONNX RNN 'R' tensor size invalid!" ); |
4474 | |
4475 | // Validate B size. |
4476 | if (B.getNode()) { |
4477 | assert(B.dims().size() == 2 && |
4478 | "ONNX RNN input 'B' should have 2 dimensions!" ); |
4479 | assert(B.dims()[0] == numDirections && B.dims()[1] == 2 * hiddenSize && |
4480 | "ONNX RNN 'B' tensor size invalid!" ); |
4481 | } |
4482 | |
4483 | // Validate initial_h size if given else create Splat with 0. |
4484 | if (initial_h.getNode()) { |
4485 | assert(initial_h.dims().size() == 3 && |
4486 | "ONNX RNN input 'initial_h' should have 2 dimensions!" ); |
4487 | assert(initial_h.dims()[0] == numDirections && |
4488 | initial_h.dims()[1] == batchSize && |
4489 | initial_h.dims()[2] == hiddenSize && |
4490 | "ONNX RNN 'initial_h' tensor size invalid!" ); |
4491 | } else { |
4492 | auto splatTy = getParent()->uniqueType( |
4493 | ElemKind::FloatTy, {numDirections, batchSize, hiddenSize}); |
4494 | initial_h = createSplat(opName + ".initial_h" , splatTy, 0.0); |
4495 | } |
4496 | |
4497 | // Validate number of activations. |
4498 | assert(activations.size() == numDirections * 1 && |
4499 | "ONNX RNN activations vector invalid!" ); |
4500 | |
4501 | // Create X slices. |
4502 | std::vector<Node *> Xslices; |
4503 | for (dim_t t = 0; t < seqLength; t++) { |
4504 | auto XsliceName = opName + ".X" + std::to_string(t) + ".slice" ; |
4505 | Node *Xt = createSlice(XsliceName, X, RNN_X_SLICE_RANGE(t)); |
4506 | auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape" ; |
4507 | Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize}); |
4508 | Xslices.push_back(Xt); |
4509 | } |
4510 | |
4511 | // Lambda to load forward/backward RNN cell. |
4512 | auto loadRNNCell = [&](bool forward, std::vector<NodeValue> &Yslices, |
4513 | NodeValue &Hslice) { |
4514 | // Name prefix. |
4515 | std::string dirLabel = forward ? ".fw" : ".bw" ; |
4516 | std::string prefix = opName + ((numDirections > 1) ? dirLabel : "" ); |
4517 | |
4518 | // Slice index used for creating weights slices. |
4519 | dim_t sliceIdx0 = 0; |
4520 | if (direction == RnnDirection::Bidirectional) { |
4521 | sliceIdx0 = forward ? 0 : 1; |
4522 | } |
4523 | |
4524 | // Activations. |
4525 | size_t activationOffset = sliceIdx0 * 1; |
4526 | auto activationF = activations[activationOffset + 0]; |
4527 | |
4528 | // Create W slice (Required). |
4529 | NodeValue Wi = |
4530 | createSlice(prefix + ".Wi." , W, RNN_W_SLICE_RANGE(sliceIdx0, 0)); |
4531 | Wi = createReshape(prefix + ".Wi.reshape" , Wi, {hiddenSize, inputSize}); |
4532 | Wi = createTranspose(prefix + ".Wi.transp" , Wi, {1, 0}); |
4533 | |
4534 | // Create R slice (Required). |
4535 | NodeValue Ri = |
4536 | createSlice(prefix + ".Ri." , R, RNN_R_SLICE_RANGE(sliceIdx0, 0)); |
4537 | Ri = createReshape(prefix + ".Ri.reshape" , Ri, {hiddenSize, hiddenSize}); |
4538 | Ri = createTranspose(prefix + ".Ri.transp" , Ri, {1, 0}); |
4539 | |
4540 | // Create B slices (optional). |
4541 | NodeValue bWi = nullptr; |
4542 | NodeValue bRi = nullptr; |
4543 | |
4544 | if (B) { |
4545 | |
4546 | bWi = createSlice(prefix + ".bWi." , B, RNN_B_SLICE_RANGE(sliceIdx0, 0)); |
4547 | bRi = createSlice(prefix + ".bRi." , B, RNN_B_SLICE_RANGE(sliceIdx0, 1)); |
4548 | |
4549 | bWi = createReshape(prefix + ".bWi.reshape" , bWi, {hiddenSize}); |
4550 | bRi = createReshape(prefix + ".bRi.reshape" , bRi, {hiddenSize}); |
4551 | } |
4552 | |
4553 | // Create H slice for this direction. |
4554 | Node *Hinit = createSlice(prefix + ".H.slice" , initial_h, |
4555 | RNN_H_SLICE_RANGE(sliceIdx0)); |
4556 | Hinit = |
4557 | createReshape(prefix + ".H.reshape" , Hinit, {batchSize, hiddenSize}); |
4558 | |
4559 | // Initialize. |
4560 | Node *Ht = Hinit; |
4561 | |
4562 | // Unroll RNN cell for all time steps. |
4563 | for (size_t t = 0; t < seqLength; t++) { |
4564 | |
4565 | // Input for current time step. |
4566 | // For the reverse RNN cell the inputs are provided in reverse order. |
4567 | Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t]; |
4568 | |
4569 | // Hidden state update: Ht = f(Xt * Wi + bWi + Ht-1 * Ri + bRi). |
4570 | Ht = createAdd(prefix + ".H.add" , |
4571 | RNN_CREATE_FC(prefix + ".H.fc1" , Xt, Wi, bWi), |
4572 | RNN_CREATE_FC(prefix + ".H.fc2" , Ht, Ri, bRi)); |
4573 | Ht = activationF(prefix + ".H.act" , Ht); |
4574 | |
4575 | // Output. |
4576 | Yslices.push_back(Ht); |
4577 | } |
4578 | |
4579 | // Updated states nodes. |
4580 | Hslice = Ht; |
4581 | }; // End of local lambda "loadRNNCell". |
4582 | |
4583 | bool forwardEnabled = ((direction == RnnDirection::Forward) || |
4584 | (direction == RnnDirection::Bidirectional)); |
4585 | bool backwardEnabled = ((direction == RnnDirection::Reverse) || |
4586 | (direction == RnnDirection::Bidirectional)); |
4587 | |
4588 | std::vector<NodeValue> YSlices; |
4589 | std::vector<NodeValue> Hslices; |
4590 | |
4591 | // Load forward RNN. |
4592 | std::vector<NodeValue> forwardYslices; |
4593 | if (forwardEnabled) { |
4594 | NodeValue forwardHslice; |
4595 | loadRNNCell(/* forward */ true, forwardYslices, forwardHslice); |
4596 | Hslices.push_back(forwardHslice); |
4597 | } |
4598 | |
4599 | // Load backward RNN. |
4600 | std::vector<NodeValue> backwardYslices; |
4601 | if (backwardEnabled) { |
4602 | NodeValue backwardHslice; |
4603 | loadRNNCell(/* forward */ false, backwardYslices, backwardHslice); |
4604 | Hslices.push_back(backwardHslice); |
4605 | } |
4606 | |
4607 | // Gather Y slices. |
4608 | for (size_t t = 0; t < seqLength; t++) { |
4609 | if (forwardEnabled) { |
4610 | YSlices.push_back(forwardYslices[t]); |
4611 | } |
4612 | if (backwardEnabled) { |
4613 | YSlices.push_back(backwardYslices[seqLength - 1 - t]); |
4614 | } |
4615 | } |
4616 | |
4617 | // Concatenate Y slices. |
4618 | // Y size is [seqLength, numDirections, batchSize, hiddenSize]. |
4619 | Y = createReshape(opName + ".Y.reshape" , |
4620 | createConcat(opName + ".Y.concat" , YSlices, 0), |
4621 | {seqLength, numDirections, batchSize, hiddenSize}); |
4622 | |
4623 | // Concatenate Y_h slices. |
4624 | // Y_h size is [numDirections, batchSize, hiddenSize]. |
4625 | Y_h = createReshape(opName + ".Y_h.reshape" , |
4626 | createConcat(opName + ".Y_h.concat" , Hslices, 0), |
4627 | {numDirections, batchSize, hiddenSize}); |
4628 | |
4629 | #undef RNN_X_SLICE_RANGE |
4630 | #undef RNN_W_SLICE_RANGE |
4631 | #undef RNN_R_SLICE_RANGE |
4632 | #undef RNN_B_SLICE_RANGE |
4633 | #undef RNN_H_SLICE_RANGE |
4634 | #undef RNN_CREATE_FC |
4635 | } |
4636 | |
4637 | void Function::createOnnxGRU(llvm::StringRef namePrefix, NodeValue X, |
4638 | NodeValue W, NodeValue R, NodeValue B, |
4639 | NodeValue initial_h, NodeValue &Y, NodeValue &Y_h, |
4640 | unsigned hiddenSize, RnnDirection direction, |
4641 | std::vector<RnnActivation> &activations, |
4642 | bool linearBeforeReset) { |
4643 | |
4644 | #define GRU_X_SLICE_RANGE(idx) \ |
4645 | {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize } |
4646 | #define GRU_W_SLICE_RANGE(idx0, idx1) \ |
4647 | {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize } |
4648 | #define GRU_R_SLICE_RANGE(idx0, idx1) \ |
4649 | {idx0, idx1 * hiddenSize, 0}, { \ |
4650 | idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \ |
4651 | } |
4652 | #define GRU_B_SLICE_RANGE(idx0, idx1) \ |
4653 | {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize } |
4654 | #define GRU_H_SLICE_RANGE(idx) \ |
4655 | {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize } |
4656 | #define GRU_CREATE_FC(name, LHS, RHS, BIAS) \ |
4657 | BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \ |
4658 | : (Node *)createMatMul(name, LHS, RHS) |
4659 | |
4660 | // Operator name. |
4661 | const std::string &opName = namePrefix.str(); |
4662 | |
4663 | // Get all size parameters. |
4664 | dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1; |
4665 | assert(X.dims().size() == 3 && |
4666 | "ONNX GRU input 'X' should have 3 dimensions!" ); |
4667 | dim_t seqLength = X.dims()[0]; |
4668 | dim_t batchSize = X.dims()[1]; |
4669 | dim_t inputSize = X.dims()[2]; |
4670 | |
4671 | // Validate W size. |
4672 | assert(W.dims().size() == 3 && |
4673 | "ONNX GRU input 'W' should have 3 dimensions!" ); |
4674 | assert(W.dims()[0] == numDirections && W.dims()[1] == 3 * hiddenSize && |
4675 | W.dims()[2] == inputSize && "ONNX GRU 'W' tensor size invalid!" ); |
4676 | |
4677 | // Validate R size. |
4678 | assert(R.dims().size() == 3 && |
4679 | "ONNX GRU input 'R' should have 3 dimensions!" ); |
4680 | assert(R.dims()[0] == numDirections && R.dims()[1] == 3 * hiddenSize && |
4681 | R.dims()[2] == hiddenSize && "ONNX GRU 'R' tensor size invalid!" ); |
4682 | |
4683 | // Validate B size. |
4684 | if (B.getNode()) { |
4685 | assert(B.dims().size() == 2 && |
4686 | "ONNX GRU input 'B' should have 2 dimensions!" ); |
4687 | assert(B.dims()[0] == numDirections && B.dims()[1] == 6 * hiddenSize && |
4688 | "ONNX GRU 'B' tensor size invalid!" ); |
4689 | } |
4690 | |
4691 | // Validate initial_h size if given else create Splat with 0. |
4692 | if (initial_h.getNode()) { |
4693 | assert(initial_h.dims().size() == 3 && |
4694 | "ONNX GRU input 'initial_h' should have 2 dimensions!" ); |
4695 | assert(initial_h.dims()[0] == numDirections && |
4696 | initial_h.dims()[1] == batchSize && |
4697 | initial_h.dims()[2] == hiddenSize && |
4698 | "ONNX GRU 'initial_h' tensor size invalid!" ); |
4699 | } else { |
4700 | auto splatTy = getParent()->uniqueType( |
4701 | ElemKind::FloatTy, {numDirections, batchSize, hiddenSize}); |
4702 | initial_h = createSplat(opName + ".initial_h" , splatTy, 0.0); |
4703 | } |
4704 | |
4705 | // Validate number of activations. |
4706 | assert(activations.size() == numDirections * 2 && |
4707 | "ONNX GRU activations vector invalid!" ); |
4708 | |
4709 | // Create X slices. |
4710 | std::vector<Node *> Xslices; |
4711 | for (dim_t t = 0; t < seqLength; t++) { |
4712 | auto XsliceName = opName + ".X" + std::to_string(t) + ".slice" ; |
4713 | Node *Xt = createSlice(XsliceName, X, GRU_X_SLICE_RANGE(t)); |
4714 | auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape" ; |
4715 | Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize}); |
4716 | Xslices.push_back(Xt); |
4717 | } |
4718 | |
4719 | // Lambda to load forward/backward GRU cell. |
4720 | auto loadGRUCell = [&](bool forward, std::vector<NodeValue> &Yslices, |
4721 | NodeValue &Hslice) { |
4722 | // Name prefix. |
4723 | std::string dirLabel = forward ? ".fw" : ".bw" ; |
4724 | std::string prefix = opName + ((numDirections > 1) ? dirLabel : "" ); |
4725 | |
4726 | // Slice index used for creating weights slices. |
4727 | dim_t sliceIdx0 = 0; |
4728 | if (direction == RnnDirection::Bidirectional) { |
4729 | sliceIdx0 = forward ? 0 : 1; |
4730 | } |
4731 | |
4732 | // Activations. |
4733 | size_t activationOffset = sliceIdx0 * 2; |
4734 | auto activationF = activations[activationOffset + 0]; |
4735 | auto activationG = activations[activationOffset + 1]; |
4736 | |
4737 | // Create W slices (Required). |
4738 | NodeValue Wz = |
4739 | createSlice(prefix + ".Wz." , W, GRU_W_SLICE_RANGE(sliceIdx0, 0)); |
4740 | NodeValue Wr = |
4741 | createSlice(prefix + ".Wr." , W, GRU_W_SLICE_RANGE(sliceIdx0, 1)); |
4742 | NodeValue Wh = |
4743 | createSlice(prefix + ".Wh." , W, GRU_W_SLICE_RANGE(sliceIdx0, 2)); |
4744 | |
4745 | Wz = createReshape(prefix + ".Wz.reshape" , Wz, {hiddenSize, inputSize}); |
4746 | Wr = createReshape(prefix + ".Wr.reshape" , Wr, {hiddenSize, inputSize}); |
4747 | Wh = createReshape(prefix + ".Wh.reshape" , Wh, {hiddenSize, inputSize}); |
4748 | |
4749 | Wz = createTranspose(prefix + ".Wz.transp" , Wz, {1, 0}); |
4750 | Wr = createTranspose(prefix + ".Wr.transp" , Wr, {1, 0}); |
4751 | Wh = createTranspose(prefix + ".Wh.transp" , Wh, {1, 0}); |
4752 | |
4753 | // Create R slices (Required). |
4754 | NodeValue Rz = |
4755 | createSlice(prefix + ".Rz." , R, GRU_R_SLICE_RANGE(sliceIdx0, 0)); |
4756 | NodeValue Rr = |
4757 | createSlice(prefix + ".Rr." , R, GRU_R_SLICE_RANGE(sliceIdx0, 1)); |
4758 | NodeValue Rh = |
4759 | createSlice(prefix + ".Rh." , R, GRU_R_SLICE_RANGE(sliceIdx0, 2)); |
4760 | |
4761 | Rz = createReshape(prefix + ".Rz.reshape" , Rz, {hiddenSize, hiddenSize}); |
4762 | Rr = createReshape(prefix + ".Rr.reshape" , Rr, {hiddenSize, hiddenSize}); |
4763 | Rh = createReshape(prefix + ".Rh.reshape" , Rh, {hiddenSize, hiddenSize}); |
4764 | |
4765 | Rz = createTranspose(prefix + ".Rz.transp" , Rz, {1, 0}); |
4766 | Rr = createTranspose(prefix + ".Rr.transp" , Rr, {1, 0}); |
4767 | Rh = createTranspose(prefix + ".Rh.transp" , Rh, {1, 0}); |
4768 | |
4769 | // Create B slices (optional). |
4770 | NodeValue bWz = nullptr; |
4771 | NodeValue bWr = nullptr; |
4772 | NodeValue bWh = nullptr; |
4773 | NodeValue bRz = nullptr; |
4774 | NodeValue bRr = nullptr; |
4775 | NodeValue bRh = nullptr; |
4776 | |
4777 | if (B) { |
4778 | |
4779 | bWz = createSlice(prefix + ".bWz." , B, GRU_B_SLICE_RANGE(sliceIdx0, 0)); |
4780 | bWr = createSlice(prefix + ".bWr." , B, GRU_B_SLICE_RANGE(sliceIdx0, 1)); |
4781 | bWh = createSlice(prefix + ".bWh." , B, GRU_B_SLICE_RANGE(sliceIdx0, 2)); |
4782 | bRz = createSlice(prefix + ".bRz." , B, GRU_B_SLICE_RANGE(sliceIdx0, 3)); |
4783 | bRr = createSlice(prefix + ".bRr." , B, GRU_B_SLICE_RANGE(sliceIdx0, 4)); |
4784 | bRh = createSlice(prefix + ".bRh." , B, GRU_B_SLICE_RANGE(sliceIdx0, 5)); |
4785 | |
4786 | bWz = createReshape(prefix + ".bWz.reshape" , bWz, {hiddenSize}); |
4787 | bWr = createReshape(prefix + ".bWr.reshape" , bWr, {hiddenSize}); |
4788 | bWh = createReshape(prefix + ".bWh.reshape" , bWh, {hiddenSize}); |
4789 | bRz = createReshape(prefix + ".bRz.reshape" , bRz, {hiddenSize}); |
4790 | bRr = createReshape(prefix + ".bRr.reshape" , bRr, {hiddenSize}); |
4791 | bRh = createReshape(prefix + ".bRh.reshape" , bRh, {hiddenSize}); |
4792 | } |
4793 | |
4794 | // Create H slice for this direction. |
4795 | Node *Hinit = createSlice(prefix + ".H.slice" , initial_h, |
4796 | GRU_H_SLICE_RANGE(sliceIdx0)); |
4797 | Hinit = |
4798 | createReshape(prefix + ".H.reshape" , Hinit, {batchSize, hiddenSize}); |
4799 | |
4800 | // Initialize. |
4801 | Node *Ht = Hinit; |
4802 | |
4803 | // Unroll GRU cell for all time steps. |
4804 | for (size_t t = 0; t < seqLength; t++) { |
4805 | |
4806 | // Input for current time step. |
4807 | // For the reverse GRU cell the inputs are provided in reverse order. |
4808 | Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t]; |
4809 | |
4810 | // Update gate: zt = f(Xt * Wz + bWz + Ht-1 * Rz + bRz). |
4811 | Node *zt = createAdd(prefix + ".Z.add1" , |
4812 | GRU_CREATE_FC(prefix + ".Z.fc1" , Xt, Wz, bWz), |
4813 | GRU_CREATE_FC(prefix + ".Z.fc2" , Ht, Rz, bRz)); |
4814 | zt = activationF(prefix + ".Z.act" , zt); |
4815 | |
4816 | // Reset gate: rt = f(Xt * Wr + bWr + Ht-1 * Rr + bRr). |
4817 | Node *rt = createAdd(prefix + ".R.add1" , |
4818 | GRU_CREATE_FC(prefix + ".R.fc1" , Xt, Wr, bWr), |
4819 | GRU_CREATE_FC(prefix + ".R.fc2" , Ht, Rr, bRr)); |
4820 | rt = activationF(prefix + ".R.act" , rt); |
4821 | |
4822 | // Hidden gate: |
4823 | // For linearBeforeReset = true: |
4824 | // htild = g(Xt * Wh + bWh + rt . (Ht-1 * Rh + bRh)). |
4825 | // For linearBeforeReset = false: |
4826 | // htild = g(Xt * Wh + bWh + (rt . Ht-1) * Rh + bRh). |
4827 | Node *htild; |
4828 | if (linearBeforeReset) { |
4829 | htild = createAdd( |
4830 | prefix + ".Htild.add" , |
4831 | GRU_CREATE_FC(prefix + ".Htild.fc1" , Xt, Wh, bWh), |
4832 | createMul(prefix + ".Htild.reset" , rt, |
4833 | GRU_CREATE_FC(prefix + ".Htild.fc2" , Ht, Rh, bRh))); |
4834 | } else { |
4835 | htild = createAdd( |
4836 | prefix + ".Htild.add" , |
4837 | GRU_CREATE_FC(prefix + ".Htild.fc1" , Xt, Wh, bWh), |
4838 | GRU_CREATE_FC(prefix + ".Htild.fc2" , |
4839 | createMul(prefix + ".Htild.reset" , rt, Ht), Rh, bRh)); |
4840 | } |
4841 | htild = activationG(prefix + ".Htild.act" , htild); |
4842 | |
4843 | // Hidden state update: |
4844 | // Ht = (1 - zt) . htild + zt . Ht-1 = htild - zt . htild + zt . Ht-1. |
4845 | Ht = createAdd(prefix + ".H.add" , |
4846 | createSub(prefix + ".H.sub" , htild, |
4847 | createMul(prefix + ".H.mult1" , zt, htild)), |
4848 | createMul(prefix + ".H.mult2" , zt, Ht)); |
4849 | |
4850 | // Output. |
4851 | Yslices.push_back(Ht); |
4852 | } |
4853 | |
4854 | // Updated states nodes. |
4855 | Hslice = Ht; |
4856 | }; // End of local lambda "loadGRUCell". |
4857 | |
4858 | bool forwardEnabled = ((direction == RnnDirection::Forward) || |
4859 | (direction == RnnDirection::Bidirectional)); |
4860 | bool backwardEnabled = ((direction == RnnDirection::Reverse) || |
4861 | (direction == RnnDirection::Bidirectional)); |
4862 | |
4863 | std::vector<NodeValue> YSlices; |
4864 | std::vector<NodeValue> Hslices; |
4865 | |
4866 | // Load forward GRU. |
4867 | std::vector<NodeValue> forwardYslices; |
4868 | if (forwardEnabled) { |
4869 | NodeValue forwardHslice; |
4870 | loadGRUCell(/* forward */ true, forwardYslices, forwardHslice); |
4871 | Hslices.push_back(forwardHslice); |
4872 | } |
4873 | |
4874 | // Load backward GRU. |
4875 | std::vector<NodeValue> backwardYslices; |
4876 | if (backwardEnabled) { |
4877 | NodeValue backwardHslice; |
4878 | loadGRUCell(/* forward */ false, backwardYslices, backwardHslice); |
4879 | Hslices.push_back(backwardHslice); |
4880 | } |
4881 | |
4882 | // Gather Y slices. |
4883 | for (size_t t = 0; t < seqLength; t++) { |
4884 | if (forwardEnabled) { |
4885 | YSlices.push_back(forwardYslices[t]); |
4886 | } |
4887 | if (backwardEnabled) { |
4888 | YSlices.push_back(backwardYslices[seqLength - 1 - t]); |
4889 | } |
4890 | } |
4891 | |
4892 | // Concatenate Y slices. |
4893 | // Y size is [seqLength, numDirections, batchSize, hiddenSize]. |
4894 | Y = createReshape(opName + ".Y.reshape" , |
4895 | createConcat(opName + ".Y.concat" , YSlices, 0), |
4896 | {seqLength, numDirections, batchSize, hiddenSize}); |
4897 | |
4898 | // Concatenate Y_h slices. |
4899 | // Y_h size is [numDirections, batchSize, hiddenSize]. |
4900 | Y_h = createReshape(opName + ".Y_h.reshape" , |
4901 | createConcat(opName + ".Y_h.concat" , Hslices, 0), |
4902 | {numDirections, batchSize, hiddenSize}); |
4903 | |
4904 | #undef GRU_X_SLICE_RANGE |
4905 | #undef GRU_W_SLICE_RANGE |
4906 | #undef GRU_R_SLICE_RANGE |
4907 | #undef GRU_B_SLICE_RANGE |
4908 | #undef GRU_H_SLICE_RANGE |
4909 | #undef GRU_CREATE_FC |
4910 | } |
4911 | |
4912 | void Function::createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X, |
4913 | NodeValue W, NodeValue R, NodeValue B, |
4914 | NodeValue initial_h, NodeValue initial_c, |
4915 | NodeValue P, NodeValue &Y, NodeValue &Y_h, |
4916 | NodeValue &Y_c, unsigned hiddenSize, |
4917 | RnnDirection direction, |
4918 | std::vector<RnnActivation> &activations, |
4919 | bool inputForget) { |
4920 | |
4921 | #define LSTM_X_SLICE_RANGE(idx) \ |
4922 | {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize } |
4923 | #define LSTM_H_SLICE_RANGE(idx) \ |
4924 | {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize } |
4925 | #define LSTM_C_SLICE_RANGE(idx) \ |
4926 | {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize } |
4927 | #define LSTM_W_SLICE_RANGE(idx0, idx1) \ |
4928 | {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize } |
4929 | #define LSTM_R_SLICE_RANGE(idx0, idx1) \ |
4930 | {idx0, idx1 * hiddenSize, 0}, { \ |
4931 | idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \ |
4932 | } |
4933 | #define LSTM_B_SLICE_RANGE(idx0, idx1) \ |
4934 | {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize } |
4935 | #define LSTM_P_SLICE_RANGE(idx0, idx1) \ |
4936 | {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize } |
4937 | #define LSTM_CREATE_FC(name, LHS, RHS, BIAS) \ |
4938 | BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \ |
4939 | : (Node *)createMatMul(name, LHS, RHS) |
4940 | |
4941 | // Operator name. |
4942 | const std::string &opName = namePrefix.str(); |
4943 | |
4944 | // Get all size parameters. |
4945 | dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1; |
4946 | assert(X.dims().size() == 3 && |
4947 | "ONNX LSTM input 'X' should have 3 dimensions!" ); |
4948 | dim_t seqLength = X.dims()[0]; |
4949 | dim_t batchSize = X.dims()[1]; |
4950 | dim_t inputSize = X.dims()[2]; |
4951 | |
4952 | // Validate W size. |
4953 | assert(W.dims().size() == 3 && |
4954 | "ONNX LSTM input 'W' should have 3 dimensions!" ); |
4955 | assert(W.dims()[0] == numDirections && W.dims()[1] == 4 * hiddenSize && |
4956 | W.dims()[2] == inputSize && "ONNX LSTM 'W' tensor size invalid!" ); |
4957 | |
4958 | // Validate R size. |
4959 | assert(R.dims().size() == 3 && |
4960 | "ONNX LSTM input 'R' should have 3 dimensions!" ); |
4961 | assert(R.dims()[0] == numDirections && R.dims()[1] == 4 * hiddenSize && |
4962 | R.dims()[2] == hiddenSize && "ONNX LSTM 'R' tensor size invalid!" ); |
4963 | |
4964 | // Validate B size. |
4965 | if (B.getNode()) { |
4966 | assert(B.dims().size() == 2 && |
4967 | "ONNX LSTM input 'B' should have 2 dimensions!" ); |
4968 | assert(B.dims()[0] == numDirections && B.dims()[1] == 8 * hiddenSize && |
4969 | "ONNX LSTM 'B' tensor size invalid!" ); |
4970 | } |
4971 | |
4972 | // Validate initial_h size if given else create Splat with 0. |
4973 | if (initial_h.getNode()) { |
4974 | assert(initial_h.dims().size() == 3 && |
4975 | "ONNX LSTM input 'initial_h' should have 2 dimensions!" ); |
4976 | assert(initial_h.dims()[0] == numDirections && |
4977 | initial_h.dims()[1] == batchSize && |
4978 | initial_h.dims()[2] == hiddenSize && |
4979 | "ONNX LSTM 'initial_h' tensor size invalid!" ); |
4980 | } else { |
4981 | auto splatTy = getParent()->uniqueType( |
4982 | ElemKind::FloatTy, {numDirections, batchSize, hiddenSize}); |
4983 | initial_h = createSplat(opName + ".initial_h" , splatTy, 0.0); |
4984 | } |
4985 | |
4986 | // Validate initial_c size if given else create Splat with 0. |
4987 | if (initial_c.getNode()) { |
4988 | assert(initial_c.dims().size() == 3 && |
4989 | "ONNX LSTM input 'initial_c' should have 2 dimensions!" ); |
4990 | assert(initial_c.dims()[0] == numDirections && |
4991 | initial_c.dims()[1] == batchSize && |
4992 | initial_c.dims()[2] == hiddenSize && |
4993 | "ONNX LSTM 'initial_c' tensor size invalid!" ); |
4994 | } else { |
4995 | auto splatTy = getParent()->uniqueType( |
4996 | ElemKind::FloatTy, {numDirections, batchSize, hiddenSize}); |
4997 | initial_c = createSplat(opName + ".initial_c" , splatTy, 0.0); |
4998 | } |
4999 | |
5000 | // Validate P size. |
5001 | if (P.getNode()) { |
5002 | assert(P.dims().size() == 2 && |
5003 | "ONNX LSTM input 'P' should have 2 dimensions!" ); |
5004 | assert(P.dims()[0] == numDirections && P.dims()[1] == 3 * hiddenSize && |
5005 | "ONNX LSTM 'P' tensor size invalid!" ); |
5006 | } |
5007 | |
5008 | // Validate number of activations. |
5009 | assert(activations.size() == numDirections * 3 && |
5010 | "ONNX LSTM activations vector invalid!" ); |
5011 | |
5012 | // Create X slices. |
5013 | std::vector<Node *> Xslices; |
5014 | for (dim_t t = 0; t < seqLength; t++) { |
5015 | auto XsliceName = opName + ".X" + std::to_string(t) + ".slice" ; |
5016 | Node *Xt = createSlice(XsliceName, X, LSTM_X_SLICE_RANGE(t)); |
5017 | auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape" ; |
5018 | Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize}); |
5019 | Xslices.push_back(Xt); |
5020 | } |
5021 | |
5022 | // Lambda to load forward/backward LSTM cell. |
5023 | auto loadLSTMCell = [&](bool forward, std::vector<NodeValue> &Yslices, |
5024 | NodeValue &Hslice, NodeValue &Cslice) { |
5025 | // Name prefix. |
5026 | std::string dirLabel = forward ? ".fw" : ".bw" ; |
5027 | std::string prefix = opName + ((numDirections > 1) ? dirLabel : "" ); |
5028 | |
5029 | // Slice index used for creating weights slices. |
5030 | dim_t sliceIdx0 = 0; |
5031 | if (direction == RnnDirection::Bidirectional) { |
5032 | sliceIdx0 = forward ? 0 : 1; |
5033 | } |
5034 | |
5035 | // Activations. |
5036 | size_t activationOffset = sliceIdx0 * 3; |
5037 | auto activationF = activations[activationOffset + 0]; |
5038 | auto activationG = activations[activationOffset + 1]; |
5039 | auto activationH = activations[activationOffset + 2]; |
5040 | |
5041 | // Create W slices (Required). |
5042 | NodeValue Wi = |
5043 | createSlice(prefix + ".Wi." , W, LSTM_W_SLICE_RANGE(sliceIdx0, 0)); |
5044 | NodeValue Wo = |
5045 | createSlice(prefix + ".Wo." , W, LSTM_W_SLICE_RANGE(sliceIdx0, 1)); |
5046 | NodeValue Wf = |
5047 | createSlice(prefix + ".Wf." , W, LSTM_W_SLICE_RANGE(sliceIdx0, 2)); |
5048 | NodeValue Wc = |
5049 | createSlice(prefix + ".Wc." , W, LSTM_W_SLICE_RANGE(sliceIdx0, 3)); |
5050 | |
5051 | Wi = createReshape(prefix + ".Wi.reshape" , Wi, {hiddenSize, inputSize}); |
5052 | Wo = createReshape(prefix + ".Wo.reshape" , Wo, {hiddenSize, inputSize}); |
5053 | Wf = createReshape(prefix + ".Wf.reshape" , Wf, {hiddenSize, inputSize}); |
5054 | Wc = createReshape(prefix + ".Wc.reshape" , Wc, {hiddenSize, inputSize}); |
5055 | |
5056 | Wi = createTranspose(prefix + ".Wi.transp" , Wi, {1, 0}); |
5057 | Wo = createTranspose(prefix + ".Wo.transp" , Wo, {1, 0}); |
5058 | Wf = createTranspose(prefix + ".Wf.transp" , Wf, {1, 0}); |
5059 | Wc = createTranspose(prefix + ".Wc.transp" , Wc, {1, 0}); |
5060 | |
5061 | // Create R slices (Required). |
5062 | NodeValue Ri = |
5063 | createSlice(prefix + ".Ri." , R, LSTM_R_SLICE_RANGE(sliceIdx0, 0)); |
5064 | NodeValue Ro = |
5065 | createSlice(prefix + ".Ro." , R, LSTM_R_SLICE_RANGE(sliceIdx0, 1)); |
5066 | NodeValue Rf = |
5067 | createSlice(prefix + ".Rf." , R, LSTM_R_SLICE_RANGE(sliceIdx0, 2)); |
5068 | NodeValue Rc = |
5069 | createSlice(prefix + ".Rc." , R, LSTM_R_SLICE_RANGE(sliceIdx0, 3)); |
5070 | |
5071 | Ri = createReshape(prefix + ".Ri.reshape" , Ri, {hiddenSize, hiddenSize}); |
5072 | Ro = createReshape(prefix + ".Ro.reshape" , Ro, {hiddenSize, hiddenSize}); |
5073 | Rf = createReshape(prefix + ".Rf.reshape" , Rf, {hiddenSize, hiddenSize}); |
5074 | Rc = createReshape(prefix + ".Rc.reshape" , Rc, {hiddenSize, hiddenSize}); |
5075 | |
5076 | Ri = createTranspose(prefix + ".Ri.transp" , Ri, {1, 0}); |
5077 | Ro = createTranspose(prefix + ".Ro.transp" , Ro, {1, 0}); |
5078 | Rf = createTranspose(prefix + ".Rf.transp" , Rf, {1, 0}); |
5079 | Rc = createTranspose(prefix + ".Rc.transp" , Rc, {1, 0}); |
5080 | |
5081 | // Create B slices (optional). |
5082 | NodeValue bWi = nullptr; |
5083 | NodeValue bWo = nullptr; |
5084 | NodeValue bWf = nullptr; |
5085 | NodeValue bWc = nullptr; |
5086 | NodeValue bRi = nullptr; |
5087 | NodeValue bRo = nullptr; |
5088 | NodeValue bRf = nullptr; |
5089 | NodeValue bRc = nullptr; |
5090 | |
5091 | if (B) { |
5092 | |
5093 | bWi = createSlice(prefix + ".bWi." , B, LSTM_B_SLICE_RANGE(sliceIdx0, 0)); |
5094 | bWo = createSlice(prefix + ".bWo." , B, LSTM_B_SLICE_RANGE(sliceIdx0, 1)); |
5095 | bWf = createSlice(prefix + ".bWf." , B, LSTM_B_SLICE_RANGE(sliceIdx0, 2)); |
5096 | bWc = createSlice(prefix + ".bWc." , B, LSTM_B_SLICE_RANGE(sliceIdx0, 3)); |
5097 | bRi = createSlice(prefix + ".bRi." , B, LSTM_B_SLICE_RANGE(sliceIdx0, 4)); |
5098 | bRo = createSlice(prefix + ".bRo." , B, LSTM_B_SLICE_RANGE(sliceIdx0, 5)); |
5099 | bRf = createSlice(prefix + ".bRf." , B, LSTM_B_SLICE_RANGE(sliceIdx0, 6)); |
5100 | bRc = createSlice(prefix + ".bRc." , B, LSTM_B_SLICE_RANGE(sliceIdx0, 7)); |
5101 | |
5102 | bWi = createReshape(prefix + ".bWi.reshape" , bWi, {hiddenSize}); |
5103 | bWo = createReshape(prefix + ".bWo.reshape" , bWo, {hiddenSize}); |
5104 | bWf = createReshape(prefix + ".bWf.reshape" , bWf, {hiddenSize}); |
5105 | bWc = createReshape(prefix + ".bWc.reshape" , bWc, {hiddenSize}); |
5106 | bRi = createReshape(prefix + ".bRi.reshape" , bRi, {hiddenSize}); |
5107 | bRo = createReshape(prefix + ".bRo.reshape" , bRo, {hiddenSize}); |
5108 | bRf = createReshape(prefix + ".bRf.reshape" , bRf, {hiddenSize}); |
5109 | bRc = createReshape(prefix + ".bRc.reshape" , bRc, {hiddenSize}); |
5110 | } |
5111 | |
5112 | // Create P slices (optional). |
5113 | NodeValue Pi = nullptr; |
5114 | NodeValue Po = nullptr; |
5115 | NodeValue Pf = nullptr; |
5116 | |
5117 | if (P) { |
5118 | |
5119 | Pi = createSlice(prefix + ".Pi." , P, LSTM_P_SLICE_RANGE(sliceIdx0, 0)); |
5120 | Po = createSlice(prefix + ".Po." , P, LSTM_P_SLICE_RANGE(sliceIdx0, 1)); |
5121 | Pf = createSlice(prefix + ".Pf." , P, LSTM_P_SLICE_RANGE(sliceIdx0, 2)); |
5122 | |
5123 | // Repeat P slices to match [batchSize, hiddenSize]. |
5124 | Pi = createTile(prefix + ".Pi.repeat" , Pi, batchSize, 0); |
5125 | Po = createTile(prefix + ".Po.repeat" , Po, batchSize, 0); |
5126 | Pf = createTile(prefix + ".Pf.repeat" , Pf, batchSize, 0); |
5127 | } |
5128 | |
5129 | // Create H slice for this direction. |
5130 | Node *Hinit = createSlice(prefix + ".H.slice" , initial_h, |
5131 | LSTM_H_SLICE_RANGE(sliceIdx0)); |
5132 | Hinit = |
5133 | createReshape(prefix + ".H.reshape" , Hinit, {batchSize, hiddenSize}); |
5134 | |
5135 | // Create C slice for this direction. |
5136 | Node *Cinit = createSlice(prefix + ".C.slice" , initial_c, |
5137 | LSTM_C_SLICE_RANGE(sliceIdx0)); |
5138 | Cinit = |
5139 | createReshape(prefix + ".C.reshape" , Cinit, {batchSize, hiddenSize}); |
5140 | |
5141 | // Initialize. |
5142 | Node *Ht = Hinit; |
5143 | Node *Ct = Cinit; |
5144 | |
5145 | // Unroll LSTM cell for all time steps. |
5146 | for (size_t t = 0; t < seqLength; t++) { |
5147 | |
5148 | // Input for current time step. |
5149 | // For the reverse LSTM cell the inputs are provided in reverse order. |
5150 | Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t]; |
5151 | |
5152 | // Forget gate: ft = f(Xt * Wf + bWf + Ht-1 * Rf + bRf + Pf . Ct-1). |
5153 | Node *ft = createAdd(prefix + ".F.add1" , |
5154 | LSTM_CREATE_FC(prefix + ".F.fc1" , Xt, Wf, bWf), |
5155 | LSTM_CREATE_FC(prefix + ".F.fc2" , Ht, Rf, bRf)); |
5156 | if (Pf) { |
5157 | ft = createAdd(prefix + ".F.add2" , ft, |
5158 | createMul(prefix + ".F.mult" , Pf, Ct)); |
5159 | } |
5160 | ft = activationF(prefix + ".F.act" , ft); |
5161 | |
5162 | // Cell state candidate: ctild = g(Xt * Wc + bWc + Ht-1 * Rc + bRc). |
5163 | Node *ctild = |
5164 | createAdd(prefix + ".Ctild.add" , |
5165 | LSTM_CREATE_FC(prefix + ".Ctild.fc1" , Xt, Wc, bWc), |
5166 | LSTM_CREATE_FC(prefix + ".Ctild.fc2" , Ht, Rc, bRc)); |
5167 | ctild = activationG(prefix + ".Ctild.act" , ctild); |
5168 | |
5169 | // Input gate: |
5170 | // For inputForget == true: |
5171 | // it = 1 - ft. |
5172 | // For inputForget == false: |
5173 | // it = f(Xt * Wi + bWi + Ht-1 * Ri + bRi + Pi . Ct-1). |
5174 | Node *it; |
5175 | if (inputForget) { |
5176 | auto splatTy = ft->getNthResult(0).getType(); |
5177 | it = createSub(prefix + ".I.sub" , |
5178 | createSplat(prefix + ".I.splat" , splatTy, 1.0), ft); |
5179 | } else { |
5180 | it = createAdd(prefix + ".I.add1" , |
5181 | LSTM_CREATE_FC(prefix + ".I.fc1" , Xt, Wi, bWi), |
5182 | LSTM_CREATE_FC(prefix + ".I.fc2" , Ht, Ri, bRi)); |
5183 | if (Pi) { |
5184 | it = createAdd(prefix + ".I.add2" , it, |
5185 | createMul(prefix + ".I.mult" , Pi, Ct)); |
5186 | } |
5187 | it = activationF(prefix + ".I.act" , it); |
5188 | } |
5189 | |
5190 | // Cell state update: Ct = ft . Ct-1 + it . ctild. |
5191 | Ct = createAdd(prefix + ".C.add" , createMul(prefix + ".C.mult1" , ft, Ct), |
5192 | createMul(prefix + ".C.mult2" , it, ctild)); |
5193 | |
5194 | // Output gate: ot = f(Xt * Wo + bWo + Ht-1 * Ro + bRo + Po . Ct). |
5195 | Node *ot = createAdd(prefix + ".O.add1" , |
5196 | LSTM_CREATE_FC(prefix + ".O.fc1" , Xt, Wo, bWo), |
5197 | LSTM_CREATE_FC(prefix + ".O.fc2" , Ht, Ro, bRo)); |
5198 | if (Po) { |
5199 | ot = createAdd(prefix + ".O.add2" , ot, |
5200 | createMul(prefix + ".O.mult" , Po, Ct)); |
5201 | } |
5202 | ot = activationF(prefix + ".O.act" , ot); |
5203 | |
5204 | // Hidden state update: Ht = ot . h(Ct). |
5205 | Ht = |
5206 | createMul(prefix + ".H.mult" , ot, activationH(prefix + ".H.act" , Ct)); |
5207 | |
5208 | // Output. |
5209 | Yslices.push_back(Ht); |
5210 | } |
5211 | |
5212 | // Updated states nodes. |
5213 | Hslice = Ht; |
5214 | Cslice = Ct; |
5215 | }; // End of local lambda "loadLSTMCell". |
5216 | |
5217 | bool forwardEnabled = ((direction == RnnDirection::Forward) || |
5218 | (direction == RnnDirection::Bidirectional)); |
5219 | bool backwardEnabled = ((direction == RnnDirection::Reverse) || |
5220 | (direction == RnnDirection::Bidirectional)); |
5221 | |
5222 | std::vector<NodeValue> YSlices; |
5223 | std::vector<NodeValue> Hslices; |
5224 | std::vector<NodeValue> Cslices; |
5225 | |
5226 | // Load forward LSTM. |
5227 | std::vector<NodeValue> forwardYslices; |
5228 | if (forwardEnabled) { |
5229 | NodeValue forwardHslice; |
5230 | NodeValue forwardCslice; |
5231 | loadLSTMCell(/* forward */ true, forwardYslices, forwardHslice, |
5232 | forwardCslice); |
5233 | Hslices.push_back(forwardHslice); |
5234 | Cslices.push_back(forwardCslice); |
5235 | } |
5236 | |
5237 | // Load backward LSTM. |
5238 | std::vector<NodeValue> backwardYslices; |
5239 | if (backwardEnabled) { |
5240 | NodeValue backwardHslice; |
5241 | NodeValue backwardCslice; |
5242 | loadLSTMCell(/* forward */ false, backwardYslices, backwardHslice, |
5243 | backwardCslice); |
5244 | Hslices.push_back(backwardHslice); |
5245 | Cslices.push_back(backwardCslice); |
5246 | } |
5247 | |
5248 | // Gather Y slices. |
5249 | for (size_t t = 0; t < seqLength; t++) { |
5250 | if (forwardEnabled) { |
5251 | YSlices.push_back(forwardYslices[t]); |
5252 | } |
5253 | if (backwardEnabled) { |
5254 | YSlices.push_back(backwardYslices[seqLength - 1 - t]); |
5255 | } |
5256 | } |
5257 | |
5258 | // Concatenate Y slices. |
5259 | // Y size is [seqLength, numDirections, batchSize, hiddenSize]. |
5260 | Y = createReshape(opName + ".Y.reshape" , |
5261 | createConcat(opName + ".Y.concat" , YSlices, 0), |
5262 | {seqLength, numDirections, batchSize, hiddenSize}); |
5263 | |
5264 | // Concatenate Y_h slices. |
5265 | // Y_h size is [numDirections, batchSize, hiddenSize]. |
5266 | Y_h = createReshape(opName + ".Y_h.reshape" , |
5267 | createConcat(opName + ".Y_h.concat" , Hslices, 0), |
5268 | {numDirections, batchSize, hiddenSize}); |
5269 | |
5270 | // Concatenate Y_c slices. |
5271 | // Y_c size is [numDirections, batchSize, hiddenSize]. |
5272 | Y_c = createReshape(opName + ".Y_c.reshape" , |
5273 | createConcat(opName + ".Y_c.concat" , Cslices, 0), |
5274 | {numDirections, batchSize, hiddenSize}); |
5275 | |
5276 | #undef LSTM_X_SLICE_RANGE |
5277 | #undef LSTM_H_SLICE_RANGE |
5278 | #undef LSTM_C_SLICE_RANGE |
5279 | #undef LSTM_W_SLICE_RANGE |
5280 | #undef LSTM_R_SLICE_RANGE |
5281 | #undef LSTM_B_SLICE_RANGE |
5282 | #undef LSTM_P_SLICE_RANGE |
5283 | #undef LSTM_CREATE_FC |
5284 | } |
5285 | |
5286 | TraceEventNode *Function::createTraceEvent(llvm::StringRef eventName, |
5287 | llvm::StringRef eventType, |
5288 | Node *data, unsigned index) { |
5289 | std::string name = (getName() + "_" + eventName + "_instrumentation" ).str(); |
5290 | return addNode( |
5291 | new TraceEventNode(name, data, eventName.str(), eventType.str(), index)); |
5292 | } |
5293 | |
5294 | NonMaxSuppressionNode *Function::createNonMaxSuppressionV4( |
5295 | llvm::StringRef name, NodeValue boxes, NodeValue scores, |
5296 | int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold, |
5297 | float scoreThreshold, ElemKind elTy) { |
5298 | // V4 |
5299 | // Class/Score [BatchNum][BoxNum] |
5300 | // Boxes [BatdhNum][BoxNum][4] |
5301 | // Result [BatchNum*MaxOutputPerBatch] |
5302 | // NumberOfIndicesDetected [BatchNum*MaxOutputPerBatch] |
5303 | auto scoresDim = scores.dims(); |
5304 | int scoresBoxDim = scoresDim.size() - 1; |
5305 | if (maxOutputBoxesPerClass == 0) { |
5306 | maxOutputBoxesPerClass = scoresDim[scoresBoxDim]; |
5307 | } |
5308 | |
5309 | // Allocating maximum because we don't know how many boxes will actually be |
5310 | // detected. |
5311 | std::vector<dim_t> newDim = {static_cast<dim_t>(maxOutputBoxesPerClass)}; |
5312 | auto indicesTy = getParent()->uniqueType(elTy, newDim); |
5313 | auto numberOfSelectedIndicesTy = getParent()->uniqueType( |
5314 | elTy, {static_cast<dim_t>(maxOutputBoxesPerClass)}); |
5315 | return addNode(new NonMaxSuppressionNode( |
5316 | name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox, |
5317 | maxOutputBoxesPerClass, iouThreshold, scoreThreshold, true)); |
5318 | } |
5319 | |
5320 | NonMaxSuppressionNode * |
5321 | Function::createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes, |
5322 | NodeValue scores, int64_t centerPointBox, |
5323 | int64_t maxOutputBoxesPerClass, |
5324 | float iouThreshold, float scoreThreshold) { |
5325 | return createNonMaxSuppressionV4(name, boxes, scores, centerPointBox, |
5326 | maxOutputBoxesPerClass, iouThreshold, |
5327 | scoreThreshold, ElemKind::Int64ITy); |
5328 | } |
5329 | |
5330 | NonMaxSuppressionNode *Function::createNonMaxSuppressionV4( |
5331 | llvm::StringRef name, NodeValue boxes, NodeValue scores, |
5332 | int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold, |
5333 | float scoreThreshold, TypeRef indicesTy, |
5334 | TypeRef numberOfSelectedIndicesTy) { |
5335 | assert(maxOutputBoxesPerClass > 0 && "Invalid maxOutputBoxesPerClass." ); |
5336 | |
5337 | return addNode(new NonMaxSuppressionNode( |
5338 | name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox, |
5339 | maxOutputBoxesPerClass, iouThreshold, scoreThreshold, true)); |
5340 | } |
5341 | |
5342 | NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX( |
5343 | llvm::StringRef name, NodeValue boxes, NodeValue scores, |
5344 | int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold, |
5345 | float scoreThreshold, ElemKind elTy) { |
5346 | // ONNX |
5347 | // Class/Score [BatchNum][ClassNum][BoxNum] |
5348 | // Box [BatchNum][BoxNum][4] |
5349 | // Result [BatchNum*MaxOutputPerBatch][3] |
5350 | auto boxesDim = boxes.dims(); |
5351 | auto scoresDim = scores.dims(); |
5352 | int scoresBoxDim = scoresDim.size() - 1; |
5353 | int scoresClassDim = scoresDim.size() - 2; |
5354 | int scoresBatchDim = scoresDim.size() - 3; |
5355 | int boxesBatchDim = boxesDim.size() - 3; |
5356 | if (maxOutputBoxesPerClass == 0) { |
5357 | maxOutputBoxesPerClass = scoresDim[scoresBoxDim]; |
5358 | } |
5359 | |
5360 | // allocating maximum because we don't know how many boxes will actually be |
5361 | // detected. |
5362 | std::vector<dim_t> newDim = {scoresDim[scoresBatchDim] * |
5363 | scoresDim[scoresClassDim] * |
5364 | static_cast<dim_t>(maxOutputBoxesPerClass), |
5365 | 3}; |
5366 | auto indicesTy = getParent()->uniqueType(elTy, newDim); |
5367 | auto numberOfSelectedIndicesTy = getParent()->uniqueType( |
5368 | elTy, |
5369 | {boxesDim[boxesBatchDim] * static_cast<dim_t>(maxOutputBoxesPerClass)}); |
5370 | return addNode(new NonMaxSuppressionNode( |
5371 | name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox, |
5372 | maxOutputBoxesPerClass, iouThreshold, scoreThreshold, false)); |
5373 | } |
5374 | |
5375 | NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX( |
5376 | llvm::StringRef name, NodeValue boxes, NodeValue scores, |
5377 | int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold, |
5378 | float scoreThreshold) { |
5379 | return createNonMaxSuppressionONNX(name, boxes, scores, centerPointBox, |
5380 | maxOutputBoxesPerClass, iouThreshold, |
5381 | scoreThreshold, ElemKind::Int64ITy); |
5382 | } |
5383 | |
5384 | NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX( |
5385 | llvm::StringRef name, NodeValue boxes, NodeValue scores, |
5386 | int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold, |
5387 | float scoreThreshold, TypeRef indicesTy) { |
5388 | auto boxesDim = boxes.dims(); |
5389 | assert(maxOutputBoxesPerClass > 0 && "Invalid maxOutputBoxesPerClass." ); |
5390 | |
5391 | // allocating maximum because we don't know how many boxes will actually be |
5392 | // detected. |
5393 | auto numberOfSelectedIndicesTy = getParent()->uniqueType( |
5394 | ElemKind::Int32ITy, {1, 1, 1, |
5395 | boxesDim[boxesDim.size() - 2] * |
5396 | static_cast<dim_t>(maxOutputBoxesPerClass)}); |
5397 | return addNode(new NonMaxSuppressionNode( |
5398 | name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox, |
5399 | maxOutputBoxesPerClass, iouThreshold, scoreThreshold, false)); |
5400 | } |
5401 | |
5402 | TFLiteDetectionPostProcessNode *Function::createTFLiteDetectionPostProcess( |
5403 | llvm::StringRef name, NodeValue boxes, NodeValue scores, NodeValue anchors, |
5404 | int32_t numClasses, int32_t maxDetections, int32_t maxClassesPerDetection, |
5405 | int32_t maxDetectionsPerClass, float iouThreshold, float scoreThreshold, |
5406 | float xScale, float yScale, float hScale, float wScale, bool regularNMS) { |
5407 | |
5408 | // Maximum number of detections depending on fast/regular method. |
5409 | dim_t numBoxes = anchors.dims()[0]; |
5410 | dim_t fastMaxDetections = |
5411 | std::min(numBoxes, static_cast<dim_t>(maxDetections)); |
5412 | dim_t regularMaxDetections = maxDetections; |
5413 | dim_t numMaxDetections = |
5414 | regularNMS ? regularMaxDetections : fastMaxDetections; |
5415 | |
5416 | // Create output types. We allocate enough size for the worst possible case |
5417 | // when the maximum number of detections is obtained. |
5418 | std::vector<dim_t> detectionBoxesDims = {static_cast<dim_t>(numMaxDetections), |
5419 | 4}; |
5420 | TypeRef detectionBoxesTy = |
5421 | getParent()->uniqueType(ElemKind::FloatTy, detectionBoxesDims); |
5422 | std::vector<dim_t> detectionClassesDims = { |
5423 | static_cast<dim_t>(numMaxDetections)}; |
5424 | TypeRef detectionClassesTy = |
5425 | getParent()->uniqueType(ElemKind::Int32ITy, detectionClassesDims); |
5426 | std::vector<dim_t> detectionScoresDims = { |
5427 | static_cast<dim_t>(numMaxDetections)}; |
5428 | TypeRef detectionScoresTy = |
5429 | getParent()->uniqueType(ElemKind::FloatTy, detectionScoresDims); |
5430 | TypeRef numDetectionsTy = getParent()->uniqueType(ElemKind::Int32ITy, {1}); |
5431 | |
5432 | // Dequantize inputs if quantized. |
5433 | if (boxes.getType()->isQuantizedType()) { |
5434 | boxes = createDequantize(name.str() + ".dequant.boxes" , boxes, |
5435 | ElemKind::FloatTy); |
5436 | } |
5437 | if (scores.getType()->isQuantizedType()) { |
5438 | scores = createDequantize(name.str() + ".dequant.scores" , scores, |
5439 | ElemKind::FloatTy); |
5440 | } |
5441 | if (anchors.getType()->isQuantizedType()) { |
5442 | anchors = createDequantize(name.str() + ".dequant.anchors" , anchors, |
5443 | ElemKind::FloatTy); |
5444 | } |
5445 | |
5446 | // Create node. |
5447 | return addNode(new TFLiteDetectionPostProcessNode( |
5448 | name, detectionBoxesTy, detectionClassesTy, detectionScoresTy, |
5449 | numDetectionsTy, boxes, scores, anchors, numClasses, maxDetections, |
5450 | maxClassesPerDetection, maxDetectionsPerClass, iouThreshold, |
5451 | scoreThreshold, xScale, yScale, hScale, wScale, regularNMS)); |
5452 | } |
5453 | |
5454 | Constant *Function::createCosineWindow(llvm::StringRef name, dim_t length) { |
5455 | auto window = getParent()->createConstant(ElemKind::FloatTy, {length}, name); |
5456 | auto windowH = window->getHandle<float>(); |
5457 | for (dim_t n = 0; n < length; n++) { |
5458 | windowH.raw(n) = |
5459 | 0.5 - 0.5 * cos(2.0 * M_PI * (double)(n) / (double)(length)); |
5460 | } |
5461 | return window; |
5462 | } |
5463 | |
5464 | Constant *Function::createFFTTwiddleFactors(llvm::StringRef name, |
5465 | dim_t fftLength) { |
5466 | auto twiddleFactors = |
5467 | getParent()->createConstant(ElemKind::FloatTy, {2 * fftLength}, name); |
5468 | auto twiddleFactorsH = twiddleFactors->getHandle<float>(); |
5469 | for (dim_t k = 0; k < fftLength; k++) { |
5470 | twiddleFactorsH.raw(2 * k + 0) = |
5471 | cos(2.0 * M_PI * (double)(k) / (double)(fftLength)); |
5472 | twiddleFactorsH.raw(2 * k + 1) = |
5473 | -sin(2.0 * M_PI * (double)(k) / (double)(fftLength)); |
5474 | } |
5475 | return twiddleFactors; |
5476 | } |
5477 | |
5478 | Constant *Function::createFFTBitReverseIndices(llvm::StringRef name, |
5479 | dim_t fftLength) { |
5480 | assert(fftLength >= 1 && "FFT length must be at least 1!" ); |
5481 | // Local function to reverse the bits of a number. |
5482 | auto reverseBits = [](uint64_t bits, dim_t numBits) -> uint64_t { |
5483 | assert(((0 <= numBits) && (numBits <= 64)) && |
5484 | "Maximum number of bits exceeded for 'reverseBits' function!" ); |
5485 | if (numBits <= 0) { |
5486 | return 0; |
5487 | } |
5488 | uint64_t bitsRev = 0; |
5489 | uint64_t bitsMask = 1; |
5490 | uint64_t bitsRevMask = 1 << (numBits - 1); |
5491 | for (dim_t idx = 0; idx < numBits; idx++) { |
5492 | if (bits & bitsMask) { |
5493 | bitsRev |= bitsRevMask; |
5494 | } |
5495 | bitsMask <<= 1; |
5496 | bitsRevMask >>= 1; |
5497 | } |
5498 | return bitsRev; |
5499 | }; |
5500 | auto bitReverseIndices = |
5501 | getParent()->createConstant(ElemKind::Int32ITy, {fftLength}, name); |
5502 | auto bitReverseIndicesH = bitReverseIndices->getHandle<int32_t>(); |
5503 | dim_t numBits = std::log2((double)fftLength); |
5504 | for (dim_t idx = 0; idx < fftLength; idx++) { |
5505 | bitReverseIndicesH.raw(idx) = |
5506 | static_cast<int32_t>(reverseBits(idx, numBits)); |
5507 | } |
5508 | return bitReverseIndices; |
5509 | } |
5510 | |
5511 | Constant *Function::createFFTComplexToRealWeights(llvm::StringRef name, |
5512 | dim_t fftLength, |
5513 | dim_t outLength) { |
5514 | auto complexToRealWeights = |
5515 | getParent()->createConstant(ElemKind::FloatTy, {2 * outLength}, name); |
5516 | auto complexToRealWeightsH = complexToRealWeights->getHandle<float>(); |
5517 | for (dim_t k = 0; k < outLength; k++) { |
5518 | complexToRealWeightsH.raw(2 * k + 0) = |
5519 | 0.5 * (1 - sin(2.0 * M_PI * (double)(k) / (double)(fftLength))); |
5520 | complexToRealWeightsH.raw(2 * k + 1) = |
5521 | -0.5 * cos(2.0 * M_PI * (double)(k) / (double)(fftLength)); |
5522 | } |
5523 | return complexToRealWeights; |
5524 | } |
5525 | |
5526 | AudioSpectrogramNode *Function::createAudioSpectrogram(llvm::StringRef name, |
5527 | NodeValue input, |
5528 | int64_t windowSize, |
5529 | int64_t windowStride, |
5530 | bool magnitudeSquared) { |
5531 | // Output shape will be windowCount x (fftLength / 2 + 1). |
5532 | dim_t inputLength = input.getType()->size(); |
5533 | dim_t windowCount = std::floor((inputLength - windowSize) / windowStride) + 1; |
5534 | dim_t fftLength = 1 << (dim_t)std::ceil(std::log2((double)windowSize)); |
5535 | auto spectrogramTy = getParent()->uniqueType( |
5536 | ElemKind::FloatTy, {windowCount, fftLength / 2 + 1}); |
5537 | |
5538 | // Create a cosine FFT windowing function. |
5539 | auto window = createCosineWindow(std::string(name) + ".Window" , windowSize); |
5540 | |
5541 | // Create the FFT weights for a fftLength/2 complex FFT. |
5542 | auto twiddleFactors = createFFTTwiddleFactors( |
5543 | std::string(name) + ".TwiddleFactors" , fftLength / 2); |
5544 | auto bitReverseIndices = createFFTBitReverseIndices( |
5545 | std::string(name) + ".BitReverseIndices" , fftLength / 2); |
5546 | |
5547 | // Create the complex to real FFT mapping coefficients. |
5548 | // For small FFT length make sure to generate at least 1 coefficient. |
5549 | auto complexToRealWeights = createFFTComplexToRealWeights( |
5550 | std::string(name) + ".ComplexToRealWeights" , fftLength, |
5551 | (fftLength / 4) >= 1 ? (fftLength / 4) : 1); |
5552 | |
5553 | // Create AudioSpectrogram node. |
5554 | return addNode(new AudioSpectrogramNode( |
5555 | name, spectrogramTy, input, window, twiddleFactors, bitReverseIndices, |
5556 | complexToRealWeights, windowSize, windowStride, magnitudeSquared)); |
5557 | } |
5558 | |
5559 | void Function::createMelWeights(llvm::StringRef prefix, dim_t spectrogramLength, |
5560 | float sampleRate, float lowerFrequency, |
5561 | float upperFrequency, dim_t filterBankCount, |
5562 | Constant *&melWeights, Constant *&melRanges) { |
5563 | auto fftLength = 2 * (spectrogramLength - 1); |
5564 | dim_t numFreqBins = fftLength / 2; |
5565 | dim_t numMelBins = filterBankCount; |
5566 | |
5567 | // Mel frequency scale local lambda function. |
5568 | auto melFreqScale = [](float freq) -> float { |
5569 | return 1127.0f * logf(1.0f + freq / 700.0f); |
5570 | }; |
5571 | |
5572 | // Always exclude DC (TensorFlow implementation choice from HTK). |
5573 | float freqDelta = sampleRate / (float)(fftLength); |
5574 | dim_t freqIdxMin = (dim_t)(1.5 + (lowerFrequency / freqDelta)); |
5575 | dim_t freqIdxMax = (dim_t)(upperFrequency / freqDelta); |
5576 | freqIdxMax = (freqIdxMax >= numFreqBins) ? numFreqBins : freqIdxMax; |
5577 | |
5578 | // Create Mel ranges constant. |
5579 | melRanges = getParent()->createConstant(ElemKind::Int32ITy, {2 * numMelBins}, |
5580 | std::string(prefix) + ".MelRanges" ); |
5581 | auto melRangesH = melRanges->getHandle<int32_t>(); |
5582 | |
5583 | // Mel weights and frequency start/stop (inclusive) buffers. |
5584 | auto melBinFreqWeights = std::make_unique<float[]>(numMelBins * numFreqBins); |
5585 | dim_t melBinFreqWeightsNum = 0; |
5586 | |
5587 | // Mel frequency limits. |
5588 | float melFreqLower = melFreqScale(lowerFrequency); |
5589 | float melFreqUpper = melFreqScale(upperFrequency); |
5590 | float melFreqDelta = (melFreqUpper - melFreqLower) / (numMelBins + 1); |
5591 | for (dim_t melIdx = 0; melIdx < numMelBins; melIdx++) { |
5592 | |
5593 | float melFreqLeft = melFreqLower + (melIdx + 0) * melFreqDelta; |
5594 | float melFreqCenter = melFreqLower + (melIdx + 1) * melFreqDelta; |
5595 | float melFreqRight = melFreqLower + (melIdx + 2) * melFreqDelta; |
5596 | |
5597 | int32_t freqIdxStart = -1; |
5598 | int32_t freqIdxStop = -2; |
5599 | |
5600 | for (dim_t freqIdx = freqIdxMin; freqIdx <= freqIdxMax; freqIdx++) { |
5601 | float melFreq = melFreqScale(freqIdx * freqDelta); |
5602 | if ((melFreqLeft < melFreq) && (melFreq < melFreqRight)) { |
5603 | |
5604 | // Compute frequency bin weight for this Mel bin. |
5605 | float weight = 1.0f - std::abs(melFreq - melFreqCenter) / melFreqDelta; |
5606 | |
5607 | // Store the frequency bin weight. |
5608 | melBinFreqWeights[melBinFreqWeightsNum++] = weight; |
5609 | |
5610 | // Update frequency bin start/stop index. |
5611 | if (freqIdxStart == -1) { |
5612 | freqIdxStart = freqIdx; |
5613 | } |
5614 | freqIdxStop = freqIdx; |
5615 | } |
5616 | } |
5617 | |
5618 | // Store the frequency bin start/stop index. |
5619 | melRangesH.raw(2 * melIdx + 0) = freqIdxStart; |
5620 | melRangesH.raw(2 * melIdx + 1) = freqIdxStop; |
5621 | } |
5622 | |
5623 | // Validate Mel ranges. |
5624 | dim_t melBinFreqWeightsNumValidate = 0; |
5625 | for (dim_t melIdx = 0; melIdx < numMelBins; melIdx++) { |
5626 | int32_t freqIdxRange = |
5627 | melRangesH.raw(2 * melIdx + 1) - melRangesH.raw(2 * melIdx + 0) + 1; |
5628 | melBinFreqWeightsNumValidate += freqIdxRange; |
5629 | } |
5630 | assert(melBinFreqWeightsNum == melBinFreqWeightsNumValidate && |
5631 | "Invalid Mel ranges" ); |
5632 | |
5633 | // Create Mel weights constant. |
5634 | melWeights = |
5635 | getParent()->createConstant(ElemKind::FloatTy, {melBinFreqWeightsNum}, |
5636 | std::string(prefix) + ".MelWeights" ); |
5637 | auto melWeightsH = melWeights->getHandle<float>(); |
5638 | for (dim_t idx = 0; idx < melBinFreqWeightsNum; idx++) { |
5639 | melWeightsH.raw(idx) = melBinFreqWeights[idx]; |
5640 | } |
5641 | } |
5642 | |
5643 | Constant *Function::createDCTMat(llvm::StringRef name, dim_t N, dim_t K) { |
5644 | Constant *dctMat = |
5645 | getParent()->createConstant(ElemKind::FloatTy, {K, N}, name); |
5646 | auto dctMatH = dctMat->getHandle<float>(); |
5647 | float dctFact = (float)sqrt(2.0 / (double)(N)); |
5648 | for (dim_t k = 0; k < K; k++) { |
5649 | for (dim_t n = 0; n < N; n++) { |
5650 | dctMatH.at({k, n}) = |
5651 | dctFact * cos(M_PI / (double)(N) * ((double)(n) + 0.5) * (double)(k)); |
5652 | } |
5653 | } |
5654 | return dctMat; |
5655 | } |
5656 | |
5657 | MFCCNode *Function::createMFCC(llvm::StringRef name, NodeValue spectrogram, |
5658 | float sampleRate, float lowerFrequency, |
5659 | float upperFrequency, int64_t filterBankCount, |
5660 | int64_t numCoefficients) { |
5661 | // Create the Mel weights. |
5662 | dim_t spectrogramLength = spectrogram.dims()[1]; |
5663 | Constant *melWeights; |
5664 | Constant *melRanges; |
5665 | createMelWeights(name, spectrogramLength, sampleRate, lowerFrequency, |
5666 | upperFrequency, filterBankCount, melWeights, melRanges); |
5667 | |
5668 | // Create the DCT transform matrix. |
5669 | Constant *dctMat = createDCTMat(std::string(name) + ".DCTMat" , |
5670 | filterBankCount, numCoefficients); |
5671 | |
5672 | // Output shape will be windowCount x numCoefficients. |
5673 | dim_t windowCount = spectrogram.dims()[0]; |
5674 | auto coefficientsTy = getParent()->uniqueType( |
5675 | ElemKind::FloatTy, {windowCount, static_cast<dim_t>(numCoefficients)}); |
5676 | |
5677 | // Create MFCC node. |
5678 | return addNode(new MFCCNode(name, coefficientsTy, spectrogram, melWeights, |
5679 | melRanges, dctMat, sampleRate, lowerFrequency, |
5680 | upperFrequency, filterBankCount, |
5681 | numCoefficients)); |
5682 | } |
5683 | |
5684 | ROIAlignNode * |
5685 | Function::createROIAlign(llvm::StringRef name, NodeValue featureMap, |
5686 | NodeValue boxes, NodeValue batchIndices, |
5687 | uint32_t outputHeight, uint32_t outputWidth, |
5688 | uint32_t samplingRatio, float spatialScale, |
5689 | bool aligned, bool rotated, PoolingMode mode) { |
5690 | auto featureMapDims = featureMap.dims(); |
5691 | auto boxesDims = boxes.dims(); |
5692 | std::vector<dim_t> outDim = {boxesDims[0], outputHeight, outputWidth, |
5693 | featureMapDims[3]}; |
5694 | auto outTy = |
5695 | getParent()->uniqueTypeWithNewShape(featureMap.getType(), outDim); |
5696 | return addNode(new ROIAlignNode( |
5697 | name, outTy, featureMap, boxes, batchIndices, mode, outputHeight, |
5698 | outputHeight, samplingRatio, spatialScale, aligned, rotated)); |
5699 | } |
5700 | |
5701 | BBoxTransformNode *Function::createBBoxTransform( |
5702 | llvm::StringRef name, NodeValue rois, NodeValue deltas, NodeValue imInfo, |
5703 | llvm::ArrayRef<float> weights, bool applyScale, bool rotated, |
5704 | bool angleBoundOn, int64_t angleBoundLo, int64_t angleBoundHi, |
5705 | float clipAngleThresh, bool legacyPlusOne) { |
5706 | auto deltasDims = deltas.dims(); |
5707 | auto imInfoDims = imInfo.dims(); |
5708 | |
5709 | auto boxOutTy = getParent()->uniqueTypeWithNewShape( |
5710 | rois.getType(), {deltasDims[0], deltasDims[1]}); |
5711 | // Forcing roiBatchSplitsTy to always be Float. |
5712 | auto roiBatchSplitsTy = |
5713 | getParent()->uniqueType(rois.getElementType(), {imInfoDims[0]}); |
5714 | |
5715 | return addNode(new BBoxTransformNode( |
5716 | name, boxOutTy, roiBatchSplitsTy, rois, deltas, imInfo, weights, |
5717 | applyScale, rotated, angleBoundOn, angleBoundLo, angleBoundHi, |
5718 | clipAngleThresh, legacyPlusOne)); |
5719 | } |
5720 | |
5721 | ExternalFunctionCallNode *Function::createExternalFunctionCall( |
5722 | llvm::StringRef name, TypeRef outTy, llvm::ArrayRef<glow::NodeValue> inputs, |
5723 | llvm::StringRef funcName, llvm::StringRef funcImpl, |
5724 | llvm::StringRef funcKind) { |
5725 | return addNode(new ExternalFunctionCallNode(name.str(), outTy, inputs, |
5726 | funcName.str(), funcImpl.str(), |
5727 | funcKind.str())); |
5728 | } |
5729 | |
5730 | //===----------------------------------------------------------------------===// |
5731 | // Graph dumping and printing |
5732 | //===----------------------------------------------------------------------===// |
5733 | |
5734 | void Function::dump() const { |
5735 | llvm::outs() << "Graph structure " << getName() << ":\n" ; |
5736 | for (auto &n : nodes_) { |
5737 | llvm::outs() << n.getDebugDesc(); |
5738 | } |
5739 | } |
5740 | |
5741 | std::string Function::toString(bool skipUsersForStorage, bool skipName) const { |
5742 | std::string storage; |
5743 | llvm::raw_string_ostream os(storage); |
5744 | dump(os, skipUsersForStorage, skipName); |
5745 | return os.str(); |
5746 | } |
5747 | |
5748 | llvm::hash_code Function::getHash() const { |
5749 | // Omit function name when generating the hash. |
5750 | return llvm::hash_value(toString(/* skipUsersForStorage */ false, |
5751 | /* skipName */ true)); |
5752 | } |
5753 | |
5754 | void Function::dump(llvm::raw_ostream &os, bool skipUsersForStorage, |
5755 | bool skipName) const { |
5756 | os << "Graph structure" ; |
5757 | if (!skipName) { |
5758 | os << " " << getName(); |
5759 | } |
5760 | os << ":\n" ; |
5761 | std::set<const Node *, SortNamed> sorted; |
5762 | for (const Node &n : nodes_) { |
5763 | sorted.insert(&n); |
5764 | } |
5765 | for (auto *n : sorted) { |
5766 | os << n->getDebugDesc(); |
5767 | } |
5768 | for (auto *C : getNamedSorted(findConstants())) { |
5769 | os << C->getDebugDesc(skipUsersForStorage); |
5770 | } |
5771 | for (auto *P : getNamedSorted(findPlaceholders())) { |
5772 | os << P->getDebugDesc(skipUsersForStorage); |
5773 | } |
5774 | } |
5775 | |
5776 | /// We can't use NodeWalker here, because it ignores result indices, which |
5777 | /// are critical in generating detailed debug output. |
5778 | class FunctionDottyPrinter : public AbstractDottyPrinter { |
5779 | // A set of already visited (during graph walk) nodes. |
5780 | std::unordered_set<Node *> visitedNodes_{}; |
5781 | |
5782 | /// Recursively traverses inputs of node \p N using Deep First Search. |
5783 | /// Each node will be visited no more than once. The method also dumps |
5784 | /// edges with their port identifiers in dotty format. |
5785 | void visitNode(Node *N) { |
5786 | if (visitedNodes_.find(N) != visitedNodes_.end()) |
5787 | return; |
5788 | visitedNodes_.insert(N); |
5789 | |
5790 | dumpNode(N, false); |
5791 | |
5792 | // Print edges for the predicate field, if it's used. |
5793 | if (N->hasPredicate()) { |
5794 | auto pred = N->getPredicate(); |
5795 | size_t resNo = pred.getResNo(); |
5796 | std::ostringstream edge; |
5797 | edge << pred.getNode()->getName().str() << ":" |
5798 | << pred.getNode()->getOutputName(resNo).str() << " -> " |
5799 | << N->getName().str() << ":w" ; |
5800 | dumpEdgeStyle(N, 0, pred, edge); |
5801 | edges_.insert(edge.str()); |
5802 | visitNode(pred); |
5803 | } |
5804 | |
5805 | for (size_t i = 0; i < N->getNumInputs(); i++) { |
5806 | Node *to = N->getNthInput(i).getNode(); |
5807 | size_t resNo = N->getNthInput(i).getResNo(); |
5808 | |
5809 | std::ostringstream edge; |
5810 | edge << to->getName().str() << ":" << to->getOutputName(resNo).str() |
5811 | << " -> " << N->getName().str() << ":" << N->getInputName(i); |
5812 | dumpEdgeStyle(N, i, to, edge); |
5813 | edges_.insert(edge.str()); |
5814 | |
5815 | visitNode(to); |
5816 | } |
5817 | } |
5818 | |
5819 | public: |
5820 | void visitGraph(Function *F) { |
5821 | // Sort nodes before printing the dot so we can diff dot files. |
5822 | std::set<Node *, SortNamed> sorted; |
5823 | for (Node &N : F->getNodes()) { |
5824 | sorted.insert(&N); |
5825 | } |
5826 | for (auto *N : sorted) { |
5827 | visitNode(N); |
5828 | } |
5829 | } |
5830 | }; |
5831 | |
5832 | std::string Function::dumpDAG() { |
5833 | llvm::SmallString<64> dotPath; |
5834 | llvm::sys::fs::createTemporaryFile("dotty_graph_dump" , "dot" , dotPath); |
5835 | dumpDAG(dotPath); |
5836 | |
5837 | return std::string(dotPath.begin(), dotPath.end()); |
5838 | } |
5839 | |
5840 | /// Local utility function to convert an existing DOT file \p dotFile to |
5841 | /// the right format suggested by the file name extension. For example if |
5842 | /// the file name ends with ".pdf" then the file will be converted to PDF. |
5843 | /// For conversion we use the "dot" application (assumed available) and |
5844 | /// the conversion is done in place. |
5845 | static void convertDotFileToRightFormat(llvm::StringRef dotFile) { |
5846 | static const std::vector<std::string> supportedFormats = {"pdf" , "svg" , |
5847 | "png" }; |
5848 | // Get file extension. |
5849 | llvm::StringRef extWithDot = llvm::sys::path::extension(dotFile); |
5850 | std::string ext = extWithDot.take_back(extWithDot.size() - 1).str(); |
5851 | // Convert to new format if supported. |
5852 | if (std::find(supportedFormats.begin(), supportedFormats.end(), ext) != |
5853 | supportedFormats.end()) { |
5854 | std::string cmd = |
5855 | "dot -T" + ext + " " + dotFile.str() + " -o " + dotFile.str(); |
5856 | std::string cmdErr = |
5857 | "Error running DOT conversion application with command '" + cmd + |
5858 | "'! Check that you have the 'dot' application installed on your " |
5859 | "system in order to convert files from DOT format to other formats! " |
5860 | "Otherwise choose the extension of the DOT file to '.dot'!" ; |
5861 | CHECK(!system(cmd.c_str())) << cmdErr; |
5862 | } |
5863 | } |
5864 | |
5865 | void Function::dumpDAG(llvm::StringRef dotFilename) { |
5866 | llvm::StringRef legalDotFilename = dotFilename.take_back(255); |
5867 | llvm::outs() << "Writing dotty graph for Function to: " << legalDotFilename |
5868 | << '\n'; |
5869 | if (dotFilename.size() > 255) { |
5870 | llvm::outs() << "WARNING: Filename " << dotFilename |
5871 | << " is longer than 255 characters, and so was truncated to " |
5872 | << legalDotFilename << '\n'; |
5873 | } |
5874 | |
5875 | FunctionDottyPrinter DP; |
5876 | |
5877 | DP.visitGraph(this); |
5878 | |
5879 | std::ofstream myfile; |
5880 | myfile.open(legalDotFilename.str()); |
5881 | if (myfile.fail()) { |
5882 | LOG(ERROR) << "Unable to open " << legalDotFilename.str() |
5883 | << ", reason: " << strerror(errno); |
5884 | } else { |
5885 | DP.dumpAll(myfile); |
5886 | } |
5887 | myfile.close(); |
5888 | convertDotFileToRightFormat(legalDotFilename); |
5889 | } |
5890 | |
5891 | void Function::dumpDAG(const char *dotFilename) { |
5892 | dumpDAG(llvm::StringRef(dotFilename)); |
5893 | } |
5894 | |
5895 | Node *Function::getNodeByName(llvm::StringRef name) { |
5896 | for (auto &N : getNodes()) { |
5897 | if (N.getName().equals(name)) { |
5898 | return &N; |
5899 | } |
5900 | } |
5901 | return nullptr; |
5902 | } |
5903 | |
5904 | NodeValue Function::getNodeValueByName(llvm::StringRef name) { |
5905 | auto strPair = name.split(':'); |
5906 | // Search node, constant or placeholder. |
5907 | auto nodeName = strPair.first; |
5908 | Node *node = getNodeByName(nodeName); |
5909 | node = node ? node : getParent()->getConstantByName(nodeName); |
5910 | node = node ? node : getParent()->getPlaceholderByNameSlow(nodeName); |
5911 | if (!node || (node->getNumResults() == 0)) { |
5912 | return NodeValue(); |
5913 | } |
5914 | // Get result number. |
5915 | if (node->getNumResults() == 1) { |
5916 | return NodeValue(node); |
5917 | } else { |
5918 | unsigned resNo = 0; |
5919 | CHECK(!strPair.second.getAsInteger(0, resNo)) << "Invalid node value name!" ; |
5920 | return NodeValue(node, resNo); |
5921 | } |
5922 | } |
5923 | |
5924 | void Module::eraseConstant(ConstList::iterator I) { |
5925 | if (I == constants_.end()) |
5926 | return; |
5927 | logStorageDeletion(functions_, *I); |
5928 | delete *I; |
5929 | constants_.erase(I); |
5930 | } |
5931 | |
5932 | void Module::erasePlaceholder(PlaceholderList::iterator I) { |
5933 | if (I == placeholders_.end()) { |
5934 | return; |
5935 | } |
5936 | |
5937 | logStorageDeletion(functions_, *I); |
5938 | delete *I; |
5939 | placeholders_.erase(I); |
5940 | } |
5941 | |
5942 | void Function::eraseNode(NodesList::iterator I) { |
5943 | // Log node deletion. |
5944 | logCtx_->logNodeDeletion(*I); |
5945 | |
5946 | nodes_.erase(I); |
5947 | } |
5948 | |
5949 | Constant *Module::getConstantByName(llvm::StringRef name) const { |
5950 | for (auto *V : getConstants()) { |
5951 | if (V->getName() == name) |
5952 | return V; |
5953 | } |
5954 | return nullptr; |
5955 | } |
5956 | |
5957 | void Function::randomizeConstants( |
5958 | const std::map<Kinded::Kind, std::set<unsigned>> &ignoredConstants) { |
5959 | LOG(INFO) << "Randomize Constants............" ; |
5960 | for (Constant *c : getParent()->getConstants()) { |
5961 | bool usedHere = false; |
5962 | bool usedElsewhere = false; |
5963 | bool ignored = false; |
5964 | |
5965 | for (auto &user : c->getUsers()) { |
5966 | auto *nodeUser = user.getUser(); |
5967 | if (nodeUser->getParent() == this) { |
5968 | usedHere = true; |
5969 | } else { |
5970 | usedElsewhere = true; |
5971 | } |
5972 | |
5973 | auto kind = nodeUser->getKind(); |
5974 | if (ignoredConstants.count(kind)) { |
5975 | for (auto idx : ignoredConstants.at(kind)) { |
5976 | if (nodeUser->getNthInput(idx).getNode() == c) { |
5977 | ignored = true; |
5978 | break; |
5979 | } |
5980 | } |
5981 | } |
5982 | } |
5983 | |
5984 | if (!usedHere) { |
5985 | continue; |
5986 | } |
5987 | |
5988 | if (usedElsewhere) { |
5989 | LOG(FATAL) << "Can't randomize Constant \"" << c->getName().str() |
5990 | << "\" because it is used by another function" ; |
5991 | } |
5992 | |
5993 | if (ignored) { |
5994 | continue; |
5995 | } |
5996 | |
5997 | auto &payload = c->getPayloadMutable(); |
5998 | |
5999 | switch (c->getElementType()) { |
6000 | case ElemKind::FloatTy: { |
6001 | auto H = payload.getHandle<float>(); |
6002 | auto minMaxArg = H.minMaxArg(); |
6003 | H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG()); |
6004 | break; |
6005 | } |
6006 | case ElemKind::Float16Ty: { |
6007 | auto H = payload.getHandle<float16_t>(); |
6008 | auto minMaxArg = H.minMaxArg(); |
6009 | H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG()); |
6010 | break; |
6011 | } |
6012 | case ElemKind::BFloat16Ty: { |
6013 | auto H = payload.getHandle<bfloat16_t>(); |
6014 | auto minMaxArg = H.minMaxArg(); |
6015 | H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG()); |
6016 | break; |
6017 | } |
6018 | case ElemKind::Int8QTy: { |
6019 | auto H = payload.getHandle<int8_t>(); |
6020 | auto minMaxArg = H.minMaxArg(); |
6021 | H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG()); |
6022 | break; |
6023 | } |
6024 | case ElemKind::UInt8QTy: { |
6025 | auto H = payload.getHandle<uint8_t>(); |
6026 | auto minMaxArg = H.minMaxArg(); |
6027 | H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG()); |
6028 | break; |
6029 | } |
6030 | case ElemKind::Int16QTy: { |
6031 | auto H = payload.getHandle<int16_t>(); |
6032 | auto minMaxArg = H.minMaxArg(); |
6033 | H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG()); |
6034 | break; |
6035 | } |
6036 | case ElemKind::Int32QTy: { |
6037 | auto H = payload.getHandle<int32_t>(); |
6038 | auto minMaxArg = H.minMaxArg(); |
6039 | H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG()); |
6040 | break; |
6041 | } |
6042 | case ElemKind::Int32ITy: { |
6043 | auto H = payload.getHandle<int32_t>(); |
6044 | auto minMaxArg = H.minMaxArg(); |
6045 | H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG()); |
6046 | break; |
6047 | } |
6048 | case ElemKind::Int64ITy: { |
6049 | auto H = payload.getHandle<int64_t>(); |
6050 | auto minMaxArg = H.minMaxArg(); |
6051 | H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG()); |
6052 | break; |
6053 | } |
6054 | case ElemKind::UInt8FusedQTy: |
6055 | payload.getHandle<uint8_t>().randomize( |
6056 | std::numeric_limits<uint8_t>::lowest(), |
6057 | std::numeric_limits<uint8_t>::max(), getPRNG()); |
6058 | break; |
6059 | case ElemKind::UInt8FusedFP16QTy: |
6060 | payload.getHandle<uint8_t>().randomize( |
6061 | std::numeric_limits<uint8_t>::lowest(), |
6062 | std::numeric_limits<uint8_t>::max(), getPRNG()); |
6063 | break; |
6064 | case ElemKind::UInt4FusedFP16QTy: |
6065 | payload.getHandle<uint8_t>().randomize( |
6066 | std::numeric_limits<uint8_t>::lowest(), |
6067 | std::numeric_limits<uint8_t>::max(), getPRNG()); |
6068 | break; |
6069 | case ElemKind::UInt4FusedQTy: |
6070 | payload.getHandle<uint8_t>().randomize( |
6071 | std::numeric_limits<uint8_t>::lowest(), |
6072 | std::numeric_limits<uint8_t>::max(), getPRNG()); |
6073 | break; |
6074 | case ElemKind::BoolTy: |
6075 | payload.getHandle<bool>().randomize(false, true, getPRNG()); |
6076 | break; |
6077 | default: |
6078 | LOG(FATAL) << "Unsupported ElemKind" ; |
6079 | } |
6080 | } |
6081 | } |
6082 | |
6083 | Placeholder *Module::getPlaceholderByNameSlow(llvm::StringRef name) const { |
6084 | for (auto *P : getPlaceholders()) { |
6085 | if (P->getName() == name) { |
6086 | return P; |
6087 | } |
6088 | } |
6089 | |
6090 | return nullptr; |
6091 | } |
6092 | |
6093 | void Module::eraseConstant(Constant *N) { |
6094 | auto &vars = getConstants(); |
6095 | auto I = std::find(vars.begin(), vars.end(), N); |
6096 | eraseConstant(I); |
6097 | } |
6098 | |
6099 | void Function::eraseNode(Node *N) { |
6100 | if (Constant *V = dyn_cast<Constant>(N)) { |
6101 | return getParent()->eraseConstant(V); |
6102 | } |
6103 | assert(std::find_if(nodes_.begin(), nodes_.end(), |
6104 | [N](const Node &node) -> bool { return &node == N; }) != |
6105 | nodes_.end() && |
6106 | "Could not find node to delete!" ); |
6107 | eraseNode(N->getIterator()); |
6108 | } |
6109 | |
6110 | PlaceholderList Function::findPlaceholders() { |
6111 | PlaceholderList list; |
6112 | for (auto &PH : parent_->getPlaceholders()) { |
6113 | for (auto &user : PH->getUsers()) { |
6114 | if (user.getUser()->getParent() == this) { |
6115 | list.push_back(PH); |
6116 | break; |
6117 | } |
6118 | } |
6119 | } |
6120 | return list; |
6121 | } |
6122 | |
6123 | PlaceholderList Function::findPlaceholders() const { |
6124 | PlaceholderList list; |
6125 | for (auto &PH : parent_->getPlaceholders()) { |
6126 | for (auto &user : PH->getUsers()) { |
6127 | if (user.getUser()->getParent() == this) { |
6128 | list.push_back(PH); |
6129 | break; |
6130 | } |
6131 | } |
6132 | } |
6133 | return list; |
6134 | } |
6135 | |
6136 | ConstList Function::findConstants() { |
6137 | ConstList list; |
6138 | for (auto &constant : parent_->getConstants()) { |
6139 | for (auto &user : constant->getUsers()) { |
6140 | if (user.getUser()->getParent() == this) { |
6141 | list.push_back(constant); |
6142 | break; |
6143 | } |
6144 | } |
6145 | } |
6146 | return list; |
6147 | } |
6148 | |
6149 | ConstList Function::findConstants() const { |
6150 | ConstList list; |
6151 | for (auto &constant : parent_->getConstants()) { |
6152 | for (auto &user : constant->getUsers()) { |
6153 | if (user.getUser()->getParent() == this) { |
6154 | list.push_back(constant); |
6155 | break; |
6156 | } |
6157 | } |
6158 | } |
6159 | return list; |
6160 | } |
6161 | |
6162 | Function *Function::clone(llvm::StringRef newName, |
6163 | llvm::DenseMap<const Node *, Node *> *map, |
6164 | llvm::DenseMap<const Node *, Node *> *currToNewMap) { |
6165 | Module *M = getParent(); |
6166 | auto *newF = M->createFunction(newName); |
6167 | return clone(newF, map, currToNewMap); |
6168 | } |
6169 | |
6170 | Function * |
6171 | Function::clone(Function *newF, llvm::DenseMap<const Node *, Node *> *map, |
6172 | llvm::DenseMap<const Node *, Node *> *currToNewMap) const { |
6173 | // Maps current nodes to new nodes. |
6174 | llvm::DenseMap<const Node *, Node *> currToNew; |
6175 | |
6176 | // Initialize the map from a user-provided map. |
6177 | if (currToNewMap) { |
6178 | currToNew.insert(currToNewMap->begin(), currToNewMap->end()); |
6179 | } |
6180 | |
6181 | // Clone all of the nodes in the function. |
6182 | for (auto &N : getNodes()) { |
6183 | Node *copy = N.clone(); |
6184 | // Record the copy relationship between the graphs. |
6185 | currToNew[&N] = copy; |
6186 | newF->addNode(copy); |
6187 | if (N.hasPredicate()) { |
6188 | copy->setPredicate(N.getPredicate()); |
6189 | } |
6190 | } |
6191 | |
6192 | // At this point we have a new invalid function that points into nodes in |
6193 | // the original function. Here we update the links between the nodes in the |
6194 | // new function. |
6195 | for (auto &N : newF->getNodes()) { |
6196 | // Fix each one of the inputs of this node. |
6197 | for (unsigned inp = 0, e = N.getNumInputs(); inp < e; inp++) { |
6198 | auto input = N.getNthInput(inp); |
6199 | |
6200 | auto it = currToNew.find(input.getNode()); |
6201 | if (it == currToNew.end()) { |
6202 | assert(isa<Storage>(input.getNode()) && |
6203 | "Could not find a mapping for some node!" ); |
6204 | continue; |
6205 | } |
6206 | |
6207 | // Update the node with the edge to the current graph. |
6208 | N.setNthInput(inp, NodeValue(it->second, input.getResNo())); |
6209 | } |
6210 | |
6211 | if (N.hasPredicate()) { |
6212 | auto it = currToNew.find(N.getPredicate().getNode()); |
6213 | if (it != currToNew.end()) { |
6214 | N.setPredicate(NodeValue(it->second, N.getPredicate().getResNo())); |
6215 | } |
6216 | } |
6217 | } |
6218 | |
6219 | // Record the node mapping into the external map. |
6220 | if (map) { |
6221 | assert(map->empty() && "The external map must be empty" ); |
6222 | for (auto it : currToNew) { |
6223 | map->insert(it); |
6224 | } |
6225 | } |
6226 | |
6227 | assert(newF->getNodes().size() == getNodes().size() && "Invalid func size" ); |
6228 | return newF; |
6229 | } |
6230 | |
6231 | /// Verify the input \p idx of a node \p N. Check that the node \p N is in the |
6232 | /// use-list of the corresponding input node. |
6233 | static bool verifyNodeInput(const Node &N, size_t idx) { |
6234 | auto input = N.getNthInput(idx); |
6235 | auto *refN = input.getNode(); |
6236 | // Check that N is in the use-list of the input node and there is a proper |
6237 | // entry for it. |
6238 | for (auto &U : refN->getUsers()) { |
6239 | if (U.getUser() == &N && *U.get() == input) { |
6240 | return true; |
6241 | } |
6242 | } |
6243 | |
6244 | report("Any node referencing another node N must be in the use-list of the " |
6245 | "node N" ); |
6246 | return false; |
6247 | } |
6248 | |
6249 | Module *Module::clone() const { |
6250 | auto *M = new Module; |
6251 | return clone(M); |
6252 | } |
6253 | |
6254 | Module *Module::clone(Module *M) const { |
6255 | // Maps current nodes to new nodes. |
6256 | llvm::DenseMap<const Node *, Node *> currToNew; |
6257 | // Clone placeholders. |
6258 | for (auto &PH : getPlaceholders()) { |
6259 | auto *copyPH = M->createPlaceholder(PH->getType(), PH->getName(), |
6260 | PH->isTraining(), PH->getLayout()); |
6261 | currToNew[PH] = copyPH; |
6262 | } |
6263 | // Clone constants. |
6264 | for (auto &C : getConstants()) { |
6265 | // Cloner cannot decide on its own what to do with constants having unowned |
6266 | // payloads. Some kind of policy/hook maybe needed in the future for |
6267 | // deciding what needs to be done in such cases. |
6268 | DCHECK(!C->getPayload().isUnowned()) |
6269 | << "Cannot copy constant " << C->getName().str() |
6270 | << ": Unowned payloads are not supported" ; |
6271 | auto *copyC = M->createConstant(C->getType(), C->getName(), C->getLayout()); |
6272 | copyC->assign(&C->getPayload()); |
6273 | currToNew[C] = copyC; |
6274 | } |
6275 | // Clone all functions. |
6276 | for (auto *F : getFunctions()) { |
6277 | // Create an empty clone function in the new module. |
6278 | auto *copyF = M->createFunction(F->getName()); |
6279 | // Clone function's body into the newly created cloned function. Use the |
6280 | // currToNew to properly map constants and placeholders. |
6281 | F->clone(copyF, nullptr, &currToNew); |
6282 | // Update all types by cloned types. |
6283 | for (auto &N : copyF->getNodes()) { |
6284 | for (unsigned idx = 0, e = N.getNumResults(); idx < e; ++idx) { |
6285 | N.setType(idx, M->uniqueType(*N.getType(idx))); |
6286 | } |
6287 | } |
6288 | } |
6289 | return M; |
6290 | } |
6291 | |
6292 | /// \returns True if \p n is a storage node (constant or placeholder) of the |
6293 | /// function \p F. |
6294 | static bool isGraphStorageNode(Node *n, const Function *F) { |
6295 | auto &vars = F->getParent()->getConstants(); |
6296 | auto &placeholders = F->getParent()->getPlaceholders(); |
6297 | |
6298 | if (Constant *V = dyn_cast<Constant>(n)) { |
6299 | return std::find(vars.begin(), vars.end(), V) != vars.end(); |
6300 | } |
6301 | |
6302 | if (Placeholder *P = dyn_cast<Placeholder>(n)) { |
6303 | return std::find(placeholders.begin(), placeholders.end(), P) != |
6304 | placeholders.end(); |
6305 | } |
6306 | |
6307 | return false; |
6308 | } |
6309 | |
6310 | /// Insert \p node in \p nameToNode and report an error if the insertion fails. |
6311 | /// \returns True if \p node was inserted into \p nameToNode. False otherwise. |
6312 | /// When true is returned that means that \p nameToNode had no other nodes |
6313 | /// registered under \p node.getName(). |
6314 | static bool |
6315 | insertAndReport(std::unordered_map<std::string, const Node *> &nameToNode, |
6316 | const Node &node, const Function &function) { |
6317 | bool inserted = expectCompareTrue( |
6318 | "Node is not unique" , |
6319 | nameToNode.insert({node.getName().str(), &node}).second, true, &function); |
6320 | if (!inserted) { |
6321 | std::string storage; |
6322 | llvm::raw_string_ostream msg(storage); |
6323 | /// Output extra information helping to find the error. |
6324 | msg << "The node with name '" << node.getName() |
6325 | << "' conflicts with a previous definition:\n" ; |
6326 | msg << "Current definition: " << node.getDebugDesc() << "\n" ; |
6327 | msg << "Previous definition: " |
6328 | << nameToNode[node.getName().str()]->getDebugDesc(); |
6329 | report(msg.str().c_str()); |
6330 | return false; |
6331 | } |
6332 | return true; |
6333 | } |
6334 | |
6335 | bool Function::verify(const Backend *backend) const { |
6336 | bool isValid = true; |
6337 | // Check if the layout verifying is disabled, which will accept all layout for |
6338 | // any ops. |
6339 | VLOG(1) << "Layout requirements checking is " |
6340 | << (glow::flags::DisableLayoutVerifying ? "disabled" : "enabled" ); |
6341 | if (!glow::flags::DisableLayoutVerifying) { |
6342 | if (backend) { |
6343 | if (backend->getTensorLayoutRequirements().isEnabled()) { |
6344 | isValid &= expectCompareTrue( |
6345 | "Expected correct backend-specific layouts for the graph" , |
6346 | verifyLayouts(*this, backend->getTensorLayoutRequirements()), true, |
6347 | this); |
6348 | } |
6349 | } else { |
6350 | // Always run verification pre-lowering / when we don't have backend: |
6351 | isValid &= expectCompareTrue( |
6352 | "Expected correct Glow canonical layouts for the graph" , |
6353 | verifyLayouts(*this, CanonicalTensorLayout::getInstance()), true, |
6354 | this); |
6355 | } |
6356 | } |
6357 | std::unordered_map<std::string, const Node *> nameToNode; |
6358 | |
6359 | for (auto *V : findConstants()) { |
6360 | isValid &= insertAndReport(nameToNode, *V, *this); |
6361 | isValid &= expectCompareTrue("Constant and its payload must have same type" , |
6362 | *V->getType(), V->getPayload().getType(), V); |
6363 | } |
6364 | |
6365 | nameToNode.clear(); |
6366 | for (const auto &N : nodes_) { |
6367 | isValid &= insertAndReport(nameToNode, N, *this); |
6368 | } |
6369 | |
6370 | // Any node referenced by one of the graph nodes should be part of the |
6371 | // Graph. |
6372 | for (const auto &N : nodes_) { |
6373 | for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) { |
6374 | auto &input = N.getNthInput(idx); |
6375 | // Verify each input of N. |
6376 | isValid &= verifyNodeInput(N, idx); |
6377 | bool foundNode = |
6378 | std::find(nodes_.begin(), nodes_.end(), *input) != nodes_.end(); |
6379 | isValid &= expectCompareTrue( |
6380 | "Every node referenced by one of the graph nodes should be part of " |
6381 | "the graph" , |
6382 | foundNode || isGraphStorageNode(input, this), true, &N); |
6383 | } |
6384 | } |
6385 | |
6386 | // Check that all uses of each node refer to this node. |
6387 | for (const auto &N : nodes_) { |
6388 | for (const auto &U : N.getUsers()) { |
6389 | isValid &= expectCompareTrue<const Node *>( |
6390 | "All uses of a node should refer to this node" , U.get()->getNode(), |
6391 | &N, &N); |
6392 | ; |
6393 | } |
6394 | } |
6395 | |
6396 | // Check that all types used by nodes belong to the parent module. |
6397 | auto &types = getParent()->getTypes(); |
6398 | for (const auto &N : nodes_) { |
6399 | for (size_t idx = 0, e = N.getNumResults(); idx < e; ++idx) { |
6400 | auto ty = N.getType(idx); |
6401 | bool foundType = |
6402 | std::find(types.begin(), types.end(), *ty) != types.end(); |
6403 | isValid &= expectCompareTrue( |
6404 | "Every type used by one of the graph nodes should be part of " |
6405 | "the graph" , |
6406 | foundType, true, &N); |
6407 | } |
6408 | } |
6409 | |
6410 | // Check that there are no zero volume tensors |
6411 | for (const auto &N : nodes_) { |
6412 | // Check inputs |
6413 | for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) { |
6414 | auto dims = N.getNthInput(idx).dims(); |
6415 | for (auto dim : dims) { |
6416 | if (dim == 0) { |
6417 | LOG(ERROR) << "Found 0 volume input in the " << idx |
6418 | << " input to node " << N.toString() << " with dims " |
6419 | << dims; |
6420 | return false; |
6421 | } |
6422 | } |
6423 | } |
6424 | |
6425 | // Check results |
6426 | for (size_t idx = 0, e = N.getNumResults(); idx < e; ++idx) { |
6427 | auto dims = N.getNthResult(idx).dims(); |
6428 | for (auto dim : dims) { |
6429 | if (dim == 0) { |
6430 | LOG(ERROR) << "Found 0 volume result in the " << idx |
6431 | << " result from node " << N.toString() << " with dims " |
6432 | << dims; |
6433 | return false; |
6434 | } |
6435 | } |
6436 | } |
6437 | } |
6438 | |
6439 | std::unordered_map<const Placeholder *, const Node *> placeholderWrittenTo; |
6440 | for (const auto &N : nodes_) { |
6441 | isValid &= |
6442 | expectCompareTrue("Node is not linked to the function it belongs" , |
6443 | N.getParent(), this, &N); |
6444 | isValid &= N.verify(); |
6445 | // Make sure all the placeholders are at most written once, and that |
6446 | // constants are never written to. |
6447 | for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) { |
6448 | // Placeholders and Constants have no input, so they can only be |
6449 | // written to via overwritten inputs. |
6450 | if (!N.isOverwrittenNthInput(idx)) { |
6451 | continue; |
6452 | } |
6453 | |
6454 | const Node *nthInputNode = N.getNthInput(idx).getNode(); |
6455 | isValid &= expectCompareTrue( |
6456 | "Constants can never be used as an overwritten input" , |
6457 | isa<Constant>(nthInputNode), false, nthInputNode); |
6458 | |
6459 | // Unlike Constants, Placeholders can be used at most once as |
6460 | // overwritten inputs. Keep a map of Placeholders to Nodes that used |
6461 | // them as overwritten inputs, which is also used later to check for |
6462 | // read-after-write dependence violations. |
6463 | const auto *ph = dyn_cast<Placeholder>(nthInputNode); |
6464 | if (!ph) { |
6465 | continue; |
6466 | } |
6467 | auto varToFirstDef = placeholderWrittenTo.find(ph); |
6468 | bool writtenOnce = expectCompareTrue( |
6469 | "Placeholder has more than one write" , |
6470 | varToFirstDef == placeholderWrittenTo.end(), true, ph); |
6471 | if (!writtenOnce) { |
6472 | isValid = false; |
6473 | std::string storage; |
6474 | llvm::raw_string_ostream msg(storage); |
6475 | |
6476 | msg << "Placeholder " << ph->getDebugDesc() << '\n'; |
6477 | msg << "has more than one write; second writer found:\n" ; |
6478 | msg << N.getDebugDesc() << '\n'; |
6479 | msg << varToFirstDef->second->getDebugDesc() << '\n'; |
6480 | |
6481 | report(msg.str().c_str()); |
6482 | } |
6483 | |
6484 | placeholderWrittenTo[ph] = &N; |
6485 | } |
6486 | } |
6487 | |
6488 | // Now check that the placeholders that are written to are either: |
6489 | // - Written by a save node, or |
6490 | // - Are only used by the node that writes them |
6491 | // If this check fails, that means we have implicit memory |
6492 | // dependencies that may not be honored by the scheduler. |
6493 | // Either the input IR is incorrect or the scheduler needs |
6494 | // fixing. |
6495 | for (const auto &varToWrite : placeholderWrittenTo) { |
6496 | if (isa<SaveNode>(varToWrite.second)) { |
6497 | continue; |
6498 | } |
6499 | for (const NodeUse &use : varToWrite.first->getUsers()) { |
6500 | const Node *user = use.getUser(); |
6501 | // Ignore users outside this function. |
6502 | if (user->getParent() != this) { |
6503 | continue; |
6504 | } |
6505 | isValid &= expectCompareTrue( |
6506 | "Implicit read after write memory dependency may not be honored" , |
6507 | user, varToWrite.second, user); |
6508 | } |
6509 | } |
6510 | return isValid; |
6511 | } |
6512 | |
6513 | SaveNode *glow::getOutputSave(Function *F, Placeholder *PH) { |
6514 | // if parent is set for PH, check if it is the same as provided Function. |
6515 | auto *PHP = PH->getParent(); |
6516 | if (PHP != nullptr && F != PHP) { |
6517 | return nullptr; |
6518 | } |
6519 | for (auto &use : PH->getUsers()) { |
6520 | if (auto *save = llvm::dyn_cast<SaveNode>(use.getUser())) { |
6521 | if (save->getParent() == F && save->getPlaceholder() == PH) { |
6522 | return save; |
6523 | } |
6524 | } |
6525 | } |
6526 | return nullptr; |
6527 | } |
6528 | |
6529 | Node *glow::recursiveClone(Function *newF, Node *node, NodeMap &currToNew) { |
6530 | Node *copy = node->clone(); |
6531 | currToNew[node] = copy; |
6532 | newF->addNode(copy); |
6533 | for (unsigned inp = 0, e = copy->getNumInputs(); inp < e; inp++) { |
6534 | auto input = copy->getNthInput(inp); |
6535 | auto it = currToNew.find(input.getNode()); |
6536 | Node *newInput; |
6537 | if (it != currToNew.end()) { |
6538 | newInput = it->second; |
6539 | } else if (llvm::isa<Storage>(input.getNode())) { |
6540 | continue; |
6541 | } else { |
6542 | newInput = recursiveClone(newF, input.getNode(), currToNew); |
6543 | } |
6544 | copy->setNthInput(inp, NodeValue(newInput, input.getResNo())); |
6545 | } |
6546 | return copy; |
6547 | } |
6548 | |
6549 | namespace glow { |
6550 | /// If \p PH is an output placeholder, \returns true. |
6551 | /// This is determined by checking if the PH has a user which uses the PH as an |
6552 | /// overwritten input. |
6553 | bool isOutput(const Placeholder *PH, const Function &F) { |
6554 | for (const auto &use : PH->getUsers()) { |
6555 | // Look through the inputs of the PH's users. If an input is overwritten |
6556 | // check if it's the PH, if it is return true. |
6557 | auto *user = use.getUser(); |
6558 | // Consider only users inside the same function. |
6559 | if (user->getParent() != &F) { |
6560 | continue; |
6561 | } |
6562 | for (unsigned i = 0, numInputs = user->getNumInputs(); i < numInputs; i++) { |
6563 | // If the input is not overwritten we can continue. |
6564 | if (!user->isOverwrittenNthInput(i)) { |
6565 | continue; |
6566 | } |
6567 | auto input = use.getUser()->getNthInput(i); |
6568 | if (input.getNode() == PH) { |
6569 | return true; |
6570 | } |
6571 | } |
6572 | } |
6573 | return false; |
6574 | } |
6575 | |
6576 | /// If \p PH is an input placeholder, \returns true. |
6577 | bool isInput(const Placeholder *PH, const Function &F) { |
6578 | // Check that the PH is the input to a saveNode or is used by a non saveNode. |
6579 | for (const auto &use : PH->getUsers()) { |
6580 | // Consider only users inside the same function. |
6581 | if (use.getUser()->getParent() != &F) { |
6582 | continue; |
6583 | } |
6584 | // Check if PH is an input to a saveNode. |
6585 | if (auto *save = dyn_cast<SaveNode>(use.getUser())) { |
6586 | auto input = save->getInput(); |
6587 | // If the PH is not an input to the saveNode we keep looking. |
6588 | if (input.getNode() != PH) { |
6589 | continue; |
6590 | } |
6591 | } |
6592 | return true; |
6593 | } |
6594 | return false; |
6595 | } |
6596 | |
6597 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module &mod) { |
6598 | mod.dump(os); |
6599 | return os; |
6600 | } |
6601 | |
6602 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module *mod) { |
6603 | assert(mod != nullptr && "Null Pointer." ); |
6604 | mod->dump(os); |
6605 | return os; |
6606 | } |
6607 | |
6608 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &F) { |
6609 | F.dump(os); |
6610 | return os; |
6611 | } |
6612 | |
6613 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function *F) { |
6614 | assert(F != nullptr && "Null Pointer." ); |
6615 | F->dump(os); |
6616 | return os; |
6617 | } |
6618 | |
6619 | bool isConvolutionSameAsFullyConnected(const ConvolutionNode *node, |
6620 | bool enforceInput1x1) { |
6621 | bool isConv2D = (node->getInput().getType()->dims().size() == 4); |
6622 | if (!(isConv2D && node->getLayout() == ConvolutionLayout::NHWC && |
6623 | !node->hasFusedActivation())) { |
6624 | return false; |
6625 | } |
6626 | auto filterDims = ShapeNHWC(node->getFilter().getType()->dims()); |
6627 | ShapeHW kernels = ShapeHW(node->getKernels()); |
6628 | ShapeHW strides = ShapeHW(node->getStrides()); |
6629 | PaddingTLBR pads = PaddingTLBR(node->getPads()); |
6630 | auto group = node->getGroup(); |
6631 | auto dilation = node->getDilation(); |
6632 | |
6633 | bool isSame = (filterDims.h == 1) && (filterDims.w == 1); |
6634 | isSame &= (kernels.height == 1) && (kernels.width == 1); |
6635 | isSame &= (strides.height == 1) && (strides.width == 1); |
6636 | isSame &= (pads.top == 0) && (pads.left == 0) && (pads.bottom == 0) && |
6637 | (pads.right == 0); |
6638 | isSame &= (group == 1); |
6639 | isSame &= std::all_of(dilation.begin(), dilation.end(), |
6640 | [](unsigned_t i) { return i == 1; }); |
6641 | |
6642 | if (enforceInput1x1) { |
6643 | auto inputDims = ShapeNHWC(node->getInput().getType()->dims()); |
6644 | isSame &= (inputDims.h == 1) && (inputDims.w == 1); |
6645 | } |
6646 | return isSame; |
6647 | } |
6648 | |
6649 | bool isGemmSameAsFullyConnected(const GemmNode *node) { |
6650 | NodeValue inpC = node->getC(); |
6651 | return (node->getAlpha() == 1.0) && (node->getBeta() == 1.0) && |
6652 | (inpC.getNode()) && (inpC.dims().size() == 1); |
6653 | } |
6654 | |
6655 | } // namespace glow |
6656 | |