1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Graph/Graph.h"
17#include "glow/Backend/Backend.h"
18#include "glow/Flags/Flags.h"
19#include "glow/Graph/Nodes.h"
20#include "glow/Graph/PlaceholderBindings.h"
21#include "glow/Graph/TensorLayout.h"
22#include "glow/Graph/VerifierHelper.h"
23#include "glow/Quantization/Base/Base.h"
24#include "glow/Support/Support.h"
25
26#include "llvm/ADT/DenseMap.h"
27#include "llvm/ADT/SmallString.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/Support/Casting.h"
30#include "llvm/Support/FileSystem.h"
31#include "llvm/Support/Format.h"
32#include "llvm/Support/Path.h"
33#include "llvm/Support/raw_ostream.h"
34
35#ifdef WIN32
36#include <corecrt_math_defines.h>
37#endif
38#include <float.h>
39#include <fstream>
40#include <unordered_set>
41
42using namespace glow;
43using llvm::cast;
44using llvm::dyn_cast;
45using llvm::isa;
46
47namespace {
48/// A helper function to log the deletion of constant/placeholder \p s of a
49/// module into the log context of given functions \p functions.
50/// Note: The reason we don't log the deletion of constants in the function that
51/// ueses or creates it, is that constants/placeholders do not have a function
52/// parent (we can't utilize its user's function also because its users might be
53/// removed) such that it's best to log the constants/placeholders in a Module
54/// level log context and copy over to its all functions.
55void logStorageDeletion(std::list<Function *> functions, Storage *s) {
56 for (auto *F : functions) {
57 F->getLogContext()->logNodeDeletion(*s);
58 }
59 if (functions.size() > 0) {
60 auto *F = *(functions.begin());
61 F->getLogContext()->logNodeDeletion(*s, /* logIntoModule */ true);
62 }
63}
64
65/// A helper function to log the creation of constant/placeholder \p s of a
66/// module into the log context of given functions \p functions.
67/// Same note as for logStorageDeletion().
68void logStorageCreation(std::list<Function *> functions, Storage *s) {
69 for (auto *F : functions) {
70 F->getLogContext()->logNodeCreation(*s);
71 }
72 if (functions.size() > 0) {
73 auto *F = *(functions.begin());
74 F->getLogContext()->logNodeCreation(*s, /* logIntoModule */ true);
75 }
76}
77} // namespace
78
79/// Merge shape \p shape into \p mergeShape, following multidirectional
80/// broadcasting rules.
81static void mergeMultidirectionalBroadcastHelper(std::vector<dim_t> &mergeShape,
82 llvm::ArrayRef<dim_t> shape) {
83 size_t shift = mergeShape.size() - shape.size();
84 for (size_t i = 0, e = shape.size(); i < e; i++) {
85 if (shape[i] == 1) {
86 // Just leave mergeShape[i] as it is.
87 continue;
88 }
89
90 assert(
91 ((shape[i] == mergeShape[shift + i]) || (mergeShape[shift + i] == 1)) &&
92 "Incompatible dimension for the broadcast");
93 mergeShape[shift + i] = shape[i];
94 }
95}
96
97/// Utility function which computes the resulting shape in case of
98/// multidirectional broadcasting.
99static std::vector<dim_t>
100computeMultidirectionalBroadcastHelper(llvm::ArrayRef<dim_t> shape0,
101 llvm::ArrayRef<dim_t> shape1) {
102 size_t numDims0 = shape0.size();
103 size_t numDims1 = shape1.size();
104 size_t newNumDims = std::max(numDims0, numDims1);
105 std::vector<dim_t> reshapeDims(newNumDims, 1);
106
107 mergeMultidirectionalBroadcastHelper(reshapeDims, shape0);
108 mergeMultidirectionalBroadcastHelper(reshapeDims, shape1);
109
110 return reshapeDims;
111}
112
113std::vector<NodeValue>
114Function::broadcastInputs(int axis, const llvm::ArrayRef<NodeValue> inputs) {
115 dim_t numInputs = inputs.size();
116
117 if (axis > -1) {
118 assert(
119 numInputs == 2 &&
120 "If axis is specified, not -1, unidirectional broadcast will be used, "
121 "input size must be 2.");
122 return {inputs[0],
123 createBroadcast("broadcast_" + inputs[1].getNode()->getName().str(),
124 inputs[1], inputs[0].dims(), axis)};
125 }
126
127 assert(numInputs >= 2 && "Invalid input passed in to commonCreateBroadcast.");
128
129 std::vector<dim_t> targetDim = computeMultidirectionalBroadcastHelper(
130 inputs[0].dims(), inputs[1].dims());
131
132 for (size_t i = 2; i < numInputs; ++i) {
133 targetDim =
134 computeMultidirectionalBroadcastHelper(targetDim, inputs[i].dims());
135 }
136
137 std::vector<NodeValue> out(numInputs);
138 for (size_t i = 0; i < numInputs; ++i) {
139 NodeValue n = inputs[i];
140 auto dims = n.dims();
141 if (dims != llvm::ArrayRef<dim_t>(targetDim)) {
142 unsigned axis = targetDim.size() - dims.size();
143 out[i] = createBroadcast("broadcast_" + n.getNode()->getName().str(), n,
144 targetDim, axis);
145 } else {
146 out[i] = inputs[i];
147 }
148 }
149 return out;
150}
151
152bool Module::hasFunction(llvm::StringRef name) { return getFunction(name); }
153
154void Module::clearFunctions() {
155 for (auto *F : functions_) {
156 F->clear();
157 }
158}
159
160void Function::clear() {
161 nodes_.clear();
162 uniqueNodeNames_.clear();
163}
164
165Function *Module::getFunction(llvm::StringRef name) {
166 for (auto *F : functions_) {
167 if (F->getName() == name) {
168 return F;
169 }
170 }
171 return nullptr;
172}
173
174Function *Module::createFunction(llvm::StringRef name) {
175 assert(!hasFunction(name) && "A function with this name already exists");
176 Function *F = new Function(this, name);
177 functions_.push_back(F);
178 return F;
179}
180
181void Module::strip() {
182 for (auto it = constants_.begin(), e = constants_.end(); it != e; it++) {
183 Constant *v = *it;
184 v->clearPayload();
185 }
186}
187
188void Module::clear() {
189 for (auto it = constants_.begin(), e = constants_.end(); it != e; it++) {
190 Constant *v = *it;
191 logStorageDeletion(functions_, v);
192 delete v;
193 }
194
195 constants_.clear();
196
197 for (auto it = placeholders_.begin(), e = placeholders_.end(); it != e;
198 it++) {
199 Placeholder *p = *it;
200 logStorageDeletion(functions_, p);
201 delete p;
202 }
203
204 eraseFunctions();
205
206 placeholders_.clear();
207}
208
209Module::~Module() { clear(); }
210bool Module::verify() const {
211 bool isValid = true;
212 for (auto *F : functions_) {
213 isValid &= F->verify();
214 }
215 // Check that all types used by constants or placeholders belong to the
216 // module.
217 auto &types = getTypes();
218 for (const auto *PH : getPlaceholders()) {
219 bool foundType =
220 std::find(types.begin(), types.end(), *PH->getType()) != types.end();
221 isValid &=
222 expectCompareTrue("Every type used by placeholders should be part of "
223 "the graph",
224 foundType, true, PH);
225 }
226 for (const auto *C : getConstants()) {
227 bool foundType =
228 std::find(types.begin(), types.end(), *C->getType()) != types.end();
229 isValid &=
230 expectCompareTrue("Every type used by constants should be part of "
231 "the graph",
232 foundType, true, C);
233 }
234 return isValid;
235}
236
237void Module::dump() const {
238 llvm::outs() << "Module structure:\n";
239 for (auto *C : getConstants()) {
240 llvm::outs() << C->getDebugDesc() << "\n";
241 }
242
243 for (auto *P : getPlaceholders()) {
244 llvm::outs() << P->getDebugDesc() << "\n";
245 }
246
247 for (auto *F : functions_) {
248 llvm::outs() << "Function:" << F->getName() << "\n";
249 }
250}
251
252std::string Module::toString() const {
253 std::string storage;
254 llvm::raw_string_ostream os(storage);
255 dump(os);
256 return os.str();
257}
258
259/// Creates a std::set copy of \p unsorted, sorted based on name of each
260/// element, and \returns it.
261template <class T>
262static std::set<T *, SortNamed> getNamedSorted(const std::list<T *> &unsorted) {
263 return std::set<T *, SortNamed>(unsorted.begin(), unsorted.end());
264}
265
266void Module::dump(llvm::raw_ostream &os) const {
267 os << "Module structure:\n";
268 for (auto *C : getNamedSorted(constants_)) {
269 os << C->getDebugDesc() << "\n";
270 }
271 for (auto *P : getNamedSorted(placeholders_)) {
272 os << P->getDebugDesc() << "\n";
273 }
274 for (auto *F : getNamedSorted(functions_)) {
275 os << "Function : " << F->getName() << "\n";
276 }
277}
278
279/// A helper class for visiting and generating the dotty graph file.
280class AbstractDottyPrinter {
281protected:
282 // List of generated vertices.
283 std::vector<std::string> vertices_{};
284 // List of generated edges.
285 std::unordered_set<std::string> edges_{};
286 // Map node addresses to unique numbers.
287 using VertexNumberMap = std::unordered_map<void *, unsigned>;
288 VertexNumberMap vertex_numbers{};
289
290 /// Dumps label for a input/output row, given port names.
291 /// E.g. {"LHS", "RHS"} will produce {<LHS>LHS|<RHS>RHS}
292 void dumpLabelForRow(llvm::ArrayRef<std::string> names, std::ostream &os) {
293 os << "{";
294 for (size_t i = 0; i < names.size(); i++) {
295 if (i) {
296 os << "|";
297 }
298 os << "<" << names[i] << ">" << names[i];
299 }
300 os << "}";
301 }
302
303 void dumpLabel(Node *N, std::ostream &os) {
304 os << "{";
305 if (N->getNumInputs()) {
306 std::vector<std::string> names(N->getNumInputs());
307 for (size_t i = 0; i < names.size(); i++) {
308 names[i] = N->getInputName(i);
309 }
310 dumpLabelForRow(names, os);
311 os << "|";
312 }
313 os << "{" << escapeDottyString(N->getDebugDesc()) << "}";
314 if (N->getNumResults()) {
315 os << "|";
316 std::vector<std::string> names(N->getNumResults());
317 for (size_t i = 0; i < names.size(); i++) {
318 names[i] = N->getOutputName(i).str();
319 }
320 dumpLabelForRow(names, os);
321 }
322 os << "}";
323 }
324
325 void dumpNode(Node *N, bool uniqueNames) {
326 if (!N) {
327 return;
328 }
329 std::ostringstream os;
330 // Print a node descriptor that looks like this:
331 if (uniqueNames) {
332 // vNNNN [ shape = "record" label = "{...}" ];
333 os << uniqueVertexName(N) << "[\n";
334 } else {
335 // <name> [ shape = "record" label = "{...}" ];
336 os << N->getName().str() << "[\n";
337 }
338 os << "\tlabel = \"";
339 dumpLabel(N, os);
340 os << "\"\n";
341 os << "\tshape = \"record\"\n";
342 os << "\tstyle=\"filled,rounded\"\n";
343
344 // Pick a color based on the node kind.
345 unsigned colorIdx = llvm::hash_value(llvm::StringRef(N->getKindName()));
346 auto nodeColor = getDotFileNodeColor(colorIdx);
347
348 if (isa<Constant>(N)) {
349 os << "\tfillcolor=Snow3 color=DeepSkyBlue4\n";
350 } else {
351 os << "\tfillcolor=" << nodeColor << "\n";
352 }
353 os << "penwidth = 2];\n";
354
355 vertices_.push_back(os.str());
356 }
357
358 void dumpEdgeStyle(const Node *N, size_t i, Node *to, std::ostream &os) {
359 if (N->isOverwrittenNthInput(i)) {
360 os << " [dir=\"both\"]";
361 }
362 }
363
364 std::string uniqueVertexName(void *N) {
365 VertexNumberMap::iterator i;
366 bool inserted;
367 std::tie(i, inserted) = vertex_numbers.insert(std::make_pair(N, 0u));
368 if (inserted) {
369 i->second = vertex_numbers.size() - 1;
370 }
371
372 std::string buffer;
373 llvm::raw_string_ostream stream(buffer);
374 stream << llvm::format("v%04u", i->second);
375 return stream.str();
376 }
377
378public:
379 void dumpAll(std::ostream &os) {
380 CHECK(os) << "Failed to create file for to dump Graph";
381
382 os << "digraph DAG {\n\trankdir=TB;\n";
383
384 // Dump vertices:
385 for (auto &v : vertices_) {
386 os << v << "\n";
387 }
388
389 // Dump edges:
390 for (auto &e : edges_) {
391 os << e << ";\n";
392 }
393
394 os << "}";
395 }
396};
397
398class ModuleDottyPrinter : public AbstractDottyPrinter {
399 /// Dump Function as a vertix. Then iterate through constants, used in the
400 /// function, and create corresponding edges.
401 void visitFunction(Function *F) {
402 std::ostringstream os;
403 // Print a Function descriptor that looks like this:
404 // vNNNN [ label = "{...}" ];
405 os << uniqueVertexName(F) << "[\n"
406 << "\tlabel = \"Function\\l"
407 << "name : " << F->getName().str() << "\\l"
408 << "node count : " << F->getNodes().size() << "\"\n"
409 << "\tshape = box\n"
410 << "\tfillcolor=gray89, style=\"filled,rounded\"\n"
411 << "\t\n"
412 << "];\n";
413 vertices_.push_back(os.str());
414
415 for (auto &N : F->getNodes()) {
416 for (size_t i = 0; i < N.getNumInputs(); i++) {
417 Node *to = N.getNthInput(i).getNode();
418 size_t resNo = N.getNthInput(i).getResNo();
419
420 if (!isa<Constant>(to))
421 continue;
422
423 std::ostringstream edge;
424 edge << uniqueVertexName(to) << ":" << to->getOutputName(resNo).str()
425 << " -> " << uniqueVertexName(F);
426 dumpEdgeStyle(&N, i, to, edge);
427 edges_.insert(edge.str());
428 }
429 }
430 }
431
432public:
433 void visitModule(Module *M) {
434 for (auto N : M->getConstants()) {
435 dumpNode(N, true);
436 }
437
438 for (auto F : M->getFunctions()) {
439 visitFunction(F);
440 }
441 }
442};
443
444// TODO: consider refactoring boilerplate code to new trait: DottyPrintable<ADP>
445void Module::dumpDAG() {
446 llvm::SmallString<64> dotPath;
447 llvm::sys::fs::createTemporaryFile("dotty_graph_dump", "dot", dotPath);
448 dumpDAG(dotPath);
449}
450
451void Module::dumpDAG(llvm::StringRef dotFilename) {
452 llvm::outs() << "Writing dotty graph for Module to: " << dotFilename << '\n';
453
454 ModuleDottyPrinter DP;
455
456 DP.visitModule(this);
457
458 std::ofstream myfile;
459 myfile.open(dotFilename.str());
460 if (myfile.fail()) {
461 LOG(ERROR) << "Unable to open " << dotFilename.str()
462 << ", reason: " << strerror(errno);
463 } else {
464 DP.dumpAll(myfile);
465 }
466 myfile.close();
467}
468
469void Module::dumpDAG(const char *dotFilename) {
470 dumpDAG(llvm::StringRef(dotFilename));
471}
472
473void Module::eraseFunctions() {
474 while (!functions_.empty()) {
475 eraseFunction(*functions_.begin());
476 }
477}
478
479void Module::eraseFunction(Function *F) {
480 auto it = std::find(functions_.begin(), functions_.end(), F);
481 assert(it != functions_.end() && "Function is not part of a module");
482 functions_.erase(it);
483 delete F;
484}
485
486uint64_t Module::getConstantsSize() {
487 uint64_t size = 0;
488 for (auto *constant : constants_) {
489 size += constant->getPayload().getSizeInBytes();
490 }
491 return size;
492}
493
494/// \returns an Error if any results from \p N are non-fused quantized with
495/// scale == or != dummyScale, depending on \p expectDummy.
496static Error verifyDummyQParamResults(const Node &N, bool expectDummy) {
497 for (size_t i = 0, e = N.getNumResults(); i < e; i++) {
498 TypeRef T = N.getType(i);
499 if (T->isQuantizedType() && !T->isFusedQuantizedType()) {
500 const bool isDummy = T->getScale() == dummyScale;
501 if (expectDummy) {
502 RETURN_ERR_IF_NOT(
503 isDummy, strFormat("Expected all dummy scales, but found non-dummy "
504 "inside Function %s: %s",
505 N.getParent()->getName().data(),
506 N.getDebugDesc().data()));
507 } else {
508 RETURN_ERR_IF_NOT(!isDummy,
509 strFormat("Expected no dummy scales, but found one "
510 "inside Function %s: %s",
511 N.getParent()->getName().data(),
512 N.getDebugDesc().data()));
513 }
514 }
515 }
516 return Error::success();
517}
518
519Error Module::verifyDummyQParams(bool expectDummies) {
520 for (const Function *F : getFunctions()) {
521 for (const Node &N : F->getNodes()) {
522 RETURN_IF_ERR(verifyDummyQParamResults(N, expectDummies));
523 }
524 }
525 for (const Placeholder *PH : getPlaceholders()) {
526 RETURN_IF_ERR(verifyDummyQParamResults(*PH, expectDummies));
527 }
528 for (const Constant *C : getConstants()) {
529 RETURN_IF_ERR(verifyDummyQParamResults(*C, expectDummies));
530 }
531 return Error::success();
532}
533
534Function::~Function() {
535 // Delete all of the nodes.
536 for (auto it = nodes_.begin(), e = nodes_.end(); it != e;) {
537 auto cur = it++;
538 eraseNode(&*cur);
539 }
540}
541
542TypeRef Module::uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims) {
543 return uniqueType(Type(elemTy, dims));
544}
545
546TypeRef Module::uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims,
547 float scale, int32_t offset) {
548 return uniqueType(Type(elemTy, dims, scale, offset));
549}
550
551TypeRef Module::uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims) {
552 return uniqueType(Type::newShape(*T, dims));
553}
554
555TypeRef Module::uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims,
556 llvm::ArrayRef<dim_t> alignments) {
557 return uniqueType(Type::newShape(*T, dims, alignments));
558}
559
560TypeRef Module::uniqueTypeWithNewShape(TypeRef T, TypeRef shapeType) {
561 return uniqueType(Type::newShape(*T, shapeType));
562}
563
564TypeRef Module::uniqueTypeWithNewStrides(TypeRef T, llvm::ArrayRef<dim_t> dims,
565 llvm::ArrayRef<dim_t> strides) {
566 return uniqueType(Type::newStrides(*T, strides));
567}
568
569TypeRef Module::uniqueTypeWithNewQuantParams(TypeRef T,
570 TypeRef quantParamType) {
571 return uniqueType(Type::newQuantparams(*T, quantParamType->getScale(),
572 quantParamType->getOffset()));
573}
574
575TypeRef Module::uniqueType(const Type &T) {
576 for (auto &tp : types_) {
577 if (T.isEqual(tp)) {
578 return &tp;
579 }
580 }
581
582 return &*types_.insert(types_.begin(), T);
583}
584
585TypeRef Module::getVoidTy() { return uniqueType(Type()); }
586
587/// \returns a ShapeVector of rank axes.size() less than the input \p dims,
588/// where the provided \p axes dimensions are removed from the shape.
589static ShapeVector getNewShapeWithoutAxes(llvm::ArrayRef<dim_t> dims,
590 llvm::ArrayRef<unsigned_t> axes) {
591 assert(axes.size() <= dims.size() &&
592 "Cannot remove more dimensions than exist.");
593 ShapeVector newDims(dims.begin(), dims.end());
594 ShapeVector shapeAxes(axes.begin(), axes.end());
595
596 // Sort so that looping erase below doesn't fail.
597 std::sort(shapeAxes.rbegin(), shapeAxes.rend());
598
599 for (const auto &axis : shapeAxes) {
600 assert(axis <= dims.size() &&
601 "Axis to remove must fit inside dimensions of the provided dims.");
602 newDims.erase(newDims.begin() + axis);
603 }
604 return newDims;
605}
606
607//===----------------------------------------------------------------------===//
608// Node builders
609//===----------------------------------------------------------------------===//
610
611Placeholder *Module::createPlaceholder(TypeRef T, llvm::StringRef name,
612 bool isTrainable,
613 const std::string &layout) {
614 auto FT = uniqueType(*T);
615 auto *ph = new Placeholder(name, FT, isTrainable, layout);
616 ph->setName(uniqueName(ph->getName(), usedNodeNames_, usedStorageNames_,
617 originalNames_));
618 placeholders_.push_back(ph);
619 logStorageCreation(functions_, ph);
620 return ph;
621}
622
623Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
624 llvm::StringRef name, bool isTrainable,
625 const std::string &layout) {
626 auto FT = uniqueType(T, dims);
627 return createPlaceholder(FT, name, isTrainable, layout);
628}
629
630Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
631 float scale, int32_t offset,
632 llvm::StringRef name, bool isTrainable,
633 const std::string &layout) {
634 auto FT = uniqueType(T, dims, scale, offset);
635 return createPlaceholder(FT, name, isTrainable, layout);
636}
637
638Constant *Module::createConstant(TypeRef T, llvm::StringRef name,
639 const std::string &layout) {
640 auto FT = uniqueType(*T);
641 return addConstant(new Constant(name, FT, layout));
642}
643
644Constant *Module::createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims,
645 llvm::StringRef name,
646 const std::string &layout) {
647 auto FT = uniqueType(T, dims);
648 return createConstant(FT, name, layout);
649}
650
651Constant *Module::createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims,
652 float scale, int32_t offset,
653 llvm::StringRef name,
654 const std::string &layout) {
655 auto FT = uniqueType(T, dims, scale, offset);
656 return createConstant(FT, name, layout);
657}
658
659Constant *Module::createConstant(llvm::StringRef name, const Tensor &tensor,
660 const std::string &layout) {
661 auto *V = createConstant(&tensor.getType(), name, layout);
662 V->assign(&tensor);
663 return V;
664}
665
666Constant *Module::createConstant(llvm::StringRef name, Tensor &&tensor,
667 const std::string &layout) {
668 return addConstant(new Constant(name, std::move(tensor), layout));
669}
670
671std::string Module::getPrefix(llvm::StringRef name) {
672 std::string prefix = name.str();
673 size_t delim = name.rfind("__");
674 if (delim != std::string::npos &&
675 std::all_of(name.begin() + (delim + 2), name.end(),
676 [](unsigned char c) { return ::isdigit(c); })) {
677 prefix = prefix.substr(0, delim);
678 }
679 return prefix;
680}
681
682llvm::StringRef Module::uniqueName(llvm::StringRef name,
683 const llvm::StringSet<> &stringTable,
684 llvm::StringSet<> &updateTable,
685 const llvm::StringSet<> &originalNames) {
686 std::string legalName = legalizeName(name);
687 if (stringTable.find(legalName) == stringTable.end()) {
688 auto it = updateTable.insert(legalName);
689 if (it.second) {
690 return it.first->first();
691 }
692 }
693 // Retain the trailing "__[0-9]+" if it is in the original name.
694 std::string prefix = (originalNames.find(legalName) == originalNames.end())
695 ? Module::getPrefix(legalName)
696 : legalName;
697 for (unsigned i = 1; i < 10000; i++) {
698 auto suffix = std::to_string(i);
699 std::string fullName = prefix + "__" + suffix;
700 if (stringTable.find(fullName) != stringTable.end()) {
701 continue;
702 }
703
704 auto it = updateTable.insert(fullName);
705 if (it.second) {
706 return it.first->first();
707 }
708 }
709 llvm_unreachable("Unable to find a unique a name.");
710}
711
712Constant *Module::addConstant(Constant *V) {
713 V->setName(uniqueName(V->getName(), usedNodeNames_, usedStorageNames_,
714 originalNames_));
715 // Replace the Constant's output type with the equivalent unique type for
716 // this Module to maintain the invariant that each type in the Module is
717 // unique.
718 V->setType(Constant::ResultIndices::OutputIdx, uniqueType(*V->getType()));
719 constants_.push_back(V);
720 logStorageCreation(functions_, V);
721 return V;
722}
723
724/// Check if the 'pads' array has the right size.
725static void assertPadsSize(NodeValue input, llvm::ArrayRef<int> pads) {
726 assert((pads.size() == 2 * input.dims().size()) &&
727 "the pads array must contain 2 values per dimensions");
728}
729
730PadNode *Function::createPad(llvm::StringRef name, NodeValue input,
731 TypeRef outTy, unsigned_t mode,
732 llvm::ArrayRef<int> pads, float value) {
733 assertPadsSize(input, pads);
734 auto OT = getParent()->uniqueType(*outTy);
735 return addNode(new PadNode(name, OT, input, mode, pads, value));
736}
737
738/// Check the kernel size for Conv/Pooling ops.
739static void checkKernelSize(ShapeNHWC idim, llvm::ArrayRef<unsigned_t> kernels,
740 llvm::ArrayRef<unsigned_t> pads) {
741 PaddingTLBR pdim(pads);
742 (void)pdim;
743 ShapeHW kdim(kernels);
744 (void)kdim;
745 assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
746 (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
747 "Kernel size is too large");
748}
749
750/// Check the kernel size for 3D Conv/Pooling ops.
751static void check3DKernelSize(ShapeNTHWC idim,
752 llvm::ArrayRef<unsigned_t> kernels,
753 llvm::ArrayRef<unsigned_t> pads) {
754 PaddingNFTBLR pdim(pads);
755 (void)pdim;
756 ShapeTHW kdim(kernels);
757 (void)kdim;
758 assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
759 (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
760 (idim.t + pdim.near + pdim.far) >= kdim.temporal_frames &&
761 "Kernel size is too large");
762}
763
764/// Check that the dimensions that are passed in when the ConvTranspose is
765/// constructed are correct.
766static void assertConvTransposeDims(NodeValue input, NodeValue filter,
767 NodeValue bias,
768 llvm::ArrayRef<unsigned_t> kernels,
769 llvm::ArrayRef<unsigned_t> strides,
770 llvm::ArrayRef<unsigned_t> pads,
771 unsigned_t group) {
772 ShapeNHWC idim = ShapeNHWC(input.dims());
773 (void)idim;
774 ShapeHW kdim(kernels);
775 (void)kdim;
776 assert(idim.c % group == 0 && "channels number must be divisible by groups");
777
778 // NOTE: here the N in NHWC is abnormal because it is the number of filters
779 // (and therefore the number of output channels of the conv) and not the
780 // batch size. The rest of the dimensions are representative of the input
781 // dimensions to the convolution.
782 ShapeNHWC filterDims(filter.dims());
783 (void)filterDims;
784
785 assert(filterDims.h == kdim.height && filterDims.w == kdim.width &&
786 filterDims.c == idim.c && "Invalid filter dims");
787
788 assert(bias.getType()->size() == filterDims.n * group && "Invalid bias size");
789}
790
791/// Check that the dimensions that are passed in when the convolution is
792/// constructed are correct.
793static void assertConvDims(NodeValue input, NodeValue filter, NodeValue bias,
794 llvm::ArrayRef<unsigned_t> kernels,
795 llvm::ArrayRef<unsigned_t> strides,
796 llvm::ArrayRef<unsigned_t> pads, unsigned_t group) {
797 ShapeNHWC idim = ShapeNHWC(input.dims());
798 ShapeHW kdim(kernels);
799 (void)kdim;
800 checkKernelSize(idim, kernels, pads);
801 assert(idim.c % group == 0 && "channels number must be divisible by groups");
802
803 // NOTE: here the N in NHWC is abnormal because it is the number of filters
804 // (and therefore the number of output channels of the conv) and not the
805 // batch size. The rest of the dimensions are representative of the input
806 // dimensions to the convolution.
807 ShapeNHWC filterDims(filter.dims());
808 (void)filterDims;
809
810 assert(filterDims.n % group == 0 && filterDims.h == kdim.height &&
811 filterDims.w == kdim.width && filterDims.c == idim.c / group &&
812 "Invalid filter dims");
813
814 assert(bias.getType()->size() == filterDims.n && "Invalid bias size");
815}
816
817/// Check that the dimensions that are passed in when the 3D convolution is
818/// constructed are correct.
819static void assertConv3DDims(NodeValue input, NodeValue filter, NodeValue bias,
820 llvm::ArrayRef<unsigned_t> kernels,
821 llvm::ArrayRef<unsigned_t> strides,
822 llvm::ArrayRef<unsigned_t> pads,
823 unsigned_t group) {
824 ShapeNTHWC idim(input.dims());
825 ShapeTHW kdim(kernels);
826 (void)kdim;
827 check3DKernelSize(idim, kernels, pads);
828 assert(idim.c % group == 0 && "channels number must be divisible by groups");
829
830 // NOTE: here the N in NTHWC is abnormal because it is the number of filters
831 // (and therefore the number of output channels of the 3d conv) and not the
832 // batch size. The rest of the dimensions are representative of the input
833 // dimensions to the convolution.
834 ShapeNTHWC filterDims(filter.dims());
835 (void)filterDims;
836
837 assert(filterDims.n % group == 0 && filterDims.h == kdim.height &&
838 filterDims.w == kdim.width && filterDims.t == kdim.temporal_frames &&
839 filterDims.c == idim.c / group && "Invalid filter dims");
840
841 assert(bias.getType()->size() == filterDims.n && "Invalid bias size");
842}
843
844ConvolutionNode *Function::createConv(
845 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
846 TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
847 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
848 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation,
849 ConvolutionLayout layout) {
850 assertConvDims(input, filter, bias, kernels, strides, pads, group);
851 auto OT = getParent()->uniqueType(*outTy);
852
853 // If the input is quantized but the bias is not then auto-quantize the
854 // bias.
855 if (input.getType()->isQuantizedType()) {
856 auto biasType = bias.getElementType();
857 if (biasType == ElemKind::Int32QTy || biasType == ElemKind::Int8QTy) {
858 // Nothing to do
859 } else if (biasType == ElemKind::FloatTy) {
860 auto biasTy = getParent()->uniqueType(
861 glow::ElemKind::Int32QTy, bias.dims(),
862 input.getType()->getScale() * filter.getType()->getScale(),
863 /* offset */ 0);
864 bias = createQuantize("quantized_bias", bias, biasTy);
865 } else {
866 LOG(DFATAL)
867 << "Unsupported element type for bias of quantized convolution: "
868 << Type::getElementName(biasType).str();
869 }
870 }
871
872 return addNode(new ConvolutionNode(name, OT, input, filter, bias, kernels,
873 strides, pads, group, dilation, layout,
874 FusedActivation::NONE, {}));
875}
876
877ConvolutionNode *Function::createConv(llvm::StringRef name, NodeValue input,
878 NodeValue filter, NodeValue bias,
879 TypeRef outTy, unsigned_t kernel,
880 unsigned_t stride, unsigned_t pad,
881 unsigned_t group,
882 llvm::ArrayRef<unsigned_t> dilation,
883 ConvolutionLayout layout) {
884 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
885 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
886 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
887 return createConv(name, input, filter, bias, outTy, kernels, strides, pads,
888 group, dilation, layout);
889}
890
891Convolution3DNode *Function::createConv3D(llvm::StringRef name, NodeValue input,
892 NodeValue filter, NodeValue bias,
893 TypeRef outTy,
894 llvm::ArrayRef<unsigned_t> kernels,
895 llvm::ArrayRef<unsigned_t> strides,
896 llvm::ArrayRef<unsigned_t> pads,
897 unsigned_t group) {
898 assertConv3DDims(input, filter, bias, kernels, strides, pads, group);
899 auto OT = getParent()->uniqueType(*outTy);
900
901 // If the input is quantized but the bias is not then auto-quantize the
902 // bias.
903 if (input.getType()->isQuantizedType()) {
904 auto biasType = bias.getElementType();
905 if (biasType == ElemKind::Int32QTy || biasType == ElemKind::Int8QTy ||
906 biasType == ElemKind::Int16QTy) {
907 // Nothing to do
908 } else if (biasType == ElemKind::FloatTy) {
909 auto biasTy = getParent()->uniqueType(
910 glow::ElemKind::Int32QTy, bias.dims(),
911 input.getType()->getScale() * filter.getType()->getScale(),
912 /* offset */ 0);
913 bias = createQuantize("quantized_bias", bias, biasTy);
914 } else {
915 LOG(DFATAL)
916 << "Unsupported element type for bias of quantized convolution: "
917 << Type::getElementName(biasType).str();
918 }
919 }
920 return addNode(new Convolution3DNode(name, OT, input, filter, bias, kernels,
921 strides, pads, group));
922}
923
924Convolution3DNode *Function::createConv3D(llvm::StringRef name, NodeValue input,
925 NodeValue filter, NodeValue bias,
926 TypeRef outTy, unsigned_t kernel,
927 unsigned_t stride, unsigned_t pad,
928 unsigned_t group) {
929 llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad};
930 llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride};
931 llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel};
932 return createConv3D(name, input, filter, bias, outTy, kernels, strides, pads,
933 group);
934}
935
936ConvTransposeNode *Function::createConvTranspose(
937 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
938 TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
939 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
940 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation) {
941 assertConvTransposeDims(input, filter, bias, kernels, strides, pads, group);
942 auto OT = getParent()->uniqueType(*outTy);
943 return addNode(new ConvTransposeNode(name, OT, input, filter, bias, kernels,
944 strides, pads, group, dilation));
945}
946
947ConvTransposeNode *Function::createConvTranspose(
948 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
949 TypeRef outTy, unsigned_t kernel, unsigned_t stride, unsigned_t pad,
950 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation) {
951 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
952 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
953 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
954 return createConvTranspose(name, input, filter, bias, outTy, kernels, strides,
955 pads, group, dilation);
956}
957
958MaxPoolNode *Function::createMaxPool(llvm::StringRef name, NodeValue input,
959 llvm::ArrayRef<unsigned_t> kernels,
960 llvm::ArrayRef<unsigned_t> strides,
961 llvm::ArrayRef<unsigned_t> pads,
962 ElemKind elemTyAMT,
963 ConvolutionLayout layout) {
964 ShapeNHWC idim = ShapeNHWC(input.dims());
965 checkKernelSize(idim, kernels, pads);
966
967 auto outSz =
968 calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
969 auto OT = getParent()->uniqueTypeWithNewShape(
970 input.getType(), {idim.n, outSz.first, outSz.second, idim.c});
971 auto AMT = getParent()->uniqueType(
972 elemTyAMT, {idim.n, outSz.first, outSz.second, idim.c});
973
974 return addNode(
975 new MaxPoolNode(name, OT, AMT, input, kernels, strides, pads, layout));
976}
977
978MaxPoolNode *Function::createMaxPool(llvm::StringRef name, NodeValue input,
979 unsigned_t kernel, unsigned_t stride,
980 unsigned_t pad, ElemKind elemTyAMT,
981 ConvolutionLayout layout) {
982 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
983 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
984 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
985 return createMaxPool(name, input, kernels, strides, pads, elemTyAMT, layout);
986}
987
988AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,
989 llvm::ArrayRef<unsigned_t> kernels,
990 llvm::ArrayRef<unsigned_t> strides,
991 llvm::ArrayRef<unsigned_t> pads,
992 ConvolutionLayout layout,
993 bool countIncludePads) {
994 if (!is3DData(layout)) {
995
996 ShapeNHWC idim = ShapeNHWC(input.dims());
997 checkKernelSize(idim, kernels, pads);
998
999 auto outSz =
1000 calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
1001 auto OT = getParent()->uniqueTypeWithNewShape(
1002 input.getType(), {idim.n, outSz.first, outSz.second, idim.c});
1003 return addNode(new AvgPoolNode(name, OT, input, kernels, strides, pads,
1004 layout, countIncludePads));
1005
1006 } else {
1007 ShapeNTHWC idim = ShapeNTHWC(input.dims());
1008 check3DKernelSize(idim, kernels, pads);
1009
1010 auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels,
1011 strides, pads);
1012 auto OT = getParent()->uniqueTypeWithNewShape(
1013 input.getType(),
1014 {idim.n, outSz.temporal_frames, outSz.height, outSz.width, idim.c});
1015 return addNode(new AvgPoolNode(name, OT, input, kernels, strides, pads,
1016 layout, countIncludePads));
1017 }
1018}
1019
1020AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,
1021 TypeRef outTy,
1022 llvm::ArrayRef<unsigned_t> kernels,
1023 llvm::ArrayRef<unsigned_t> strides,
1024 llvm::ArrayRef<unsigned_t> pads,
1025 ConvolutionLayout layout,
1026 bool countIncludePads) {
1027 if (!is3DData(layout)) {
1028
1029 ShapeNHWC idim = ShapeNHWC(input.dims());
1030 ShapeHW kdim(kernels);
1031 (void)kdim;
1032 checkKernelSize(idim, kernels, pads);
1033 return addNode(new AvgPoolNode(name, outTy, input, kernels, strides, pads,
1034 layout, countIncludePads));
1035
1036 } else {
1037
1038 ShapeNTHWC idim = ShapeNTHWC(input.dims());
1039 ShapeTHW kdim(kernels);
1040 (void)kdim;
1041 check3DKernelSize(idim, kernels, pads);
1042 return addNode(new AvgPoolNode(name, outTy, input, kernels, strides, pads,
1043 layout, countIncludePads));
1044 }
1045}
1046
1047AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,
1048 unsigned_t kernel, unsigned_t stride,
1049 unsigned_t pad, ConvolutionLayout layout,
1050 bool countIncludePads) {
1051 if (!is3DData(layout)) {
1052
1053 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
1054 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
1055 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
1056 return createAvgPool(name, input, kernels, strides, pads, layout,
1057 countIncludePads);
1058
1059 } else {
1060
1061 llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad};
1062 llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride};
1063 llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel};
1064 return createAvgPool(name, input, kernels, strides, pads, layout,
1065 countIncludePads);
1066 }
1067}
1068
1069AdaptiveAvgPoolNode *Function::createAdaptiveAvgPool(llvm::StringRef name,
1070 NodeValue input,
1071 TypeRef outTy) {
1072 return addNode(new AdaptiveAvgPoolNode(name, outTy, input));
1073}
1074
1075GemmNode *Function::createGemm(llvm::StringRef name, NodeValue A, NodeValue B,
1076 NodeValue C, float alpha, float beta,
1077 bool transposeA, bool transposeB) {
1078 std::vector<dim_t> outDims(2);
1079 outDims[0] = transposeA ? A.dims()[1] : A.dims()[0];
1080 outDims[1] = transposeB ? B.dims()[0] : B.dims()[1];
1081 TypeRef outTy = getParent()->uniqueTypeWithNewShape(A.getType(), outDims);
1082 return createGemm(name, outTy, A, B, C, alpha, beta, transposeA, transposeB);
1083}
1084
1085GemmNode *Function::createGemm(llvm::StringRef name, TypeRef outTy, NodeValue A,
1086 NodeValue B, NodeValue C, float alpha,
1087 float beta, bool transposeA, bool transposeB) {
1088 // If C operand is not given then we create a 1D splat with 0.
1089 if (!C.getNode()) {
1090 TypeRef splatTy =
1091 getParent()->uniqueTypeWithNewShape(outTy, {outTy->dims()[1]});
1092 C = createSplat(name.str() + ".SplatC", splatTy, 0.0f);
1093 }
1094 // If C operand is a 2D constant we check if it is a broadcasted version of
1095 // a 1D tensor. If yes then we slice and reshape the C operand to 1D.
1096 if (auto *constC = llvm::dyn_cast<Constant>(C.getNode())) {
1097 if ((constC->dims().size() == 2) && (constC->getPayload().isTiled(0))) {
1098 // Slice and reshape to 1D.
1099 dim_t lengthC = constC->dims()[1];
1100 C = createSlice(name.str() + ".SliceC", C, {0, 0}, {1, lengthC});
1101 C = createReshape(name.str() + ".ReshapeC", C, {lengthC});
1102 }
1103 }
1104 TypeRef OT = getParent()->uniqueType(*outTy);
1105 return addNode(
1106 new GemmNode(name, OT, A, B, C, alpha, beta, transposeA, transposeB));
1107}
1108
1109DynamicQuantizedFullyConnectedNode *
1110Function::createDynamicQuantizedFullyConnected(llvm::StringRef name,
1111 NodeValue input, NodeValue W,
1112 NodeValue B, bool isSymmetric,
1113 bool isPerBatchElement) {
1114 TypeRef T = input.getType();
1115 TypeRef OT =
1116 getParent()->uniqueTypeWithNewShape(T, {input.dims()[0], B.dims()[0]});
1117 return addNode(new DynamicQuantizedFullyConnectedNode(
1118 name, OT, input, W, B, isSymmetric, isPerBatchElement));
1119}
1120
1121DynamicRowwiseQuantizedFullyConnectedNode *
1122Function::createDynamicRowwiseQuantizedFullyConnected(
1123 llvm::StringRef name, NodeValue input, NodeValue W, NodeValue B,
1124 NodeValue scales, NodeValue offsets, bool isSymmetric,
1125 bool isPerBatchElement) {
1126 TypeRef T = input.getType();
1127 TypeRef OT =
1128 getParent()->uniqueTypeWithNewShape(T, {input.dims()[0], B.dims()[0]});
1129 return addNode(new DynamicRowwiseQuantizedFullyConnectedNode(
1130 name, OT, input, W, B, scales, offsets, isSymmetric, isPerBatchElement));
1131}
1132
1133FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
1134 NodeValue input, Storage *W,
1135 Storage *B,
1136 unsigned_t axis) {
1137 TypeRef T = input.getType();
1138 TypeRef OT = getParent()->uniqueTypeWithNewShape(
1139 T, {input.dims()[0], B->getType()->dims()[0]});
1140
1141 return createFullyConnected(name, input, W, B, OT, axis);
1142}
1143
1144FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
1145 NodeValue input, NodeValue W,
1146 NodeValue B,
1147 unsigned_t axis) {
1148 TypeRef T = input.getType();
1149 TypeRef OT =
1150 getParent()->uniqueTypeWithNewShape(T, {input.dims()[0], B.dims()[0]});
1151
1152 return createFullyConnected(name, input, W, B, OT, axis);
1153}
1154
1155FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
1156 NodeValue input, NodeValue W,
1157 NodeValue B, TypeRef outTy,
1158 unsigned_t axis) {
1159 assert(outTy->dims().size() == 2 && "Invalid number of dimensions");
1160
1161 // FC always uses 2D input; flatten if necessary.
1162 if (input.dims().size() != 2) {
1163 input = createFlatten(name.str() + ".reshape2D", input, axis);
1164 }
1165
1166 TypeRef OT = getParent()->uniqueType(*outTy);
1167 return addNode(new FullyConnectedNode(name, OT, input, W, B));
1168}
1169
1170RowwiseQuantizedFullyConnectedNode *
1171Function::createRowwiseQuantizedFullyConnected(llvm::StringRef name,
1172 NodeValue input, NodeValue W,
1173 Constant *scales,
1174 Constant *offsets, NodeValue B,
1175 TypeRef outTy) {
1176 return addNode(new RowwiseQuantizedFullyConnectedNode(name, outTy, input, W,
1177 scales, offsets, B));
1178}
1179
1180RowwiseQuantizedFullyConnectedNode *
1181Function::createRowwiseQuantizedFullyConnected(llvm::StringRef name,
1182 NodeValue input, NodeValue W,
1183 NodeValue B, TypeRef outTy,
1184 quantization::Schema schema,
1185 bool transposeWeight) {
1186 // Since W is constant, quantize it in compilation time.
1187 // The quantized data is in qWeights, the scale of each row is in scales,
1188 // and the offset of each row is in offsets.
1189 Constant *weights = llvm::cast<Constant>(W);
1190 CHECK(weights)
1191 << "Expected RowwiseQuantizedFullyConnected weights to be a Constant";
1192
1193 dim_t numRows = transposeWeight ? weights->getType()->dims()[1]
1194 : weights->getType()->dims()[0];
1195 dim_t numCols = transposeWeight ? weights->getType()->dims()[0]
1196 : weights->getType()->dims()[1];
1197
1198 // So far, if we want to create a storage with Int8QTy/Int16QTy,
1199 // it is assumed to be quantized data and the scale and offset should be
1200 // provided. But for rowwise quantization, the scales and offsets are stored
1201 // in vectors separately, we add the dummy scale and offset here.
1202 auto *qWeights = getParent()->createConstant(
1203 ElemKind::Int8QTy, {numRows, numCols}, 0.0, 0, "weights.rwqfc");
1204 auto *scales =
1205 getParent()->createConstant(ElemKind::FloatTy, {numRows}, "scales.rwqfc");
1206 auto *offsets = getParent()->createConstant(ElemKind::Int32ITy, {numRows},
1207 "offsets.rwqfc");
1208
1209 Tensor wt;
1210 if (transposeWeight) {
1211 // This happens when the RowwiseQuantizedFullyConnected node is converted
1212 // from a quantized FullyConnected node in Glow's quantization procedure.
1213 // Since in FC, the weights is stored as transposed (i.e. I * W + B), but
1214 // in RowwiseQuantizedFullyConnected, the weights is stored as it is (i.e.
1215 // I * W(T) + B).
1216 weights->getPayloadMutable().transpose(&wt, {1, 0});
1217 } else {
1218 wt.assign(&(weights->getPayload()));
1219 }
1220
1221 // Note: Using int32_t offset here as that is what RWQ-FC expects.
1222 quantization::tensorRowwiseQuantization<float, int32_t, int8_t>(
1223 wt, qWeights->getPayloadMutable(), scales->getPayloadMutable(),
1224 offsets->getPayloadMutable(), schema);
1225
1226 return addNode(new RowwiseQuantizedFullyConnectedNode(
1227 name, outTy, input, qWeights, scales, offsets, B));
1228}
1229
1230ReluNode *Function::createRelu(llvm::StringRef name, TypeRef outTy,
1231 NodeValue input) {
1232 return addNode(new ReluNode(name, outTy, input));
1233}
1234
1235ReluNode *Function::createRELU(llvm::StringRef name, NodeValue input,
1236 TypeRef outTy) {
1237 return createRelu(name, outTy, input);
1238}
1239
1240ReluNode *Function::createRelu(llvm::StringRef name, NodeValue input) {
1241 return createRelu(name, input.getType(), input);
1242}
1243
1244ReluNode *Function::createRELU(llvm::StringRef name, NodeValue input) {
1245 return createRelu(name, input);
1246}
1247
1248GeluNode *Function::createGelu(llvm::StringRef name, NodeValue input) {
1249 return addNode(new GeluNode(name, input.getType(), input));
1250}
1251
1252GeluNode *Function::createGELU(llvm::StringRef name, NodeValue input) {
1253 return createGelu(name, input);
1254}
1255
1256PReluNode *Function::createPRELU(llvm::StringRef name, NodeValue input,
1257 NodeValue slope, TypeRef outTy) {
1258 return addNode(new PReluNode(name, outTy, input, slope));
1259}
1260
1261PReluNode *Function::createPRELU(llvm::StringRef name, NodeValue input,
1262 NodeValue slope) {
1263 return addNode(new PReluNode(name, input.getType(), input, slope));
1264}
1265
1266SigmoidNode *Function::createSigmoid(llvm::StringRef name, TypeRef outTy,
1267 NodeValue input) {
1268 return addNode(new SigmoidNode(name, outTy, input));
1269}
1270
1271SigmoidNode *Function::createSigmoid(llvm::StringRef name, NodeValue input) {
1272 return createSigmoid(name, input.getType(), input);
1273}
1274
1275SwishNode *Function::createSwish(llvm::StringRef name, NodeValue input) {
1276 return createSwish(name, getParent()->uniqueType(*input.getType()), input);
1277}
1278
1279SwishNode *Function::createSwish(llvm::StringRef name, TypeRef OT,
1280 NodeValue input) {
1281 return addNode(new SwishNode(name, OT, input));
1282}
1283
1284SwishNode *Function::createSwish(llvm::StringRef name, NodeValue input,
1285 TypeRef OT) {
1286 return createSwish(name, OT, input);
1287}
1288
1289ClipNode *Function::createHardSigmoid(llvm::StringRef name, TypeRef outTy,
1290 NodeValue input, float alpha,
1291 float beta) {
1292 auto ty = input.getType();
1293
1294 // max(0, min(1, alpha * x + beta))
1295 auto *alphaSplat = createSplat(name.str() + ".alpha", ty, alpha);
1296 auto *betaSplat = createSplat(name.str() + ".beta", ty, beta);
1297
1298 auto *mul = createMul(name.str() + ".mul", alphaSplat, input);
1299 auto *add = createAdd(name.str() + ".add", mul, betaSplat);
1300
1301 return createClip(name.str() + ".clip", add, outTy, 0, 1);
1302}
1303
1304ClipNode *Function::createHardSigmoid(llvm::StringRef name, NodeValue input,
1305 float alpha, float beta) {
1306 return createHardSigmoid(name, input.getType(), input, alpha, beta);
1307}
1308
1309TanhNode *Function::createTanh(llvm::StringRef name, TypeRef outTy,
1310 NodeValue input) {
1311 return addNode(new TanhNode(name, outTy, input));
1312}
1313
1314TanhNode *Function::createTanh(llvm::StringRef name, NodeValue input) {
1315 return createTanh(name, input.getType(), input);
1316}
1317
1318SoftPlusNode *Function::createSoftPlus(llvm::StringRef name, NodeValue input,
1319 TypeRef outTy) {
1320 if (!outTy) {
1321 outTy = getParent()->uniqueType(*input.getType());
1322 }
1323 return addNode(new SoftPlusNode(name, outTy, input));
1324}
1325
1326SoftMaxNode *Function::createSoftMax(llvm::StringRef name, NodeValue input,
1327 NodeValue selected, TypeRef outTy,
1328 float beta) {
1329 // Create input multiplier with beta.
1330 if (beta != 1.0) {
1331 auto *splat = createSplat(name, input.getType(), beta);
1332 input = createMul(name, input, splat);
1333 }
1334 // By default, pick the input type.
1335 if (!outTy) {
1336 outTy = getParent()->uniqueType(*input.getType());
1337 }
1338 return addNode(new SoftMaxNode(name, outTy, input, selected));
1339}
1340
1341LogSoftMaxNode *Function::createLogSoftMax(llvm::StringRef name,
1342 NodeValue input, NodeValue selected,
1343 TypeRef outTy, float beta) {
1344 // Create input multiplier with beta.
1345 if (beta != 1.0) {
1346 auto *splat = createSplat(name, input.getType(), beta);
1347 input = createMul(name, input, splat);
1348 }
1349 // By default, pick the input type.
1350 if (!outTy) {
1351 outTy = getParent()->uniqueType(*input.getType());
1352 }
1353 return addNode(new LogSoftMaxNode(name, outTy, input, selected));
1354}
1355
1356CrossEntropyLossNode *Function::createCrossEntropyLoss(llvm::StringRef name,
1357 NodeValue input,
1358 NodeValue labels) {
1359 auto ty = getParent()->uniqueTypeWithNewShape(input.getType(), {1});
1360 return addNode(new CrossEntropyLossNode(name, ty, input, labels));
1361}
1362
1363RegressionNode *Function::createRegression(llvm::StringRef name,
1364 NodeValue input,
1365 NodeValue expected) {
1366 return addNode(new RegressionNode(name, input, expected));
1367}
1368
1369SigmoidCrossEntropyWithLogitsNode *
1370Function::createSigmoidCrossEntropyWithLogits(llvm::StringRef name,
1371 NodeValue logits,
1372 NodeValue targets) {
1373 assert(logits.dims().size() > 1);
1374 std::vector<dim_t> outDims(logits.dims().begin(), logits.dims().end() - 1);
1375 auto ty = getParent()->uniqueTypeWithNewShape(logits.getType(), outDims);
1376 return addNode(
1377 new SigmoidCrossEntropyWithLogitsNode(name, ty, logits, targets));
1378}
1379
1380ReshapeNode *Function::createReshape(llvm::StringRef name, NodeValue input,
1381 llvm::ArrayRef<dim_t> shape,
1382 llvm::StringRef layout) {
1383 auto TR = getParent()->uniqueTypeWithNewShape(input.getType(), shape);
1384 DCHECK_EQ(TR->size(), input.getType()->size())
1385 << "Reshape to a different size";
1386 return addNode(
1387 new ReshapeNode(name.str(), TR, input, shape.vec(), layout.str()));
1388}
1389
1390TransposeNode *Function::createTranspose(llvm::StringRef name, NodeValue input,
1391 llvm::ArrayRef<unsigned_t> shuffle,
1392 const std::string &layout) {
1393 ShapeVector shape;
1394 auto dims = input.dims();
1395 for (size_t i = 0; i < dims.size(); i++) {
1396 shape.push_back(dims[shuffle[i]]);
1397 }
1398
1399 // If the layout is known, check that it matches the shuffle:
1400 auto compareShuffle = [&](const std::vector<unsigned_t> targetShuffle) {
1401 auto shuffleVec = shuffle.vec();
1402 return targetShuffle.size() == dims.size() &&
1403 std::equal(shuffleVec.begin(), shuffleVec.end(),
1404 targetShuffle.begin());
1405 };
1406
1407 auto currLayout = layout;
1408 if (currLayout == ANY_LAYOUT) {
1409 // If layout got a default value, change it based on shuffle:
1410 // TODO: remove the shuffle and replace it with layout.
1411 if (compareShuffle(NCHW2NHWC) || compareShuffle(HWCN2NHWC)) {
1412 currLayout = "NHWC";
1413 } else if (compareShuffle(NCTHW2NTHWC)) {
1414 currLayout = "NTHWC";
1415 } else if (compareShuffle(NHWC2NCHW)) {
1416 currLayout = "NCHW";
1417 } else if (compareShuffle(NTHWC2NCTHW)) {
1418 currLayout = "NCTHW";
1419 } else if (compareShuffle(NHWC2HWNC)) {
1420 currLayout = "HWNC";
1421 } else if (compareShuffle(CNHW2NHWC)) {
1422 currLayout = "NHWC";
1423 }
1424 }
1425
1426 auto NT = getParent()->uniqueTypeWithNewShape(input.getType(), shape);
1427 return addNode(new TransposeNode(name, NT, input, shuffle.vec(), currLayout));
1428}
1429
1430FlipNode *Function::createFlip(llvm::StringRef name, NodeValue input,
1431 unsigned_t axis) {
1432 auto OT = getParent()->uniqueType(*input.getType());
1433 return addNode(new FlipNode(name, OT, input, axis));
1434}
1435
1436BroadcastNode *Function::createBroadcast(llvm::StringRef name, NodeValue input,
1437 UnsignedArrayRef newShape,
1438 unsigned_t axis) {
1439 auto OT = getParent()->uniqueTypeWithNewShape(input.getType(), newShape);
1440 return addNode(new BroadcastNode(name, OT, input, axis, newShape.vec()));
1441}
1442
1443std::array<Node *, 2> Function::createRMSNorm(llvm::StringRef name, NodeValue X,
1444 NodeValue gamma, NodeValue beta,
1445 float epsilon) {
1446 // np.square(X)
1447 auto square = createSquare(name.str() + ".square", X);
1448
1449 // np.mean(np.square(X), axis=1)
1450 auto mean = createBatchedReduceMean(name.str() + ".mean", square, {1});
1451
1452 // np.mean(np.square(X), axis=1) + eps
1453 auto eps = getParent()->createConstant(ElemKind::FloatTy, 1, "eps");
1454 eps->getPayloadMutable() = {epsilon};
1455 auto bcastEps =
1456 createBroadcast(name.str() + ".bcastEps", eps, {X.dims()[0]}, 0);
1457 auto addEps = createAdd(name.str() + ".addEps", mean, bcastEps);
1458
1459 // np.sqrt(np.mean(np.square(X), axis=1) + eps)
1460 auto sqrt = createPow(name.str() + ".sqrt", addEps, 0.5f);
1461
1462 // rrms = 1.0 / np.sqrt(np.mean(np.square(X), axis=1) + eps)
1463 auto one = getParent()->createConstant(ElemKind::FloatTy, 1, "one");
1464 one->getPayloadMutable() = {1};
1465 auto bcastOne =
1466 createBroadcast(name.str() + ".bcastOne", one, {X.dims()[0]}, 0);
1467 auto rrms = createDiv(name.str() + ".rrms", bcastOne, sqrt);
1468
1469 // np.expand_dims(rrms, axis=1)
1470 auto reshape = createReshape(name.str() + ".expandD", rrms, {X.dims()[0], 1});
1471 auto bcastReshape =
1472 createBroadcast(name.str() + ".bcastReshape", reshape, X.dims(), 0);
1473
1474 // X * np.expand_dims(rrms, axis=1)
1475 auto mul = createMul(name.str() + "mul", X, bcastReshape);
1476
1477 // X * np.expand_dims(rrms, axis=1) * gamma
1478 auto bcastGamma =
1479 createBroadcast(name.str() + ".bcastGamma", gamma, X.dims(), 1);
1480 auto mulGamma = createMul(name.str() + ".mulGamma", mul, bcastGamma);
1481
1482 // Y = X * np.expand_dims(rrms, axis=1) * gamma + beta
1483 auto bcastBeta =
1484 createBroadcast(name.str() + ".bcastBeta", beta, X.dims(), 1);
1485 auto Y = createAdd(name.str() + ".Y", mulGamma, bcastBeta);
1486
1487 return {Y, rrms};
1488}
1489
1490/// \returns true if \p T1 and T2 has the exact same type except for dimension
1491/// \p dim. It will log an error when returning false.
1492static bool sameSameShapeExceptDim(TypeRef T1, TypeRef T2, unsigned dim) {
1493 if (T1->getElementType() != T2->getElementType()) {
1494 LOG(ERROR) << "Different types " << (int)T1->getElementType() << " "
1495 << (int)T2->getElementType();
1496 return false;
1497 }
1498
1499 auto D1 = T1->dims();
1500 auto D2 = T2->dims();
1501
1502 if (D1.size() != D2.size()) {
1503 LOG(ERROR) << "Different size " << D1.size() << " " << D2.size();
1504 return false;
1505 }
1506
1507 for (unsigned i = 0, e = D1.size(); i < e; i++) {
1508 // Ignore the dimension \p dim.
1509 if (i == dim) {
1510 continue;
1511 }
1512
1513 if (D1[i] != D2[i]) {
1514 LOG(ERROR) << "Different dimension at " << i << " " << D1[i] << " "
1515 << D2[i];
1516 return false;
1517 }
1518 }
1519
1520 return true;
1521}
1522
1523ConcatNode *Function::createConcat(llvm::StringRef name,
1524 llvm::ArrayRef<NodeValue> inputs,
1525 unsigned_t dimension) {
1526 for (int i = 1, e = inputs.size(); i < e; i++) {
1527 assert(sameSameShapeExceptDim(inputs[i].getType(), inputs[0].getType(),
1528 dimension) &&
1529 "Invalid type");
1530 (void)sameSameShapeExceptDim;
1531 }
1532 auto inDim = inputs[0].dims();
1533
1534 ShapeVector shape(inDim.begin(), inDim.end());
1535
1536 // We are stacking the tensors along a specific dimension. This means that
1537 // we increase the size of the tensor along this dimension.
1538 shape[dimension] = 0;
1539 for (auto I : inputs) {
1540 shape[dimension] += I.getType()->dims()[dimension];
1541 }
1542
1543 auto NT = getParent()->uniqueTypeWithNewShape(inputs[0].getType(), shape);
1544 std::vector<NodeValue> ops;
1545 ops.reserve(inputs.size());
1546 for (auto I : inputs) {
1547 ops.emplace_back(I);
1548 }
1549 return addNode(new ConcatNode(name, NT, ops, dimension));
1550}
1551
1552ConcatNode *Function::createConcat(llvm::StringRef name,
1553 llvm::ArrayRef<NodeValue> inputs,
1554 unsigned_t dimension, TypeRef outTy) {
1555 std::vector<NodeValue> ops;
1556 ops.reserve(inputs.size());
1557 for (auto I : inputs) {
1558 ops.emplace_back(I);
1559 }
1560
1561 TypeRef OT = getParent()->uniqueType(*outTy);
1562 return addNode(new ConcatNode(name, OT, ops, dimension));
1563}
1564
1565TileNode *Function::createTile(llvm::StringRef name, NodeValue input,
1566 unsigned_t tiles, unsigned_t axis,
1567 TypeRef outTy) {
1568 assert(tiles > 0 && "Tiles must be non-zero.");
1569 assert(axis >= 0 && axis < input.dims().size() &&
1570 "Axis must fall in range of source dims.");
1571
1572 if (outTy == nullptr) {
1573 ShapeVector outShape(input.dims().begin(), input.dims().end());
1574 outShape[axis] *= tiles;
1575 outTy = getParent()->uniqueTypeWithNewShape(input.getType(), outShape);
1576 }
1577
1578 return addNode(new TileNode(name, outTy, input, tiles, axis));
1579}
1580
1581TileNode *Function::createTile(llvm::StringRef name, NodeValue input,
1582 llvm::ArrayRef<unsigned_t> tiles,
1583 llvm::ArrayRef<unsigned_t> axes) {
1584 assert(tiles.size() && "The array of tiles is empty!");
1585 assert(axes.size() && "The array of axes is empty!");
1586 assert(tiles.size() == axes.size() &&
1587 "The array for tiles and axes must be equal!");
1588 TileNode *tileNode = nullptr;
1589 for (size_t idx = 0; idx < tiles.size(); ++idx) {
1590 tileNode = createTile(name.str() + "." + std::to_string(idx),
1591 tileNode ? tileNode->getResult() : input, tiles[idx],
1592 axes[idx]);
1593 }
1594 return tileNode;
1595}
1596
1597InsertTensorNode *Function::createInsertTensor(llvm::StringRef name,
1598 NodeValue big, NodeValue small,
1599 llvm::ArrayRef<dim_t> start,
1600 unsigned_t count,
1601 unsigned_t axis) {
1602 return addNode(new InsertTensorNode(name, big, small, start, count, axis));
1603}
1604
1605SliceNode *Function::createSlice(llvm::StringRef name, NodeValue input,
1606 llvm::ArrayRef<dim_t> start, TypeRef outTy) {
1607 assert(input.dims().size() == start.size() &&
1608 "Start and input dims should match");
1609 assert(outTy->dims().size() == start.size() &&
1610 "Output and start dims should match");
1611
1612 for (unsigned i = 0, e = input.dims().size(); i < e; i++) {
1613 assert(start[i] + outTy->dims()[i] <= input.dims()[i] &&
1614 "Input/Output/Start dims mismatch");
1615 }
1616
1617 TypeRef OT = getParent()->uniqueType(*outTy);
1618 return addNode(new SliceNode(name, OT, input, start));
1619}
1620
1621SliceNode *Function::createSlice(llvm::StringRef name, NodeValue input,
1622 llvm::ArrayRef<dim_t> begin,
1623 llvm::ArrayRef<dim_t> end) {
1624 std::vector<dim_t> beginV, shape;
1625 auto dims = input.dims();
1626 assert(begin.size() == end.size() && "Begin and End dimensions should match");
1627 assert(begin.size() == dims.size() &&
1628 "Begin and Input dimensions should match");
1629 for (unsigned i = 0; i < dims.size(); i++) {
1630 dim_t beginI = begin[i];
1631 dim_t endI = end[i];
1632 dim_t dimI = dims[i];
1633 (void)dimI;
1634 assert(beginI >= 0 && "Illegal Begin indices");
1635 assert(endI > 0 && "Illegal End indices");
1636 assert(beginI < dimI && "Illegal Begin indices");
1637 assert(endI <= dimI && "Illegal End indices");
1638 assert(endI > beginI && "Illegal Begin and End indices");
1639 beginV.push_back(beginI);
1640 shape.push_back(endI - beginI);
1641 }
1642
1643 auto NT = getParent()->uniqueTypeWithNewShape(input.getType(), shape);
1644 return addNode(new SliceNode(name, NT, input, beginV));
1645}
1646
1647Node *Function::createChannelShuffle(llvm::StringRef name, NodeValue input,
1648 size_t group, size_t kernel) {
1649 return addNode(
1650 new ChannelShuffleNode(name, input.getType(), input, group, kernel));
1651}
1652
1653ReshapeNode *Function::createSqueeze(llvm::StringRef name, NodeValue input,
1654 llvm::ArrayRef<dim_t> axes) {
1655 assert(!axes.empty() && "Parameter `axes` must be provided.");
1656
1657 ShapeVector shapeAxes(axes.begin(), axes.end());
1658
1659 // Sort and unique the values in axes to
1660 // 1. make sure each dim is only removed once;
1661 // 2. check if the size and value of dimensions to squeeze are valid.
1662 std::sort(shapeAxes.begin(), shapeAxes.end());
1663 shapeAxes.erase(std::unique(shapeAxes.begin(), shapeAxes.end()),
1664 shapeAxes.end());
1665 auto inDims = input.dims();
1666 assert(shapeAxes.back() < inDims.size() && "The size and value of dimensions "
1667 "to squeeze must be less than the "
1668 "input size.");
1669
1670 ShapeVector newDims;
1671 size_t j = 0;
1672 for (size_t i = 0, e = inDims.size(); i < e; i++) {
1673 if (j < shapeAxes.size() && shapeAxes[j] == i) {
1674 assert(inDims[i] == 1 && "The dimension to squeeze must be 1.");
1675 j++;
1676 } else {
1677 newDims.push_back(inDims[i]);
1678 }
1679 }
1680 return createReshape(name.str() + ".reshape", input, newDims);
1681}
1682
1683ReshapeNode *Function::createExpandDims(llvm::StringRef name, NodeValue input,
1684 llvm::ArrayRef<dim_t> axes) {
1685 assert(!axes.empty() && "Parameter `axes` must be provided.");
1686
1687 // Dimensions provided in axes are for the output tensor, so we sort them
1688 // and unique them to make sure they are processed correctly and in the
1689 // right order.
1690 ShapeVector shapeAxes(axes.begin(), axes.end());
1691 std::sort(shapeAxes.begin(), shapeAxes.end());
1692 shapeAxes.erase(std::unique(shapeAxes.begin(), shapeAxes.end()),
1693 shapeAxes.end());
1694
1695 const auto inDims = input.dims();
1696
1697 // The total number of dimensions in the new shape is equal to the original
1698 // shape size plus the uniqued new shape axes, which represents where to
1699 // insert dimensions of 1 into the output tensor's shape.
1700 const size_t totalNumNewDims = shapeAxes.size() + inDims.size();
1701 assert(totalNumNewDims <= max_tensor_dimensions &&
1702 "New expanded shape has too many dimensions.");
1703 assert(shapeAxes.back() < totalNumNewDims &&
1704 "Specified axis expands outside size of output tensor shape.");
1705 ShapeVector newDims;
1706 for (size_t i = 0, j = 0, k = 0; k < totalNumNewDims; k++) {
1707 if (j < shapeAxes.size() && shapeAxes[j] == k) {
1708 newDims.push_back(1);
1709 j++;
1710 } else {
1711 assert(i < inDims.size() && "Somehow overflowing inDims.");
1712 newDims.push_back(inDims[i]);
1713 i++;
1714 }
1715 }
1716
1717 // Create a reshape of the original data with the newly determined
1718 // dimensions.
1719 return createReshape(name.str() + ".expanddims", input, newDims);
1720}
1721
1722ReshapeNode *Function::createFlatten(llvm::StringRef name, NodeValue input,
1723 unsigned_t axis) {
1724 std::pair<dim_t, dim_t> xDim;
1725 if (axis == 0) {
1726 dim_t d = 1;
1727 for (auto dim : input.getType()->dims()) {
1728 d *= dim;
1729 }
1730 xDim = {1, d};
1731 } else {
1732 xDim = flattenCdr(input.getType()->dims(), axis);
1733 }
1734 return createReshape(name, input, {xDim.first, xDim.second});
1735}
1736
1737ReshapeNode *Function::createFlattenV1(llvm::StringRef name, NodeValue input,
1738 unsigned_t axis) {
1739 auto xDim = collapseShape(input.getType()->dims(), axis);
1740 return createReshape(name, input, {xDim.first, xDim.second});
1741}
1742
1743void Function::createSplit(llvm::StringRef name, NodeValue input,
1744 unsigned_t outputNum, unsigned_t axis,
1745 llvm::ArrayRef<dim_t> split,
1746 std::vector<SliceNode *> &outputs) {
1747 auto inDims = input.dims();
1748 if (split.empty()) {
1749 assert(inDims[axis] % outputNum == 0 &&
1750 "Dimension to split must be divisible by outputs number.");
1751 } else {
1752 assert(outputNum == split.size() &&
1753 "Number of splits must be divisible by outputs number.");
1754 }
1755
1756 ShapeVector start(inDims.size(), 0);
1757 ShapeVector end(inDims.begin(), inDims.end());
1758 end[axis] = 0;
1759
1760 outputs.resize(outputNum);
1761 for (size_t i = 0; i < outputNum; i++) {
1762 size_t curLength = split.empty() ? inDims[axis] / outputNum : split[i];
1763 end[axis] += curLength;
1764 outputs[i] =
1765 createSlice(name.str() + ".out" + std::to_string(i), input, start, end);
1766 start[axis] = end[axis];
1767 }
1768
1769 assert(end[axis] == inDims[axis] &&
1770 "Total size of results must be equal to input size.");
1771}
1772
1773BatchNormalizationNode *Function::createBatchNormalization(
1774 llvm::StringRef name, TypeRef resType, NodeValue input, NodeValue beta,
1775 NodeValue scale, NodeValue mean, NodeValue var, unsigned_t channelIdx,
1776 float epsilon, float momentum) {
1777 return addNode(new BatchNormalizationNode(name, resType, input, scale, beta,
1778 mean, var, channelIdx, epsilon,
1779 momentum));
1780}
1781
1782InstanceNormalizationNode *
1783Function::createInstanceNormalization(llvm::StringRef name, NodeValue input,
1784 NodeValue beta, NodeValue scale,
1785 unsigned_t channelIdx, float epsilon) {
1786 return addNode(new InstanceNormalizationNode(name, input, scale, beta,
1787 channelIdx, epsilon));
1788}
1789
1790LayerNormalizationNode *
1791Function::createLayerNormalization(llvm::StringRef name, TypeRef outTy,
1792 NodeValue input, NodeValue scale,
1793 NodeValue bias, float epsilon) {
1794 return addNode(
1795 new LayerNormalizationNode(name, outTy, input, scale, bias, epsilon));
1796}
1797
1798BucketizeNode *Function::createBucketizeNode(llvm::StringRef name,
1799 NodeValue input,
1800 llvm::ArrayRef<float> boundaries) {
1801 auto OT = getParent()->uniqueType(ElemKind::Int32ITy, input.dims());
1802 return addNode(new BucketizeNode(name, OT, input, boundaries));
1803}
1804
1805LocalResponseNormalizationNode *Function::createLocalResponseNormalization(
1806 llvm::StringRef name, NodeValue input, unsigned_t halfWindowSize,
1807 float alpha, float beta, float k) {
1808 // The output tensor is of the same shape as the input tensor.
1809 return addNode(new LocalResponseNormalizationNode(name, input, halfWindowSize,
1810 alpha, beta, k));
1811}
1812
1813ModuloNode *Function::createModulo(llvm::StringRef name, NodeValue input,
1814 int64_t divisor, bool signFollowDivisor) {
1815 // The output tensor is of the same shape as the input tensor.
1816 auto OT = getParent()->uniqueType(*input.getType());
1817 return addNode(new ModuloNode(name, OT, input, divisor, signFollowDivisor));
1818}
1819
1820NotNode *Function::createNot(llvm::StringRef name, NodeValue input) {
1821 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, input.dims());
1822 return addNode(new NotNode(name, OT, input));
1823}
1824
1825BitwiseNotNode *Function::createBitwiseNot(llvm::StringRef name,
1826 NodeValue input) {
1827 TypeRef OT = getParent()->uniqueType(*input.getType());
1828 return addNode(new BitwiseNotNode(name, OT, input));
1829}
1830
1831#define UNARY_ARITHMETIC_FUN_DEF(NODE_NAME_) \
1832 NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \
1833 NodeValue input) { \
1834 return create##NODE_NAME_(name, input.getType(), input); \
1835 } \
1836 NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \
1837 TypeRef T, NodeValue input) { \
1838 TypeRef OT = getParent()->uniqueType(*T); \
1839 return addNode(new NODE_NAME_##Node(name, OT, input)); \
1840 }
1841UNARY_ARITHMETIC_FUN_DEF(Abs)
1842UNARY_ARITHMETIC_FUN_DEF(Neg)
1843UNARY_ARITHMETIC_FUN_DEF(Floor)
1844UNARY_ARITHMETIC_FUN_DEF(Sign)
1845UNARY_ARITHMETIC_FUN_DEF(Ceil)
1846UNARY_ARITHMETIC_FUN_DEF(Round)
1847UNARY_ARITHMETIC_FUN_DEF(Sqrt)
1848UNARY_ARITHMETIC_FUN_DEF(Rsqrt)
1849UNARY_ARITHMETIC_FUN_DEF(Reciprocal)
1850UNARY_ARITHMETIC_FUN_DEF(Sin)
1851UNARY_ARITHMETIC_FUN_DEF(Cos)
1852UNARY_ARITHMETIC_FUN_DEF(Erf)
1853UNARY_ARITHMETIC_FUN_DEF(Truncate)
1854UNARY_ARITHMETIC_FUN_DEF(HardSwish)
1855#undef UNARY_ARITHMETIC_FUN_DEF
1856
1857#define ARITHMETIC_FUN_DEF(NODE_NAME_) \
1858 NODE_NAME_##Node *Function::create##NODE_NAME_( \
1859 llvm::StringRef name, NodeValue LHS, NodeValue RHS) { \
1860 return create##NODE_NAME_(name, LHS.getType(), LHS, RHS); \
1861 } \
1862 NODE_NAME_##Node *Function::create##NODE_NAME_( \
1863 llvm::StringRef name, TypeRef T, NodeValue LHS, NodeValue RHS) { \
1864 DCHECK(LHS.dims() == RHS.dims()) \
1865 << "Invalid operand shapes LHS:" << LHS.getNode()->getName().str() \
1866 << " RHS: " << RHS.getNode()->getName().str() << " " << LHS.dims() \
1867 << " vs " << RHS.dims(); \
1868 TypeRef OT = getParent()->uniqueType(*T); \
1869 return addNode(new NODE_NAME_##Node(name, OT, LHS, RHS)); \
1870 }
1871ARITHMETIC_FUN_DEF(Add);
1872ARITHMETIC_FUN_DEF(Mul);
1873ARITHMETIC_FUN_DEF(Sub);
1874ARITHMETIC_FUN_DEF(Div);
1875ARITHMETIC_FUN_DEF(Max);
1876ARITHMETIC_FUN_DEF(Min);
1877ARITHMETIC_FUN_DEF(Pow);
1878ARITHMETIC_FUN_DEF(And);
1879ARITHMETIC_FUN_DEF(Or);
1880ARITHMETIC_FUN_DEF(Xor);
1881ARITHMETIC_FUN_DEF(BitwiseAnd);
1882ARITHMETIC_FUN_DEF(BitwiseOr);
1883ARITHMETIC_FUN_DEF(BitwiseXor);
1884ARITHMETIC_FUN_DEF(Fmod);
1885#undef ARITHMETIC_FUN_DEF
1886
1887#define TRIGONOMETRIC_FUN_DEF(NODE_NAME_) \
1888 NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \
1889 NodeValue input) { \
1890 return create##NODE_NAME_(name, input.getType(), input); \
1891 } \
1892 NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \
1893 TypeRef T, NodeValue input) { \
1894 TypeRef OT = getParent()->uniqueType(*T); \
1895 return addNode(new NODE_NAME_##Node(name, OT, input)); \
1896 }
1897
1898TRIGONOMETRIC_FUN_DEF(Acos)
1899TRIGONOMETRIC_FUN_DEF(Asin)
1900TRIGONOMETRIC_FUN_DEF(Atan)
1901#undef TRIGONOMETRIC_FUN_DEF
1902
1903FloorDivNode *Function::createFloorDiv(llvm::StringRef name, NodeValue LHS,
1904 NodeValue RHS, bool truncate) {
1905 return createFloorDiv(name, LHS.getType(), LHS, RHS, truncate);
1906}
1907
1908FloorDivNode *Function::createFloorDiv(llvm::StringRef name, TypeRef outTy,
1909 NodeValue LHS, NodeValue RHS,
1910 bool truncate) {
1911 DCHECK(LHS.dims() == RHS.dims())
1912 << "Invalid operand shapes LHS:" << LHS.getNode()->getName().str()
1913 << " RHS: " << RHS.getNode()->getName().str() << " " << LHS.dims()
1914 << " vs " << RHS.dims();
1915 TypeRef OT = getParent()->uniqueType(*outTy);
1916 return addNode(new FloorDivNode(name, OT, LHS, RHS, truncate));
1917}
1918
1919FloorDivNode *Function::createFloorDivWithBroadcast(llvm::StringRef name,
1920 int axis, NodeValue LHS,
1921 NodeValue RHS,
1922 bool truncate) {
1923 std::vector<NodeValue> inputs = broadcastInputs(axis, {LHS, RHS});
1924 return createFloorDiv(name, inputs[0].getType(), inputs[0], inputs[1],
1925 truncate);
1926}
1927
1928FloorDivNode *Function::createFloorDivWithBroadcast(llvm::StringRef name,
1929 int axis, TypeRef outTy,
1930 NodeValue LHS,
1931 NodeValue RHS,
1932 bool truncate) {
1933 std::vector<NodeValue> inputs = broadcastInputs(axis, {LHS, RHS});
1934 return createFloorDiv(name, outTy, inputs[0], inputs[1], truncate);
1935}
1936
1937CmpLTENode *Function::createCmpLTE(llvm::StringRef name, NodeValue LHS,
1938 NodeValue RHS) {
1939 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1940 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1941 return addNode(new CmpLTENode(name, OT, LHS, RHS));
1942}
1943
1944CmpLTNode *Function::createCmpLT(llvm::StringRef name, NodeValue LHS,
1945 NodeValue RHS) {
1946 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1947 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1948 return addNode(new CmpLTNode(name, OT, LHS, RHS));
1949}
1950
1951CmpLTENode *Function::createCmpGTE(llvm::StringRef name, NodeValue LHS,
1952 NodeValue RHS) {
1953 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1954 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1955 return addNode(new CmpLTENode(name, OT, RHS, LHS));
1956}
1957
1958CmpLTNode *Function::createCmpGT(llvm::StringRef name, NodeValue LHS,
1959 NodeValue RHS) {
1960 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1961 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1962 return addNode(new CmpLTNode(name, OT, RHS, LHS));
1963}
1964
1965CmpEQNode *Function::createCmpEQ(llvm::StringRef name, NodeValue LHS,
1966 NodeValue RHS) {
1967 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1968 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1969 return addNode(new CmpEQNode(name, OT, LHS, RHS));
1970}
1971
1972CmpNEQNode *Function::createCmpNEQ(llvm::StringRef name, NodeValue LHS,
1973 NodeValue RHS) {
1974 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1975 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1976 return addNode(new CmpNEQNode(name, OT, LHS, RHS));
1977}
1978
1979MulNode *Function::createSquare(llvm::StringRef name, NodeValue input) {
1980 return createMul(name, input, input);
1981}
1982
1983MulNode *Function::createSquare(llvm::StringRef name, TypeRef outTy,
1984 NodeValue input) {
1985 return createMul(name, outTy, input, input);
1986}
1987
1988LeakyReluNode *Function::createLeakyRELU(llvm::StringRef name, NodeValue input,
1989 float alpha) {
1990 return addNode(new LeakyReluNode(name, input.getType(), input, alpha));
1991}
1992
1993LeakyReluNode *Function::createLeakyRELU(llvm::StringRef name, TypeRef outTy,
1994 NodeValue input, float alpha) {
1995 return addNode(new LeakyReluNode(name, outTy, input, alpha));
1996}
1997
1998IsNaNNode *Function::createIsNaN(llvm::StringRef name, NodeValue input) {
1999 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, input.dims());
2000 return addNode(new IsNaNNode(name, OT, input));
2001}
2002
2003ReplaceNaNNode *Function::createReplaceNaN(llvm::StringRef name,
2004 NodeValue input, float value) {
2005 return addNode(new ReplaceNaNNode(name, input.getType(), input, value));
2006}
2007
2008PowNode *Function::createPow(llvm::StringRef name, NodeValue base, float exp) {
2009 auto *SP = createSplat(name, base.getType(), exp);
2010 return createPow(name, base, SP);
2011}
2012
2013LogNode *Function::createLog(llvm::StringRef name, NodeValue input) {
2014 return createLog(name, input.getType(), input);
2015}
2016
2017LogNode *Function::createLog(llvm::StringRef name, TypeRef outTy,
2018 NodeValue input) {
2019 return addNode(new LogNode(name, outTy, input));
2020}
2021
2022LogNode *Function::createLog(llvm::StringRef name, NodeValue input,
2023 TypeRef outTy) {
2024 return createLog(name, outTy, input);
2025}
2026
2027ExpNode *Function::createExp(llvm::StringRef name, NodeValue input) {
2028 return addNode(new ExpNode(name, input.getType(), input));
2029}
2030
2031ExpNode *Function::createExp(llvm::StringRef name, TypeRef outTy,
2032 NodeValue input) {
2033 return addNode(new ExpNode(name, outTy, input));
2034}
2035
2036LogitNode *Function::createLogit(llvm::StringRef name, NodeValue input,
2037 float eps) {
2038 return addNode(new LogitNode(name, input.getType(), input, eps));
2039}
2040
2041NonZeroNode *Function::createNonZero(llvm::StringRef name, NodeValue Cond) {
2042 auto outTy = getParent()->uniqueType(ElemKind::Int32ITy, {Cond.dims()[0], 1});
2043 return addNode(new NonZeroNode(name, outTy, Cond));
2044}
2045
2046SelectNode *Function::createSelect(llvm::StringRef name, TypeRef outTy,
2047 NodeValue Cond, NodeValue LHS,
2048 NodeValue RHS) {
2049 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
2050 assert(LHS.dims() == Cond.dims() && "Invalid operand shapes");
2051 assert(LHS.dims() == outTy->dims() && "Invalid result shape");
2052 auto OT = getParent()->uniqueType(*outTy);
2053 return addNode(new SelectNode(name, OT, Cond, LHS, RHS));
2054}
2055
2056SelectNode *Function::createSelect(llvm::StringRef name, NodeValue Cond,
2057 NodeValue LHS, NodeValue RHS) {
2058 auto inDims = LHS.dims();
2059 assert(inDims.size() > 0);
2060 ShapeVector outDims(inDims.begin(), inDims.end());
2061 auto OT = getParent()->uniqueTypeWithNewShape(LHS.getType(), outDims);
2062 return createSelect(name, OT, Cond, LHS, RHS);
2063}
2064
2065SplatNode *Function::createSplat(llvm::StringRef name, TypeRef ty,
2066 float value) {
2067 return addNode(new SplatNode(name, getParent()->uniqueType(*ty), value));
2068}
2069
2070TouchNode *Function::createTouch(llvm::StringRef name, TypeRef ty) {
2071 return addNode(new TouchNode(name, getParent()->uniqueType(*ty)));
2072}
2073
2074MatMulNode *Function::createMatMul(llvm::StringRef name, TypeRef outTy,
2075 NodeValue lhs, NodeValue rhs) {
2076 return addNode(
2077 new MatMulNode(name, getParent()->uniqueType(*outTy), lhs, rhs));
2078}
2079
2080MatMulNode *Function::createMatMul(llvm::StringRef name, NodeValue lhs,
2081 NodeValue rhs) {
2082 auto LT = lhs.getType();
2083 auto RT = rhs.getType();
2084 auto LDims = LT->dims();
2085 auto RDims = RT->dims();
2086 assert(lhs.getType()->getElementType() == rhs.getType()->getElementType());
2087
2088 auto ty =
2089 getParent()->uniqueTypeWithNewShape(lhs.getType(), {LDims[0], RDims[1]});
2090 return createMatMul(name, ty, lhs, rhs);
2091}
2092
2093BatchMatMulNode *Function::createBatchMatMul(llvm::StringRef name,
2094 NodeValue LHS, NodeValue RHS) {
2095 const size_t numDimsLHS = LHS.dims().size();
2096 if (numDimsLHS > 3) {
2097 const size_t numDimsRHS = RHS.dims().size();
2098 std::vector<dim_t> newLHSShape = {0, LHS.dims()[numDimsLHS - 2],
2099 LHS.dims()[numDimsLHS - 1]};
2100 newLHSShape[0] = LHS.getType()->size() / (newLHSShape[1] * newLHSShape[2]);
2101 LHS = createReshape(name.str() + ".reshapeLHS3D", LHS, newLHSShape);
2102 std::vector<dim_t> newRHSShape = {0, RHS.dims()[numDimsRHS - 2],
2103 RHS.dims()[numDimsRHS - 1]};
2104 newRHSShape[0] = RHS.getType()->size() / (newRHSShape[1] * newRHSShape[2]);
2105 RHS = createReshape(name.str() + ".reshapeRHS3D", RHS, newRHSShape);
2106 }
2107
2108 const size_t numDimsRHS = RHS.dims().size();
2109 assert(LHS.dims().size() == 3 && "LHS must be 3 dimensional.");
2110 assert((numDimsRHS == 2 || numDimsRHS == 3) &&
2111 "RHS must be 2 or 3 dimensional.");
2112 // If necessary, expand the RHS input to be 3D by adding initial leading
2113 // dim.
2114 if (numDimsRHS == 2) {
2115 RHS = createExpandDims(name.str() + ".reshapeRHS", RHS, {0});
2116 }
2117 // If necessary, Tile the RHS input so it matches the numBatches of LHS.
2118 if (RHS.dims()[0] == 1 && LHS.dims()[0] != 1) {
2119 RHS = createTile(name.str() + ".tileRHS", RHS, LHS.dims()[0], /*axis */ 0);
2120 }
2121
2122 // LHS = {numBatches, N, M}
2123 // RHS = {numBatches, M, P}
2124 // Result = {numBatches, N, P}
2125 const dim_t numBatches = LHS.dims()[0];
2126 const dim_t N = LHS.dims()[1];
2127 const dim_t M = LHS.dims()[2];
2128 (void)M;
2129 const dim_t P = RHS.dims()[2];
2130 assert((RHS.dims()[0] == numBatches) && "Batch sizes are invalid.");
2131 assert((RHS.dims()[1] == M) && "Batch matmul dimensions are invalid.");
2132
2133 auto OT =
2134 getParent()->uniqueTypeWithNewShape(LHS.getType(), {numBatches, N, P});
2135 return addNode(new BatchMatMulNode(name, OT, LHS, RHS));
2136}
2137
2138BatchedReduceAddNode *
2139Function::createBatchedReduceAdd(llvm::StringRef name, TypeRef outTy,
2140 NodeValue batch,
2141 llvm::ArrayRef<unsigned_t> axes) {
2142 assert(axes.size() == 1 && "Only supporting single reduction for now.");
2143 auto axis = axes[0];
2144
2145 // Calculate the expected total number of elements in the output tensor
2146 // based on the number of elements in the batch divided by the axis
2147 // dimension.
2148 const size_t outNumElements = batch.getType()->size() / batch.dims()[axis];
2149 (void)outNumElements;
2150 assert(outTy->size() == outNumElements &&
2151 "Incorrect number of elements in the output type.");
2152 auto OT = getParent()->uniqueType(*outTy);
2153 return addNode(new BatchedReduceAddNode(name, OT, batch, axis));
2154}
2155
2156BatchedReduceSumSquareNode *
2157Function::createBatchedReduceSumSquare(llvm::StringRef name, NodeValue batch,
2158 llvm::ArrayRef<unsigned_t> axes) {
2159 auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
2160 auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims);
2161 return createBatchedReduceSumSquare(name, OT, batch, axes);
2162}
2163
2164BatchedReduceSumSquareNode *
2165Function::createBatchedReduceSumSquare(llvm::StringRef name, TypeRef outTy,
2166 NodeValue batch,
2167 llvm::ArrayRef<unsigned_t> axes) {
2168 assert(axes.size() == 1 && "Only supporting single reduction for now.");
2169 auto axis = axes[0];
2170
2171 // Calculate the expected total number of elements in the output tensor
2172 // based on the number of elements in the batch divided by the axis
2173 // dimension.
2174 const size_t outNumElements = batch.getType()->size() / batch.dims()[axis];
2175 (void)outNumElements;
2176 assert(outTy->size() == outNumElements &&
2177 "Incorrect number of elements in the output type.");
2178 auto OT = getParent()->uniqueType(*outTy);
2179 return addNode(new BatchedReduceSumSquareNode(name, OT, batch, axis));
2180}
2181
2182BatchedReduceAddNode *
2183Function::createBatchedReduceAdd(llvm::StringRef name, NodeValue batch,
2184 llvm::ArrayRef<unsigned_t> axes) {
2185 auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
2186 auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims);
2187 return createBatchedReduceAdd(name, OT, batch, axes);
2188}
2189
2190BatchedReduceMeanNode *
2191Function::createBatchedReduceMean(llvm::StringRef name, TypeRef outTy,
2192 NodeValue batch,
2193 llvm::ArrayRef<unsigned_t> axes) {
2194 auto OT = getParent()->uniqueType(*outTy);
2195 return addNode(new BatchedReduceMeanNode(name, OT, batch, axes));
2196}
2197
2198BatchedReduceMeanNode *
2199Function::createBatchedReduceMean(llvm::StringRef name, NodeValue batch,
2200 llvm::ArrayRef<unsigned_t> axes) {
2201 // Create new shape with specified dimensions either reduced or removed.
2202 auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
2203 auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims);
2204 return createBatchedReduceMean(name, OT, batch, axes);
2205}
2206
2207BatchedReduceMinNode *
2208Function::createBatchedReduceMin(llvm::StringRef name, TypeRef outTy,
2209 NodeValue batch,
2210 llvm::ArrayRef<unsigned_t> axes) {
2211 auto OT = getParent()->uniqueType(*outTy);
2212 return addNode(new BatchedReduceMinNode(name, OT, batch, axes));
2213}
2214
2215BatchedReduceMinNode *
2216Function::createBatchedReduceMin(llvm::StringRef name, NodeValue batch,
2217 llvm::ArrayRef<unsigned_t> axes) {
2218 // Create new shape with specified dimensions either reduced or removed.
2219 auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
2220 auto OT = getParent()->uniqueType(batch.getType()->getElementType(), outDims);
2221 return addNode(new BatchedReduceMinNode(name, OT, batch, axes));
2222}
2223
2224BatchedReduceMaxNode *
2225Function::createBatchedReduceMax(llvm::StringRef name, TypeRef outTy,
2226 NodeValue batch,
2227 llvm::ArrayRef<unsigned_t> axes) {
2228 auto OT = getParent()->uniqueType(*outTy);
2229 return addNode(new BatchedReduceMaxNode(name, OT, batch, axes));
2230}
2231
2232BatchedReduceMaxNode *
2233Function::createBatchedReduceMax(llvm::StringRef name, NodeValue batch,
2234 llvm::ArrayRef<unsigned_t> axes) {
2235 // Create new shape with specified dimensions either reduced or removed.
2236 auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
2237 auto OT = getParent()->uniqueType(batch.getType()->getElementType(), outDims);
2238 return addNode(new BatchedReduceMaxNode(name, OT, batch, axes));
2239}
2240
2241BatchedReduceProdNode *
2242Function::createBatchedReduceProd(llvm::StringRef name, TypeRef outTy,
2243 NodeValue batch,
2244 llvm::ArrayRef<unsigned_t> axes) {
2245 assert(axes.size() == 1 && "Only supporting single reduction for now.");
2246 auto axis = axes[0];
2247
2248 // Calculate the expected total number of elements in the output tensor
2249 // based on the number of elements in the batch divided by the axis
2250 // dimension.
2251 const size_t outNumElements = batch.getType()->size() / batch.dims()[axis];
2252 (void)outNumElements;
2253 assert(outTy->size() == outNumElements &&
2254 "Incorrect number of elements in the output type.");
2255 auto OT = getParent()->uniqueType(*outTy);
2256 return addNode(new BatchedReduceProdNode(name, OT, batch, axis));
2257}
2258
2259BatchedReduceProdNode *
2260Function::createBatchedReduceProd(llvm::StringRef name, NodeValue batch,
2261 llvm::ArrayRef<unsigned_t> axes) {
2262 auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
2263 auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims);
2264 return createBatchedReduceProd(name, OT, batch, axes);
2265}
2266
2267BatchedAddNode *Function::createBatchedAdd(llvm::StringRef name,
2268 NodeValue batch, NodeValue slice) {
2269 return addNode(new BatchedAddNode(name, batch.getType(), batch, slice));
2270}
2271
2272BatchedAddNode *Function::createBatchedAdd(llvm::StringRef name, TypeRef outTy,
2273 NodeValue batch, NodeValue slice) {
2274 return addNode(
2275 new BatchedAddNode(name, getParent()->uniqueType(*outTy), batch, slice));
2276}
2277
2278BatchedMulNode *Function::createBatchedMul(llvm::StringRef name,
2279 NodeValue batch, NodeValue slice) {
2280 return addNode(new BatchedMulNode(name, batch.getType(), batch, slice));
2281}
2282
2283BatchedMulNode *Function::createBatchedMul(llvm::StringRef name, TypeRef outTy,
2284 NodeValue batch, NodeValue slice) {
2285 return addNode(
2286 new BatchedMulNode(name, getParent()->uniqueType(*outTy), batch, slice));
2287}
2288
2289CumSumNode *Function::createCumSum(llvm::StringRef name, NodeValue input,
2290 int64_t dim, bool exclusive, bool reverse) {
2291 return addNode(
2292 new CumSumNode(name, input.getType(), input, dim, exclusive, reverse));
2293}
2294
2295LengthsSumNode *Function::createLengthsSum(llvm::StringRef name, NodeValue data,
2296 NodeValue lengths) {
2297 ShapeVector outDims(data.dims().begin(), data.dims().end());
2298 outDims[0] = lengths.dims()[0];
2299 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
2300 return addNode(new LengthsSumNode(name, outTy, data, lengths));
2301}
2302
2303SparseLengthsSumNode *
2304Function::createSparseLengthsSum(llvm::StringRef name, NodeValue data,
2305 NodeValue indices, NodeValue lengths,
2306 LengthsMode lengthsMode, float avgLength) {
2307 auto inDims = data.dims();
2308 ShapeVector outDims(inDims.begin(), inDims.end());
2309 outDims[0] = lengths.dims()[0];
2310 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
2311 return addNode(new SparseLengthsSumNode(name, outTy, data, indices, lengths,
2312 lengthsMode, avgLength));
2313}
2314
2315SparseLengthsWeightedSumNode *Function::createSparseLengthsWeightedSum(
2316 llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
2317 NodeValue lengths, LengthsMode lengthsMode, float avgLength) {
2318 auto inDims = data.dims();
2319 ShapeVector outDims(inDims.begin(), inDims.end());
2320 outDims[0] = lengths.dims()[0];
2321 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
2322 return addNode(new SparseLengthsWeightedSumNode(
2323 name, outTy, data, weights, indices, lengths, lengthsMode, avgLength));
2324}
2325
2326SparseLengthsWeightedSumNode *Function::createSparseLengthsWeightedSum(
2327 llvm::StringRef name, TypeRef outTy, NodeValue data, NodeValue weights,
2328 NodeValue indices, NodeValue lengths, LengthsMode lengthsMode,
2329 float avgLength) {
2330 return addNode(new SparseLengthsWeightedSumNode(
2331 name, outTy, data, weights, indices, lengths, lengthsMode, avgLength));
2332}
2333
2334RowwiseQuantizedSparseLengthsWeightedSumNode *
2335Function::createRowwiseQuantizedSparseLengthsWeightedSum(
2336 llvm::StringRef name, Storage *data, NodeValue scales, NodeValue offsets,
2337 NodeValue weights, NodeValue indices, NodeValue lengths, ElemKind precision,
2338 bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2339 auto inDims = data->dims();
2340 ShapeVector outDims(inDims.begin(), inDims.end());
2341 outDims[0] = lengths.dims()[0];
2342 auto outTy = getParent()->uniqueType(precision, outDims);
2343 return addNode(new RowwiseQuantizedSparseLengthsWeightedSumNode(
2344 name, outTy, data, scales, offsets, weights, indices, lengths,
2345 useFP16Accumulation, lengthsMode, avgLength));
2346}
2347
2348RowwiseQuantizedSparseLengthsWeightedSumNode *
2349Function::createRowwiseQuantizedSparseLengthsSum(
2350 llvm::StringRef name, Storage *data, NodeValue scales, NodeValue offsets,
2351 NodeValue indices, NodeValue lengths, ElemKind precision,
2352 bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2353 auto ty = getParent()->uniqueType(precision, {indices.dims()[0]});
2354 auto ones = createSplat(name.str() + ".ones", ty, 1.0);
2355 return createRowwiseQuantizedSparseLengthsWeightedSum(
2356 name, data, scales, offsets, ones, indices, lengths, precision,
2357 useFP16Accumulation, lengthsMode, avgLength);
2358}
2359
2360/// Helper to create a RowwiseQuantizedSparseLengthsWeightedSumNode in the
2361/// Function \p F with \p name, using \ data, \p weights, \p indices, and \p
2362/// lengths as inputs. The provided float data in \p Tensor is rowwise
2363/// quantized, creating Constants for the rowwise quantized data as well as
2364/// Scales and Offsets, in the Module containing \p F.
2365static RowwiseQuantizedSparseLengthsWeightedSumNode *
2366quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
2367 Function *F, llvm::StringRef name, Tensor &data, NodeValue weights,
2368 NodeValue indices, NodeValue lengths, quantization::Schema schema,
2369 ElemKind precision, bool useFP16Accumulation, LengthsMode lengthsMode,
2370 float avgLength) {
2371 auto inDims = data.dims();
2372
2373 // Note: In rwqData, we are using a quantized type, however the scale/offset
2374 // are set to dummy values 0.0/0. This is because the actually used
2375 // scale/offset come from dataScales and dataOffsets.
2376 Constant *rwqData = F->getParent()->createConstant(ElemKind::UInt8QTy, inDims,
2377 0.0, 0, "data");
2378 Constant *dataScales =
2379 F->getParent()->createConstant(precision, {inDims[0]}, "dataScales");
2380 Constant *dataOffsets =
2381 F->getParent()->createConstant(precision, {inDims[0]}, "dataOffsets");
2382
2383 // Note: Using floating point offset here as that is what RWQ-SLWS expects.
2384 switch (precision) {
2385 case ElemKind::FloatTy:
2386 quantization::tensorRowwiseQuantization<float, float, uint8_t>(
2387 data, rwqData->getPayloadMutable(), dataScales->getPayloadMutable(),
2388 dataOffsets->getPayloadMutable(), schema);
2389 break;
2390 case ElemKind::Float16Ty:
2391 quantization::tensorRowwiseQuantization<float16_t, float16_t, uint8_t>(
2392 data, rwqData->getPayloadMutable(), dataScales->getPayloadMutable(),
2393 dataOffsets->getPayloadMutable(), schema);
2394 break;
2395 default:
2396 LOG(FATAL) << "Unsupported precision for RWQ-SLWS.";
2397 }
2398 return F->createRowwiseQuantizedSparseLengthsWeightedSum(
2399 name, rwqData, dataScales, dataOffsets, weights, indices, lengths,
2400 precision, useFP16Accumulation, lengthsMode, avgLength);
2401}
2402
2403RowwiseQuantizedSparseLengthsWeightedSumNode *
2404Function::createRowwiseQuantizedSparseLengthsWeightedSum(
2405 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
2406 NodeValue lengths, quantization::Schema schema, ElemKind precision,
2407 bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2408 return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
2409 this, name, data, weights, indices, lengths, schema, precision,
2410 useFP16Accumulation, lengthsMode, avgLength);
2411}
2412
2413RowwiseQuantizedSparseLengthsWeightedSumNode *
2414Function::createRowwiseQuantizedSparseLengthsSum(
2415 llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
2416 quantization::Schema schema, ElemKind precision, bool useFP16Accumulation,
2417 LengthsMode lengthsMode, float avgLength) {
2418 auto ty = getParent()->uniqueType(precision, {indices.dims()[0]});
2419 auto ones = createSplat(name.str() + ".ones", ty, 1.0);
2420 return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
2421 this, name, data, ones, indices, lengths, schema, precision,
2422 useFP16Accumulation, lengthsMode, avgLength);
2423}
2424
2425/// Helper used to get specific output type required for
2426/// createRowwiseQuantizedSparseLengthsSum,
2427/// createRowwiseQuantizedSparseLengthsWeightedSum, and
2428/// EmbeddingBagByteRowwiseOffsets. Function \p F is used to get the specific
2429/// type, using inputs \p data and \p segmentsDim to compute output dimensions.
2430static TypeRef
2431getOutputTypeOfFusedRowwiseQuantizedSLS(Function *F, NodeValue data,
2432 llvm::ArrayRef<dim_t> segmentsDim) {
2433 ShapeVector outDims(data.dims().begin(), data.dims().end());
2434 outDims[0] = segmentsDim[0];
2435 // The output column count is the same as the input column count, but
2436 // without the extra bytes for the fused scale/offset, as the output is not
2437 // fused.
2438 CHECK(isFusedQuantizedElemKind(data.getElementType()))
2439 << "Must use a fused ElemKind for data.";
2440 outDims[1] -= 2 * ((data.getElementType() == ElemKind::UInt8FusedQTy ||
2441 data.getElementType() == ElemKind::UInt4FusedQTy)
2442 ? sizeof(float)
2443 : sizeof(float16_t));
2444 // If using 4-bit quantization, then the input data has packed two 4-bit
2445 // elements into one byte, so we need to double the outDims.
2446 if (data.getElementType() == ElemKind::UInt4FusedFP16QTy ||
2447 data.getElementType() == ElemKind::UInt4FusedQTy) {
2448 outDims[1] *= 2;
2449 }
2450 const ElemKind outputK = (data.getElementType() == ElemKind::UInt8FusedQTy ||
2451 data.getElementType() == ElemKind::UInt4FusedQTy)
2452 ? ElemKind::FloatTy
2453 : ElemKind::Float16Ty;
2454 return F->getParent()->uniqueType(outputK, outDims);
2455}
2456
2457FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
2458Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum(
2459 llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
2460 NodeValue lengths, bool useFP16Accumulation, LengthsMode lengthsMode,
2461 float avgLength) {
2462 auto outTy =
2463 getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims());
2464 return addNode(new FusedRowwiseQuantizedSparseLengthsWeightedSumNode(
2465 name, outTy, data, weights, indices, lengths, useFP16Accumulation,
2466 lengthsMode, avgLength));
2467}
2468
2469FusedRowwiseQuantizedSparseLengthsSumNode *
2470Function::createFusedRowwiseQuantizedSparseLengthsSum(
2471 llvm::StringRef name, Storage *data, NodeValue indices, NodeValue lengths,
2472 bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2473 auto outTy =
2474 getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims());
2475 return addNode(new FusedRowwiseQuantizedSparseLengthsSumNode(
2476 name, outTy, data, indices, lengths, useFP16Accumulation, lengthsMode,
2477 avgLength));
2478}
2479
2480/// Helper to get quantized data required for
2481/// RowwiseQuantizedSparseLengthsWeightedSumNode and
2482/// RowwiseQuantizedSparseLengthsSumNode. Function \p F uses float Tensor \p
2483/// data to create a rowwise qunatized Constant \p rwqData, which contains fused
2484/// scales and offsets.
2485static Constant *quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2486 Function *F, Tensor &data, ElemKind precision) {
2487 // For fused rowwise quantization, we must have a two-dimensional input. If
2488 // passed in a single dimensional data Tensor then add an extra dimension.
2489 const auto fDims = flattenCdr(data.dims());
2490 Tensor fData = data.getUnowned({fDims.first, fDims.second});
2491
2492 // Note: In rwqData, we are using a quantized type, however the scale/offset
2493 // are set to dummy values 0.0/0. This is because the actually used
2494 // scale/offset are fused inline with each row. Also, we expand the second
2495 // dimension to include space for the scale/offset, each 4 bytes
2496 // (float/int32_t).
2497 switch (precision) {
2498 case ElemKind::UInt8FusedQTy: {
2499 Constant *rwqData = F->getParent()->createConstant(
2500 precision, {fDims.first, fDims.second + 2 * (dim_t)sizeof(float)}, 0.0,
2501 0, "data");
2502 quantization::tensorFusedRowwiseQuantization<float>(
2503 fData, rwqData->getPayloadMutable());
2504 return rwqData;
2505 }
2506 case ElemKind::UInt8FusedFP16QTy: {
2507 Constant *rwqData = F->getParent()->createConstant(
2508 precision, {fDims.first, fDims.second + 2 * (dim_t)sizeof(float16_t)},
2509 0.0, 0, "data");
2510 quantization::tensorFusedRowwiseQuantization<float16_t>(
2511 fData, rwqData->getPayloadMutable());
2512 return rwqData;
2513 }
2514 case ElemKind::UInt4FusedFP16QTy: {
2515 // We pack 4-bit values into bytes, so given the input size in float we
2516 // divide by two and take the ceiling to make sure we have enough space for
2517 // all elements.
2518 const dim_t outerDim =
2519 std::ceil(((float)fDims.second) / 2) + 2 * sizeof(float16_t);
2520 Constant *rwqData = F->getParent()->createConstant(
2521 precision, {fDims.first, outerDim}, 0.0, 0, "data");
2522 quantization::tensorFusedRowwiseQuantization<float16_t>(
2523 fData, rwqData->getPayloadMutable());
2524 return rwqData;
2525 }
2526 case ElemKind::UInt4FusedQTy: {
2527 // We pack 4-bit values into bytes, so given the input size in float we
2528 // divide by two and take the ceiling to make sure we have enough space for
2529 // all elements.
2530 const dim_t outerDim =
2531 std::ceil(((float)fDims.second) / 2) + 2 * sizeof(float);
2532 Constant *rwqData = F->getParent()->createConstant(
2533 precision, {fDims.first, outerDim}, 0.0, 0, "data");
2534 quantization::tensorFusedRowwiseQuantization<float>(
2535 fData, rwqData->getPayloadMutable());
2536 return rwqData;
2537 }
2538 default:
2539 llvm_unreachable("Invalid type for FusedRowwiswQuantization.");
2540 }
2541}
2542
2543FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
2544Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum(
2545 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
2546 NodeValue lengths, ElemKind fusedElemKind, bool useFP16Accumulation,
2547 LengthsMode lengthsMode, float avgLength) {
2548 Constant *rwqData =
2549 quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2550 this, data, fusedElemKind);
2551 return createFusedRowwiseQuantizedSparseLengthsWeightedSum(
2552 name, rwqData, weights, indices, lengths, useFP16Accumulation,
2553 lengthsMode, avgLength);
2554}
2555
2556FusedRowwiseQuantizedSparseLengthsSumNode *
2557Function::createFusedRowwiseQuantizedSparseLengthsSum(
2558 llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
2559 ElemKind fusedElemKind, bool useFP16Accumulation, LengthsMode lengthsMode,
2560 float avgLength) {
2561 Constant *rwqData =
2562 quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2563 this, data, fusedElemKind);
2564 return this->createFusedRowwiseQuantizedSparseLengthsSum(
2565 name, rwqData, indices, lengths, useFP16Accumulation, lengthsMode,
2566 avgLength);
2567}
2568
2569EmbeddingNode *Function::createEmbedding(llvm::StringRef name,
2570 NodeValue weights, NodeValue indices,
2571 int32_t padIdx, bool scale,
2572 bool sparse) {
2573 auto indDims = indices.dims();
2574 auto wtDims = weights.dims();
2575
2576 assert(wtDims.size() == 2 && "weights must be a 2D tensor");
2577
2578 ShapeVector outDims(indDims.begin(), indDims.end());
2579 dim_t embedding_dim = wtDims[1];
2580 outDims.push_back(embedding_dim);
2581
2582 auto outTy = getParent()->uniqueTypeWithNewShape(weights.getType(), outDims);
2583 return addNode(
2584 new EmbeddingNode(name, outTy, weights, indices, padIdx, scale, sparse));
2585}
2586
2587EmbeddingBagNode *
2588Function::createEmbeddingBag(llvm::StringRef name, NodeValue data,
2589 NodeValue weights, NodeValue indices,
2590 NodeValue offsets, bool hasEndOffset,
2591 LengthsMode lengthsMode, float avgLength) {
2592 auto inDims = data.dims();
2593 ShapeVector outDims(inDims.begin(), inDims.end());
2594 outDims[0] = hasEndOffset ? offsets.dims()[0] - 1 : offsets.dims()[0];
2595 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
2596 return addNode(new EmbeddingBagNode(name, outTy, data, weights, indices,
2597 offsets, hasEndOffset, lengthsMode,
2598 avgLength));
2599}
2600
2601EmbeddingBagByteRowwiseOffsetsNode *
2602Function::createEmbeddingBagByteRowwiseOffsets(
2603 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
2604 NodeValue offsets, ElemKind fusedElemKind, bool useFP16Accumulation,
2605 bool hasEndOffset, LengthsMode lengthsMode, float avgLength) {
2606 Constant *rwqData =
2607 quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2608 this, data, fusedElemKind);
2609 return createEmbeddingBagByteRowwiseOffsets(
2610 name, rwqData, weights, indices, offsets, useFP16Accumulation,
2611 hasEndOffset, lengthsMode, avgLength);
2612}
2613
2614EmbeddingBagByteRowwiseOffsetsNode *
2615Function::createEmbeddingBagByteRowwiseOffsets(
2616 llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
2617 NodeValue offsets, bool useFP16Accumulation, bool hasEndOffset,
2618 LengthsMode lengthsMode, float avgLength) {
2619 std::vector<dim_t> segmentDims(offsets.dims().begin(), offsets.dims().end());
2620 // If hasEndOffset the last offset is just for marking the end of the last
2621 // segment.
2622 if (hasEndOffset) {
2623 segmentDims[0] -= 1;
2624 }
2625 auto outTy = getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, segmentDims);
2626 return addNode(new EmbeddingBagByteRowwiseOffsetsNode(
2627 name, outTy, data, weights, indices, offsets, useFP16Accumulation,
2628 hasEndOffset, lengthsMode, avgLength));
2629}
2630
2631LengthsToRangesNode *Function::createLengthsToRanges(llvm::StringRef name,
2632 NodeValue lengths) {
2633 ShapeVector outDims({lengths.dims()[0], 2});
2634 auto outTy = getParent()->uniqueTypeWithNewShape(lengths.getType(), outDims);
2635 return addNode(new LengthsToRangesNode(name, outTy, lengths));
2636}
2637
2638LengthsRangeFillNode *
2639Function::createLengthsRangeFill(llvm::StringRef name, NodeValue lengths,
2640 unsigned_t maxOutputSize) {
2641 auto outTy =
2642 getParent()->uniqueTypeWithNewShape(lengths.getType(), {maxOutputSize});
2643 return addNode(new LengthsRangeFillNode(name, outTy, lengths));
2644}
2645
2646GaussianFillNode *Function::createGaussianFill(llvm::StringRef name,
2647 NodeValue input, float mean,
2648 float scale, float seed) {
2649 auto outTy = getParent()->uniqueType(ElemKind::Float16Ty, input.dims());
2650
2651 return addNode(new GaussianFillNode(name, outTy, input, mean, scale, seed));
2652}
2653
2654BatchSparseToDenseNode *Function::createBatchSparseToDense(
2655 llvm::StringRef name, NodeValue lengths, NodeValue indices,
2656 NodeValue values, float defaultValue, unsigned_t denseLastDim) {
2657 // The output is a 2-D tensor with first dim = number of lengths and second
2658 // dim = denseLastDim
2659 ShapeVector outDims({lengths.dims()[0], denseLastDim});
2660 auto outTy = getParent()->uniqueTypeWithNewShape(values.getType(), outDims);
2661 return addNode(new BatchSparseToDenseNode(
2662 name, outTy, lengths, indices, values, defaultValue, denseLastDim));
2663}
2664
2665FillExamplesWithIndicatorNode *
2666Function::createFillExamplesWithIndicator(llvm::StringRef name, NodeValue data,
2667 NodeValue indicator) {
2668 ShapeVector outDims({indicator.dims()[0]});
2669 if (data.dims().size() > 1) {
2670 outDims.insert(outDims.end(), data.dims().begin() + 1, data.dims().end());
2671 }
2672 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
2673 return addNode(
2674 new FillExamplesWithIndicatorNode(name, outTy, data, indicator));
2675}
2676
2677SparseToDenseMaskNode *Function::createSparseToDenseMask(
2678 llvm::StringRef name, NodeValue indices, NodeValue values,
2679 NodeValue defaultValue, NodeValue lengths, llvm::ArrayRef<dim_t> mask) {
2680 auto lengthsDims = lengths.dims();
2681 auto valueDims = defaultValue.dims();
2682 ShapeVector outDims = {(dim_t)mask.size()};
2683 // If lengths is 0-dimensional tensor, then there is no batch dimension.
2684 if (lengthsDims.size() > 0) {
2685 outDims.insert(outDims.begin(), lengthsDims[0]);
2686 }
2687 outDims.insert(outDims.end(), valueDims.begin(), valueDims.end());
2688 auto outTy = getParent()->uniqueTypeWithNewShape(values.getType(), outDims);
2689 return addNode(new SparseToDenseMaskNode(name, outTy, indices, values,
2690 defaultValue, lengths, mask));
2691}
2692
2693SparseLabelSplitNode *Function::createSparseLabelSplit(llvm::StringRef name,
2694 NodeValue lengths,
2695 NodeValue indices,
2696 NodeValue values,
2697 dim_t numLabels) {
2698 const auto numItems = indices.dims()[0];
2699 // The assumption here is that all output tensors (excluding offsetMap)
2700 // will have the same number of elements, i.e. numItems / numLabels.
2701 auto labelValuesTy = getParent()->uniqueTypeWithNewShape(
2702 values.getType(), {numLabels, numItems / numLabels});
2703 auto exampleIdsTy = getParent()->uniqueType(
2704 ElemKind::Int32ITy, {numLabels, numItems / numLabels});
2705 auto gradientOffsetMapTy =
2706 getParent()->uniqueType(ElemKind::Int32ITy, {indices.dims()[0]});
2707 return addNode(new SparseLabelSplitNode(name, labelValuesTy, exampleIdsTy,
2708 gradientOffsetMapTy, lengths, indices,
2709 values, numLabels));
2710}
2711
2712SaveNode *Function::createSave(llvm::StringRef name, NodeValue input) {
2713 auto *dest = getParent()->createPlaceholder(input.getType(), name, false);
2714 return createSave(name, input, dest);
2715}
2716
2717SaveNode *Function::createSave(llvm::StringRef name, NodeValue input,
2718 Placeholder *output, bool skipSuffix) {
2719 return addNode(new SaveNode(skipSuffix ? name.str() : (name + "_save").str(),
2720 input, output));
2721}
2722
2723QuantizationProfileNode *
2724Function::createQuantizationProfile(PlaceholderBindings &bindings,
2725 llvm::StringRef name, NodeValue input,
2726 dim_t numHistogramBins) {
2727 auto *histogram = getParent()->createPlaceholder(
2728 ElemKind::FloatTy, {numHistogramBins}, "histogram_" + name.str(), false);
2729 bindings.allocate(histogram)->zero();
2730 // Intermediate data used for histogram calculations.
2731 // Min tensor value seen so far is kept on the first position.
2732 // Max tensor value seen so far is kept on the second position.
2733 auto *computationInfoPH = getParent()->createPlaceholder(
2734 ElemKind::FloatTy, {2}, "CI_" + name.str(), false);
2735 bindings.allocate(computationInfoPH);
2736 auto *computationInfoTensor = bindings.get(computationInfoPH);
2737 auto handle = computationInfoTensor->getHandle<float>();
2738 handle.raw(0) = std::numeric_limits<float>::max();
2739 handle.raw(1) = std::numeric_limits<float>::lowest();
2740
2741 return addNode(new QuantizationProfileNode(
2742 "QI_" + name.str(), input, histogram, computationInfoPH,
2743 input.getNode()->getName().str(), input.getResNo()));
2744}
2745
2746template <typename T>
2747IntLookupTableNode *
2748Function::createIntLookupTable(llvm::StringRef name, NodeValue input,
2749 llvm::ArrayRef<T> initValues, TypeRef outTy) {
2750 assert(initValues.size() == input.getType()->getQuantizedValueCount() &&
2751 "Lookup table length must match input type!");
2752 assert(outTy->isType<T>() && "Lookup table element must match output type!");
2753 if (std::is_same<T, int8_t>::value) {
2754 // Create INT8 lookup table.
2755 auto *mapping = getParent()->createConstant(
2756 ElemKind::Int8QTy, {(dim_t)initValues.size()}, outTy->getScale(),
2757 outTy->getOffset(), "mapping");
2758 mapping->getHandle<T>() = initValues;
2759 return addNode(new IntLookupTableNode(name, outTy, input, mapping));
2760 } else if (std::is_same<T, int16_t>::value) {
2761 // Create INT16 lookup table.
2762 auto *mapping = getParent()->createConstant(
2763 ElemKind::Int16QTy, {(dim_t)initValues.size()}, outTy->getScale(),
2764 outTy->getOffset(), "mapping");
2765 mapping->getHandle<T>() = initValues;
2766 return addNode(new IntLookupTableNode(name, outTy, input, mapping));
2767 } else if (std::is_same<T, int32_t>::value) {
2768 // Create INT32 lookup table.
2769 auto *mapping = getParent()->createConstant(
2770 ElemKind::Int32QTy, {(dim_t)initValues.size()}, outTy->getScale(),
2771 outTy->getOffset(), "mapping");
2772 mapping->getHandle<T>() = initValues;
2773 return addNode(new IntLookupTableNode(name, outTy, input, mapping));
2774 } else {
2775 llvm_unreachable("Lookup table type not supported.");
2776 }
2777}
2778
2779IntLookupTableNode *
2780Function::createIntLookupTable(llvm::StringRef name, NodeValue input,
2781 std::function<float(float)> func,
2782 TypeRef outTy) {
2783 if (outTy->isType<int8_t>()) {
2784 std::vector<int8_t> initValues =
2785 quantization::createMapping<int8_t>(input.getType(), outTy, func);
2786 return createIntLookupTable<int8_t>(name, input, initValues, outTy);
2787 } else if (outTy->isType<int16_t>()) {
2788 std::vector<int16_t> initValues =
2789 quantization::createMapping<int16_t>(input.getType(), outTy, func);
2790 return createIntLookupTable<int16_t>(name, input, initValues, outTy);
2791 } else if (outTy->isType<int32_t>()) {
2792 std::vector<int32_t> initValues =
2793 quantization::createMapping<int32_t>(input.getType(), outTy, func);
2794 return createIntLookupTable<int32_t>(name, input, initValues, outTy);
2795 } else {
2796 llvm_unreachable("Lookup table type not supported.");
2797 }
2798}
2799
2800LookupTableNode *Function::createLookupTable(
2801 llvm::StringRef name, NodeValue input, LUTOperator lutOperator,
2802 std::vector<float> &lutOperatorArgs, NodeValue table, NodeValue idxTable,
2803 TypeRef outTy) {
2804 return addNode(new LookupTableNode(name, outTy, input, table, idxTable,
2805 lutOperator, lutOperatorArgs));
2806}
2807
2808IntLookupTableNode *Function::createIntLog(llvm::StringRef name,
2809 NodeValue input, TypeRef outTy) {
2810 auto inputRange = input.getType()->getQuantizedValueRange();
2811 (void)inputRange;
2812 assert(inputRange.first >= 0 &&
2813 "Input range must not be negative since this is input to log().");
2814 auto func = [](float x) -> float {
2815 return (x == 0.0) ? std::log(std::numeric_limits<float>::min()) : log(x);
2816 };
2817 return createIntLookupTable(name, input, func, outTy);
2818}
2819
2820IntLookupTableNode *Function::createIntExp(llvm::StringRef name,
2821 NodeValue input, TypeRef outTy) {
2822 return createIntLookupTable(name, input, expf, outTy);
2823}
2824
2825IntLookupTableNode *Function::createIntTanh(llvm::StringRef name,
2826 NodeValue input, TypeRef outTy) {
2827 return createIntLookupTable(name, input, tanhf, outTy);
2828}
2829
2830IntLookupTableNode *Function::createIntSigmoid(llvm::StringRef name,
2831 NodeValue input, TypeRef outTy) {
2832 auto func = [](float x) -> float { return 1.0f / (1.0f + expf(-x)); };
2833 return createIntLookupTable(name, input, func, outTy);
2834}
2835
2836TopKNode *Function::createTopK(llvm::StringRef name, NodeValue input,
2837 unsigned_t k, ElemKind outIndicesTyKind) {
2838 auto inDims = input.dims();
2839 assert(inDims.size() > 0);
2840 assert(k <= inDims.back());
2841 ShapeVector outDims(inDims.begin(), inDims.end());
2842 outDims.back() = k;
2843 auto OT = getParent()->uniqueTypeWithNewShape(input.getType(), outDims);
2844 return addNode(new TopKNode(
2845 name, OT, getParent()->uniqueType(outIndicesTyKind, outDims), input, k));
2846}
2847
2848TopKNode *Function::createTopK(llvm::StringRef name, NodeValue input,
2849 unsigned_t k) {
2850 return createTopK(name, input, k, ElemKind::Int64ITy);
2851}
2852
2853ArgMaxNode *Function::createArgMax(llvm::StringRef name, NodeValue input,
2854 unsigned_t axis, bool keepDims,
2855 ElemKind elemTy) {
2856 ShapeVector outDims = reduceDims(input.dims(), {axis}, keepDims);
2857 auto OT = getParent()->uniqueType(elemTy, outDims);
2858 return addNode(new ArgMaxNode(name, OT, input, axis, keepDims));
2859}
2860
2861ArgMinNode *Function::createArgMin(llvm::StringRef name, NodeValue input,
2862 unsigned_t axis, bool keepDims,
2863 ElemKind elemTy) {
2864 ShapeVector outDims = reduceDims(input.dims(), {axis}, keepDims);
2865 auto OT = getParent()->uniqueType(elemTy, outDims);
2866 return addNode(new ArgMinNode(name, OT, input, axis, keepDims));
2867}
2868
2869VectorNormNode *Function::createVectorNorm(llvm::StringRef name,
2870 NodeValue input, unsigned_t axis,
2871 unsigned_t p) {
2872 auto outDims = getNewShapeWithoutAxes(input.dims(), axis);
2873 auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), outDims);
2874 const size_t outNumElements = input.getType()->size() / input.dims()[axis];
2875 (void)outNumElements;
2876 assert(outTy->size() == outNumElements &&
2877 "Incorrect number of elements in the output type.");
2878 auto OT = getParent()->uniqueType(*outTy);
2879 return addNode(new VectorNormNode(name, OT, input, axis, p));
2880}
2881
2882CollectRpnProposalsNode *Function::createCollectRpnProposals(
2883 llvm::StringRef name, std::vector<NodeValue> &roisIn,
2884 std::vector<NodeValue> &roiProbsIn, int64_t rpnMaxLevel,
2885 int64_t rpnMinLevel, unsigned_t rpnPostNmsTopN) {
2886
2887 auto boxDim = roisIn[0].dims()[1];
2888
2889 assert(rpnPostNmsTopN > 0 && "RpnPostNmsTopN should be greater than zero");
2890
2891 ShapeVector outDims{rpnPostNmsTopN, boxDim};
2892
2893 auto OT = getParent()->uniqueTypeWithNewShape(roisIn[0].getType(), outDims);
2894 return addNode(new CollectRpnProposalsNode(
2895 name, OT, roisIn, roiProbsIn, rpnMaxLevel, rpnMinLevel, rpnPostNmsTopN));
2896}
2897
2898GatherNode *Function::createGather(llvm::StringRef name, NodeValue data,
2899 NodeValue indices, unsigned_t axis) {
2900 auto dDims = data.dims();
2901 auto iDims = indices.dims();
2902 assert(dDims.size() > axis);
2903 ShapeVector outDims;
2904 outDims.insert(outDims.end(), dDims.begin(), dDims.begin() + axis);
2905 outDims.insert(outDims.end(), iDims.begin(), iDims.end());
2906 outDims.insert(outDims.end(), dDims.begin() + axis + 1, dDims.end());
2907 return addNode(new GatherNode(
2908 name, getParent()->uniqueTypeWithNewShape(data.getType(), outDims), data,
2909 indices, axis));
2910}
2911
2912GatherNDNode *Function::createGatherND(llvm::StringRef name, NodeValue data,
2913 NodeValue indices,
2914 unsigned_t batchDims) {
2915 auto dataDims = data.dims();
2916 auto indicesDims = indices.dims();
2917 size_t indicesDimLast = indicesDims.back();
2918
2919 // Validate the input dimensions.
2920 assert(dataDims.size() >= 1 && "Input data rank must be >= 1.");
2921 assert(indicesDims.size() >= 1 && "Input indices rank must be >= 1.");
2922 for (size_t idx = 0; idx < batchDims; ++idx) {
2923 assert(dataDims[idx] == indicesDims[idx] &&
2924 "Batch dimensions of data and indices must be the same!");
2925 }
2926 assert(batchDims < std::min(dataDims.size(), indicesDims.size()) &&
2927 "The number of batch dimensions must be smaller than both the input "
2928 "data and indices rank!");
2929 assert(indicesDimLast >= 1 &&
2930 "Last dimension of indices must be at least 1!");
2931 assert(indicesDimLast <= (dataDims.size() - batchDims) &&
2932 "Last dimension of indices must be at most rank of data without batch "
2933 "dimensions!");
2934
2935 // Compute the output dimensions.
2936 size_t outRank =
2937 dataDims.size() + indicesDims.size() - indicesDimLast - 1 - batchDims;
2938 std::vector<dim_t> outDims(outRank);
2939 size_t outIdx = 0;
2940 for (size_t idx = 0; idx < batchDims; ++idx) {
2941 outDims[outIdx++] = dataDims[idx];
2942 }
2943 for (size_t idx = batchDims; idx < indicesDims.size() - 1; ++idx) {
2944 outDims[outIdx++] = indicesDims[idx];
2945 }
2946 for (size_t idx = batchDims + indicesDimLast; idx < dataDims.size(); ++idx) {
2947 outDims[outIdx++] = dataDims[idx];
2948 }
2949 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
2950 return addNode(new GatherNDNode(name, outTy, data, indices, batchDims));
2951}
2952
2953GatherElementsNode *Function::createGatherElements(llvm::StringRef name,
2954 NodeValue data,
2955 NodeValue indices,
2956 unsigned_t dim = 0) {
2957 const auto iDims = indices.dims();
2958 const auto dRank = data.dims().size();
2959 const auto iRank = iDims.size();
2960 (void)dRank;
2961 (void)iRank;
2962 assert((dim < 0 ? dim >= -dRank : dim < dRank) &&
2963 "[GatherElements] dim must in the range [-rank, rank-1].");
2964 assert(iRank == dRank &&
2965 "[GatherElements] Data and indices rank must be equal.");
2966 assert(dRank > 0 && "[GatherElements] Data and indices rank must be >= 1.");
2967
2968 return addNode(new GatherElementsNode(
2969 name, getParent()->uniqueTypeWithNewShape(data.getType(), iDims), data,
2970 indices, dim));
2971}
2972
2973GatherRangesNode *Function::createGatherRanges(llvm::StringRef name,
2974 NodeValue data, NodeValue ranges,
2975 unsigned_t maxOutputSize) {
2976 auto numRanges = ranges.dims()[0];
2977 return addNode(new GatherRangesNode(
2978 name,
2979 /*OutputTy=*/
2980 getParent()->uniqueTypeWithNewShape(data.getType(), {maxOutputSize}),
2981 /*LengthsTy=*/
2982 getParent()->uniqueTypeWithNewShape(ranges.getType(), numRanges), data,
2983 ranges));
2984}
2985
2986ScatterDataNode *Function::createScatterData(llvm::StringRef name,
2987 NodeValue data, NodeValue indices,
2988 NodeValue slices,
2989 bool cumulative) {
2990 return addNode(new ScatterDataNode(name, data, indices, slices, cumulative));
2991}
2992
2993BatchOneHotNode *Function::createBatchOneHot(llvm::StringRef name,
2994 NodeValue data, NodeValue lengths,
2995 NodeValue values) {
2996 auto outTy = getParent()->uniqueTypeWithNewShape(
2997 data.getType(), {data.dims()[0], values.dims()[0]});
2998 return addNode(new BatchOneHotNode(name, outTy, data, lengths, values));
2999}
3000
3001SpaceToDepthNode *Function::createSpaceToDepth(llvm::StringRef name,
3002 NodeValue input,
3003 unsigned blockSize) {
3004 assert(blockSize > 0 && "BlockSize must be >= 1.");
3005
3006 auto inputDim = input.dims();
3007 assert(inputDim.size() == 4 && "Dimension size of 4 is expected.");
3008 assert((inputDim[1] % blockSize == 0 && inputDim[2] % blockSize == 0) &&
3009 "Height and Width needs to be multiple of blockSize.");
3010 std::vector<dim_t> newDim = {inputDim[0], inputDim[1] / blockSize,
3011 inputDim[2] / blockSize,
3012 inputDim[3] * blockSize * blockSize};
3013 auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim);
3014 return addNode(new SpaceToDepthNode(name, outTy, input, blockSize));
3015}
3016
3017ReshapeNode *Function::createDepthToSpace(llvm::StringRef name, NodeValue input,
3018 unsigned blockSize, bool isCRD) {
3019 assert(blockSize > 0 && "Block size must be >= 1.");
3020
3021 auto inputDim = input.dims();
3022 assert(inputDim.size() == 4 && "Dimension size of 4 is expected.");
3023 dim_t N = inputDim[0];
3024 dim_t H = inputDim[1];
3025 dim_t W = inputDim[2];
3026 dim_t C = inputDim[3];
3027 assert(C % (blockSize * blockSize) == 0 &&
3028 "Depth should be divisible by block size squared.");
3029
3030 llvm::SmallVector<unsigned_t, 6> shuffle;
3031 llvm::SmallVector<dim_t, 6> tmpShape;
3032 llvm::SmallVector<dim_t, 4> outShape = {N, H * blockSize, W * blockSize,
3033 C / (blockSize * blockSize)};
3034 if (isCRD) {
3035 tmpShape = {N, H, W, C / (blockSize * blockSize), blockSize, blockSize};
3036 shuffle = D2S_CRD;
3037 } else {
3038 tmpShape = {N, H, W, blockSize, blockSize, C / (blockSize * blockSize)};
3039 shuffle = D2S_DCR;
3040 }
3041
3042 auto *RN1 = createReshape(name.str() + "_reshape_in", input, tmpShape);
3043 auto *TN = createTranspose(name.str() + "_transpose", RN1, shuffle);
3044 return createReshape(name.str() + "_reshape_out", TN, outShape);
3045}
3046
3047ResizeNearestNode *Function::createResizeNearest(llvm::StringRef name,
3048 NodeValue input,
3049 llvm::ArrayRef<float> scale) {
3050 auto inputDim = input.dims();
3051 DCHECK_EQ(inputDim.size(), scale.size())
3052 << "Input Dimension size: " << inputDim.size()
3053 << " Scale size: " << scale.size() << " should be same.";
3054
3055 std::vector<dim_t> newDim;
3056
3057 for (size_t i = 0; i < scale.size(); i++) {
3058 auto newD = dim_t(std::floor(inputDim[i] * scale[i]));
3059 DCHECK_GT(newD, 0) << "Scaled dim is " << newD
3060 << ", Scaled value needs to be larger than 0.";
3061 newDim.push_back(newD);
3062 }
3063
3064 auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim);
3065 return addNode(new ResizeNearestNode(name, outTy, input, scale));
3066}
3067
3068ResizeNearestNode *Function::createResizeNearest(llvm::StringRef name,
3069 NodeValue input,
3070 TypeRef outTy) {
3071 auto inputDim = input.dims();
3072 auto outputDim = outTy->dims();
3073 DCHECK_EQ(inputDim.size(), outputDim.size())
3074 << "Input dimension size: " << inputDim.size()
3075 << " output dimension size: " << outputDim.size() << " should be same.";
3076
3077 std::vector<float> scales;
3078 for (size_t i = 0; i < inputDim.size(); i++) {
3079 float scale = (outputDim[i] / (float)inputDim[i]);
3080 DCHECK_GT(scale, 0.0) << "Scale: " << scale
3081 << ", Scale larger than 0 is expected.";
3082 scales.push_back(scale);
3083 }
3084
3085 return addNode(new ResizeNearestNode(name, outTy, input, scales));
3086}
3087
3088ResizeBilinearNode *
3089Function::createResizeBilinear(llvm::StringRef name, NodeValue input,
3090 llvm::ArrayRef<float> scale) {
3091 auto inputDim = input.dims();
3092 DCHECK_EQ(inputDim.size(), scale.size())
3093 << "Input Dimension size: " << inputDim.size()
3094 << " Scale size: " << scale.size() << " should be same.";
3095
3096 std::vector<dim_t> newDim;
3097
3098 for (size_t i = 0; i < scale.size(); i++) {
3099 auto newD = dim_t(std::floor(inputDim[i] * scale[i]));
3100 DCHECK_GT(newD, 0) << "Scaled dim is " << newD
3101 << ", Scaled value needs to be larger than 0.";
3102 newDim.push_back(newD);
3103 }
3104
3105 auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim);
3106 return addNode(new ResizeBilinearNode(name, outTy, input, scale));
3107}
3108
3109ResizeBilinearNode *Function::createResizeBilinear(llvm::StringRef name,
3110 NodeValue input,
3111 TypeRef outTy) {
3112 auto inputDim = input.dims();
3113 auto outputDim = outTy->dims();
3114 DCHECK_EQ(inputDim.size(), outputDim.size())
3115 << "Input dimension size: " << inputDim.size()
3116 << " output dimension size: " << outputDim.size() << " should be same.";
3117
3118 std::vector<float> scales;
3119 for (size_t i = 0; i < inputDim.size(); i++) {
3120 float scale = (outputDim[i] / (float)inputDim[i]);
3121 DCHECK_GT(scale, 0.0) << "Scale: " << scale
3122 << ", Scale larger than 0 is expected.";
3123 scales.push_back(scale);
3124 }
3125
3126 return addNode(new ResizeBilinearNode(name, outTy, input, scales));
3127}
3128
3129QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input,
3130 TypeRef outTy) {
3131 assert(input.getType()->isFPType() && "Input must be a floating type");
3132 assert(outTy->isQuantizedType() && "Output must be a quantized type");
3133 assert(input.dims().equals(outTy->dims()) &&
3134 "Different dimensions for input and output");
3135
3136 return addNode(
3137 new QuantizeNode(name, getParent()->uniqueType(*outTy), input));
3138}
3139
3140QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input,
3141 ElemKind q, float scale,
3142 int32_t offset) {
3143 TypeRef OT = getParent()->uniqueType(q, input.dims(), scale, offset);
3144 return createQuantize(name, input, OT);
3145}
3146
3147DequantizeNode *Function::createDequantize(llvm::StringRef name,
3148 NodeValue input, ElemKind k) {
3149 assert(input.getType()->isQuantizedType() &&
3150 "Input must be a quantized type");
3151 assert(isFloatElemKind(k) && "Result must be float type.");
3152 ShapeVector outShape(input.dims().begin(), input.dims().end());
3153 if (input.getElementType() == ElemKind::UInt8FusedQTy) {
3154 assert(outShape.size() == 2 && "Fused tensors should be 2D");
3155 assert(outShape[1] > 2 * sizeof(float) &&
3156 "Expected space for per-row scale/offset");
3157 outShape[1] -= 2 * sizeof(float);
3158 }
3159 TypeRef outTy = getParent()->uniqueType(Type(k, outShape));
3160 return createDequantize(name, input, outTy);
3161}
3162
3163DequantizeNode *Function::createDequantize(llvm::StringRef name,
3164 NodeValue input, TypeRef outTy) {
3165 assert(input.getType()->isQuantizedType() &&
3166 "Input must be a quantized type");
3167 assert(outTy->isFPType() && "Output should be an FP type");
3168 return addNode(new DequantizeNode(name, outTy, input));
3169}
3170
3171RescaleQuantizedNode *Function::createRescaleQuantized(llvm::StringRef name,
3172 NodeValue input,
3173 TypeRef outTy) {
3174 assert(input.getType()->isQuantizedType() &&
3175 "Input must be a quantized type");
3176 assert(outTy->isQuantizedType() && "Output must be a quantized type");
3177 assert(input.dims().equals(outTy->dims()) &&
3178 "Different dimensions for input and output");
3179
3180 return addNode(
3181 new RescaleQuantizedNode(name, getParent()->uniqueType(*outTy), input));
3182}
3183
3184Node *Function::createWeightedSum(llvm::StringRef name,
3185 llvm::ArrayRef<NodeValue> data,
3186 llvm::ArrayRef<NodeValue> weights) {
3187 assert(data.size() == weights.size() &&
3188 "Must have same number of data and weights.");
3189 assert(data.size() > 0 && "No inputs provided.");
3190
3191 const auto *outTy = data[0].getType();
3192
3193 // Create a zero splat to bootstrap the adding chain.
3194 Node *currAdd = createSplat(name.str() + ".splat", outTy, 0.);
3195
3196 for (size_t i = 0, e = data.size(); i < e; i++) {
3197 assert(weights[i].getType()->size() == 1 &&
3198 "Each provided weight node must be size 1.");
3199 assert(outTy == data[i].getType() &&
3200 "All data nodes must have the same type.");
3201
3202 // Broadcast the current weight to same shape as the data.
3203 auto *bcastW =
3204 createBroadcast(name.str() + ".bcastWeight" + std::to_string(i),
3205 weights[i], outTy->dims(), /* axis */ 0);
3206
3207 // Element-wise multiply the broadcasted weight by the data.
3208 auto *scaledD =
3209 createMul(name.str() + ".mul" + std::to_string(i), bcastW, data[i]);
3210
3211 // Element-wise add the scaled data to the running total.
3212 currAdd =
3213 createAdd(name.str() + ".add" + std::to_string(i), scaledD, currAdd);
3214 }
3215
3216 // Return the final weighted sum via the last add in the chain.
3217 return currAdd;
3218}
3219
3220Node *Function::createBatchBoxCox(llvm::StringRef name, NodeValue data,
3221 NodeValue lambda1, NodeValue lambda2,
3222 float epsilon) {
3223 assert((lambda1.dims() == lambda2.dims()) &&
3224 "lambda1 and lambda2 must have the same shape");
3225 assert((lambda1.getType()->getElementType() == lambda2.getElementType()) &&
3226 "lambda1 and lambda2 must have the same element type");
3227 assert((lambda1.getType()->getElementType() == data.getElementType()) &&
3228 "data and lambdas must have the same element type");
3229 assert((lambda1.dims().size() == 1) && "lambda1 and lambda2 must be vectors");
3230 assert((data.dims().size() == 2) && "data must be a matrix");
3231 assert((data.dims()[1] == lambda1.dims()[0]) &&
3232 "data, lambda1 and lambda2 must have the same number of rows");
3233
3234 return addNode(new BatchBoxCoxNode(name, data, lambda1, lambda2, epsilon));
3235}
3236
3237ClipNode *Function::createClip(llvm::StringRef name, NodeValue input,
3238 TypeRef outTy, float min, float max) {
3239 return addNode(new ClipNode(name, outTy, input, min, max));
3240}
3241
3242ClipNode *Function::createClip(llvm::StringRef name, NodeValue input, float min,
3243 float max) {
3244 return addNode(new ClipNode(name, input.getType(), input, min, max));
3245}
3246
3247ClipNode *Function::createClipMinMaxFP16(llvm::StringRef name,
3248 NodeValue input) {
3249 return createClip(name, input, kMinFP16, kMaxFP16);
3250}
3251
3252ClipNode *Function::createClipMinMaxBFloat16(llvm::StringRef name,
3253 NodeValue input) {
3254 constexpr float bfloat16Min = FLT_MIN;
3255 constexpr float bfloat16Max = FLT_MAX;
3256 return createClip(name, input, bfloat16Min, bfloat16Max);
3257}
3258
3259BatchedUnaryEmbeddingsBagsNode *Function::createBatchedUnaryEmbeddingsBags(
3260 llvm::StringRef name, NodeValue weights, NodeValue tableOffsets,
3261 NodeValue indices, NodeValue offsets) {
3262 auto inDims = weights.dims();
3263 ShapeVector outDims(inDims.begin(), inDims.end());
3264 outDims[0] = weights.dims()[0];
3265 outDims[2] = tableOffsets.dims()[0] - 1;
3266 outDims[1] = (offsets.dims()[0] - 1) / outDims[2];
3267
3268 auto outTy = getParent()->uniqueTypeWithNewShape(weights.getType(), outDims);
3269 return addNode(new BatchedUnaryEmbeddingsBagsNode(
3270 name, outTy, weights, tableOffsets, offsets, indices));
3271}
3272
3273IntNBitSplitEmbeddingBagsNode *Function::createIntNBitSplitEmbeddingBags(
3274 llvm::StringRef name, NodeValue devWeights, NodeValue uvmWeights,
3275 NodeValue weightsPlacements, NodeValue weightsOffsets, NodeValue weightsTys,
3276 NodeValue dimOffsets, int64_t totalDims, NodeValue indices,
3277 NodeValue offsets, SplitEmbeddingPoolingMode poolingMode,
3278 SplitEmbeddingSparseType outputDtype) {
3279 ShapeVector outDims;
3280 outDims.insert(outDims.end(),
3281 (offsets.dims()[0] - 1) / (dimOffsets.dims()[0] - 1));
3282 int64_t totalSize = outputDtype == SplitEmbeddingSparseType::EST_FLOAT
3283 ? totalDims * sizeof(float)
3284 : totalDims * sizeof(float16_t);
3285 outDims.insert(outDims.end(), totalSize);
3286
3287 auto outTy =
3288 getParent()->uniqueTypeWithNewShape(devWeights.getType(), outDims);
3289
3290 return addNode(new IntNBitSplitEmbeddingBagsNode(
3291 name, outTy, devWeights, uvmWeights, weightsPlacements, weightsOffsets,
3292 weightsTys, dimOffsets, indices, offsets, totalDims, poolingMode,
3293 outputDtype));
3294}
3295
3296IntNBitSplitEmbeddingWeightedBagsNode *
3297Function::createIntNBitSplitEmbeddingWeightedBags(
3298 llvm::StringRef name, NodeValue devWeights, NodeValue uvmWeights,
3299 NodeValue weightsPlacements, NodeValue weightsOffsets, NodeValue weightsTys,
3300 NodeValue dimOffsets, int64_t totalDims, NodeValue indices,
3301 NodeValue offsets, SplitEmbeddingPoolingMode poolingMode,
3302 SplitEmbeddingSparseType outputDtype, NodeValue indiceWeights) {
3303 ShapeVector outDims;
3304 outDims.insert(outDims.end(),
3305 (offsets.dims()[0] - 1) / (dimOffsets.dims()[0] - 1));
3306 int64_t totalSize = outputDtype == SplitEmbeddingSparseType::EST_FLOAT
3307 ? totalDims * sizeof(float)
3308 : totalDims * sizeof(float16_t);
3309 outDims.insert(outDims.end(), totalSize);
3310
3311 auto outTy =
3312 getParent()->uniqueTypeWithNewShape(devWeights.getType(), outDims);
3313
3314 return addNode(new IntNBitSplitEmbeddingWeightedBagsNode(
3315 name, outTy, devWeights, uvmWeights, weightsPlacements, weightsOffsets,
3316 weightsTys, dimOffsets, indices, offsets, indiceWeights, totalDims,
3317 poolingMode, outputDtype));
3318}
3319
3320//===----------------------------------------------------------------------===//
3321// Placeholder-builder methods.
3322//===----------------------------------------------------------------------===//
3323
3324BatchNormalizationNode *Function::createBatchNormalization(
3325 PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
3326 unsigned_t channelIdx, float epsilon, float momentum) {
3327 // Figure out how many channels are in the tensor.
3328 dim_t channels = input.dims()[channelIdx];
3329
3330 ElemKind inputTy = input.getType()->getElementType();
3331
3332 // Allocate the learnable parameters beta and gamma.
3333 auto *beta =
3334 getParent()->createPlaceholder(inputTy, {channels}, "beta", true);
3335 bindings.allocate(beta)->init(Tensor::InitKind::Broadcast, 0.1, getPRNG());
3336
3337 auto *scale =
3338 getParent()->createPlaceholder(inputTy, {channels}, "scale", true);
3339 bindings.allocate(scale)->init(Tensor::InitKind::Broadcast, 0.001, getPRNG());
3340
3341 auto *mean =
3342 getParent()->createPlaceholder(inputTy, {channels}, "mean", false);
3343 bindings.allocate(mean)->zero();
3344
3345 auto *variance =
3346 getParent()->createPlaceholder(inputTy, {channels}, "variance", false);
3347 bindings.allocate(variance)->init(Tensor::InitKind::Broadcast, 1.0,
3348 getPRNG());
3349
3350 auto resultType = getParent()->uniqueType(inputTy, input.dims());
3351
3352 return createBatchNormalization(name, resultType, input, beta, scale, mean,
3353 variance, channelIdx, epsilon, momentum);
3354}
3355
3356ConvolutionNode *Function::createConv(
3357 PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
3358 dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels,
3359 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
3360 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation,
3361 ConvolutionLayout layout) {
3362 ShapeNHWC idim = ShapeNHWC(input.dims());
3363 ShapeHW kdim(kernels);
3364 PaddingTLBR pdim(pads);
3365 (void)pdim;
3366 assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
3367 (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
3368 "buffer too small for selected stride");
3369
3370 assert(group > 0 && "group should be larger than 0");
3371 assert(idim.c % group == 0 && "channels number must be divisible by groups");
3372 assert(outChannels % group == 0 && "outChannels must be divisible by groups");
3373
3374 // Calculate the size and allocate the output buffer.
3375 auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides,
3376 pads, dilation);
3377
3378 std::array<dim_t, 4> outDims = {
3379 {idim.n, outSz.first, outSz.second, outChannels}};
3380
3381 // Allocate the Filter and Bias tensors.
3382 std::array<dim_t, 4> filterDim = {
3383 {outChannels, kdim.height, kdim.width, idim.c / group}};
3384 size_t fanIn = kdim.height * kdim.width * idim.c;
3385 ElemKind inputTy = input.getType()->getElementType();
3386 assert(isFloatElemKind(inputTy) && "Convolution on non-floating point type?");
3387 auto *filter =
3388 getParent()->createPlaceholder(inputTy, filterDim, "filter", true);
3389 bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn,
3390 getPRNG());
3391
3392 auto *bias =
3393 getParent()->createPlaceholder(inputTy, {outChannels}, "bias", true);
3394 bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1,
3395 getPRNG());
3396
3397 auto OT = getParent()->uniqueType(inputTy, outDims);
3398
3399 return addNode(new ConvolutionNode(name, OT, input, filter, bias, kernels,
3400 strides, pads, group, dilation, layout,
3401 FusedActivation::NONE, {}));
3402}
3403
3404ConvolutionNode *Function::createConv(PlaceholderBindings &bindings,
3405 llvm::StringRef name, NodeValue input,
3406 dim_t outChannels, unsigned_t kernel,
3407 unsigned_t stride, unsigned_t pad,
3408 unsigned_t group,
3409 llvm::ArrayRef<unsigned_t> dilation,
3410 ConvolutionLayout layout) {
3411 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
3412 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
3413 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
3414 return createConv(bindings, name, input, outChannels, kernels, strides, pads,
3415 group, dilation, layout);
3416}
3417
3418Convolution3DNode *Function::createConv3D(PlaceholderBindings &bindings,
3419 llvm::StringRef name, NodeValue input,
3420 dim_t outChannels,
3421 llvm::ArrayRef<unsigned_t> kernels,
3422 llvm::ArrayRef<unsigned_t> strides,
3423 llvm::ArrayRef<unsigned_t> pads,
3424 unsigned_t group) {
3425 ShapeNTHWC idim(input.dims());
3426 ShapeTHW kdim(kernels);
3427
3428 assert(group > 0 && "group should be larger than 0");
3429 assert(idim.c % group == 0 && "channels number must be divisible by groups");
3430 assert(outChannels % group == 0 && "outChannels must be divisible by groups");
3431
3432 // Calculate the size and allocate the output buffer.
3433 auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels,
3434 strides, pads);
3435
3436 std::array<dim_t, 5> outDims = {
3437 {idim.n, outSz.temporal_frames, outSz.height, outSz.width, outChannels}};
3438
3439 // Allocate the Filter and Bias tensors.
3440 std::array<dim_t, 5> filterDim = {{outChannels, kdim.temporal_frames,
3441 kdim.height, kdim.width, idim.c / group}};
3442
3443 dim_t fanIn = kdim.temporal_frames * kdim.height * kdim.width * idim.c;
3444 ElemKind inputTy = input.getType()->getElementType();
3445 assert(isFloatElemKind(inputTy) &&
3446 "Convolution3D on non-floating point type?");
3447 auto *filter =
3448 getParent()->createPlaceholder(inputTy, filterDim, "filter", true);
3449 bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn,
3450 getPRNG());
3451
3452 auto *bias =
3453 getParent()->createPlaceholder(inputTy, {outChannels}, "bias", true);
3454 bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1,
3455 getPRNG());
3456
3457 auto OT = getParent()->uniqueType(inputTy, outDims);
3458
3459 assertConv3DDims(input, filter, bias, kernels, strides, pads, group);
3460
3461 return addNode(new Convolution3DNode(name, OT, input, filter, bias, kernels,
3462 strides, pads, group));
3463}
3464
3465Convolution3DNode *Function::createConv3D(PlaceholderBindings &bindings,
3466 llvm::StringRef name, NodeValue input,
3467 size_t outChannels, unsigned_t kernel,
3468 unsigned_t stride, unsigned_t pad,
3469 unsigned_t group) {
3470 llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad};
3471 llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride};
3472 llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel};
3473 return createConv3D(bindings, name, input, outChannels, kernels, strides,
3474 pads, group);
3475}
3476
3477ChannelwiseQuantizedConvolutionNode *Function::createChannelwiseQuantizedConv(
3478 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
3479 NodeValue filterScales, NodeValue filterOffsets, NodeValue biasScales,
3480 NodeValue biasOffsets, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
3481 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
3482 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation, bool quantizeFilter,
3483 bool quantizeBias, quantization::Schema schema, ElemKind filterElemQTy,
3484 ElemKind biasElemQTy) {
3485
3486 // Validate dimensions.
3487 bool isConv3D = (input.getType()->dims().size() == 5);
3488 if (isConv3D) {
3489 assertConv3DDims(input, filter, bias, kernels, strides, pads, group);
3490 } else {
3491 assertConvDims(input, filter, bias, kernels, strides, pads, group);
3492 }
3493
3494 // Validate bias precision.
3495 auto biasElemKind = bias.getElementType();
3496 DCHECK(biasElemKind == ElemKind::Int8QTy ||
3497 biasElemKind == ElemKind::Int32QTy ||
3498 biasElemKind == ElemKind::FloatTy)
3499 << "Unsupported element type for ChannelwiseQuantizedConvolution bias: "
3500 << Type::getElementName(biasElemKind).str();
3501
3502 // Validate filter precision.
3503 auto filterElemKind = filter.getElementType();
3504 DCHECK(filterElemKind == ElemKind::Int8QTy ||
3505 filterElemKind == ElemKind::FloatTy)
3506 << "Unsupported element type for ChannelwiseQuantizedConvolution "
3507 << "filter: " << Type::getElementName(filterElemKind).str();
3508
3509 DCHECK(!filterScales.getNode() || dyn_cast<Constant>(filterScales.getNode()))
3510 << "Filter scales input to ChannelwiseQuantizedConvolutionNode must be "
3511 "null or Constant";
3512
3513 DCHECK(!filterOffsets.getNode() ||
3514 dyn_cast<Constant>(filterOffsets.getNode()))
3515 << "Filter offsets input to ChannelwiseQuantizedConvolutionNode must be "
3516 "null or Constant";
3517
3518 DCHECK(!biasScales.getNode() || dyn_cast<Constant>(biasScales.getNode()))
3519 << "Bias scales input to ChannelwiseQuantizedConvolutionNode must be "
3520 "null or Constant";
3521
3522 DCHECK(!biasOffsets.getNode() || dyn_cast<Constant>(biasOffsets.getNode()))
3523 << "Bias offsets input to ChannelwiseQuantizedConvolutionNode must be "
3524 "null or Constant";
3525
3526 // Number of output channels.
3527 dim_t numChannels = outTy->dims().back();
3528 dim_t qDim = 0;
3529 dim_t qStep = 1;
3530
3531 // Whether filter/bias quantization parameters were explicitly provided.
3532 bool filterQParamsGiven = filterScales.getNode() && filterOffsets.getNode();
3533 bool biasQParamsGiven = biasScales.getNode() && biasOffsets.getNode();
3534
3535 // If input filter is FLOAT and filterScales/filterOffsets are NOT provided
3536 // then compute them automatically for given schema and filterElemQTy.
3537 // If input filter is QUANTIZED then filterScales/filterOffsets are mandatory.
3538 if (!filterQParamsGiven) {
3539 DCHECK(filterElemKind == ElemKind::FloatTy)
3540 << "ChannelwiseQuantizedConvolution: If the input filter is "
3541 << "quantized then the filter scales/offsets must be provided!";
3542 Constant *filterC = dyn_cast<Constant>(filter.getNode());
3543 DCHECK(filterC)
3544 << "Filter input to ChannelwiseQuantizedConvolutionNode must be a "
3545 "Constant to quantize it";
3546 Constant *filterScalesC = getParent()->createConstant(
3547 ElemKind::FloatTy, {numChannels}, "filterScales");
3548 Constant *filterOffsetsC = getParent()->createConstant(
3549 ElemKind::Int32ITy, {numChannels}, "filterOffsets");
3550 // Get filter channelwise TensorQuantizationParams.
3551 quantization::getTensorQuantizationParams(
3552 filterC->getPayload(), filterScalesC->getPayloadMutable(),
3553 filterOffsetsC->getPayloadMutable(), schema, filterElemQTy, qDim,
3554 qStep);
3555 filterScales = NodeValue(filterScalesC);
3556 filterOffsets = NodeValue(filterOffsetsC);
3557 }
3558
3559 // If input bias is FLOAT and biasScales/biasOffsets are NOT provided
3560 // then compute them automatically for given schema and biasElemQTy.
3561 // If input bias is QUANTIZED and biasScales/biasOffsets are NOT provided
3562 // then assume the channel wise quantization parameters are implicitly:
3563 // biasScales[i] = inputScale * filterScales[i] and biasOffsets[i] = 0.
3564 if (!biasQParamsGiven) {
3565 Constant *biasC = dyn_cast<Constant>(bias.getNode());
3566 DCHECK(biasC)
3567 << "Bias input to ChannelwiseQuantizedConvolutionNode must be a "
3568 "Constant to quantize it";
3569 Constant *biasScalesC = getParent()->createConstant(
3570 ElemKind::FloatTy, {numChannels}, "biasScales");
3571 Constant *biasOffsetsC = getParent()->createConstant(
3572 ElemKind::Int32ITy, {numChannels}, "biasOffsets");
3573 auto biasScalesH = biasScalesC->getPayload().getHandle<float>();
3574 auto biasOffsetsH = biasOffsetsC->getPayload().getHandle<int32_t>();
3575 Constant *filterScalesC = dyn_cast<Constant>(filterScales.getNode());
3576 Constant *filterOffsetsC = dyn_cast<Constant>(filterOffsets.getNode());
3577 auto filterScalesH = filterScalesC->getPayload().getHandle<float>();
3578 auto filterOffsetsH = filterOffsetsC->getPayload().getHandle<int32_t>();
3579 auto inputScale = input.getType()->getScale();
3580 auto inputOffset = input.getType()->getOffset();
3581 if (biasElemKind == ElemKind::FloatTy) {
3582 auto biasH = biasC->getPayload().getHandle<float>();
3583 // Get bias channelwise TensorQuantizationParams.
3584 quantization::getTensorQuantizationParams(
3585 biasC->getPayload(), biasScalesC->getPayloadMutable(),
3586 biasOffsetsC->getPayloadMutable(), schema, biasElemQTy, qDim, qStep);
3587 // Specialize the bias channelwise TensorQuantizationParams.
3588 for (dim_t idx = 0; idx < numChannels; idx++) {
3589 bool biasZero = (biasH.raw(idx) == 0.f);
3590 TensorQuantizationParams biasTQP = {biasScalesH.raw(idx),
3591 biasOffsetsH.raw(idx)};
3592 TensorQuantizationParams inputTQP = {inputScale, inputOffset};
3593 TensorQuantizationParams filterTQP = {filterScalesH.raw(idx),
3594 filterOffsetsH.raw(idx)};
3595 if (filterQParamsGiven) {
3596 // Specialize only bias quantization parameters.
3597 biasTQP = specializeBiasQuantizationParams(
3598 biasTQP, inputTQP, filterTQP, schema, biasElemQTy, biasZero);
3599 } else {
3600 // Specialize bias and weights quantization parameters.
3601 specializeBiasWeightsQuantizationParams(
3602 biasTQP, inputTQP, filterTQP, schema, biasElemQTy, biasZero);
3603 }
3604 biasScalesH.raw(idx) = biasTQP.scale;
3605 biasOffsetsH.raw(idx) = biasTQP.offset;
3606 filterScalesH.raw(idx) = filterTQP.scale;
3607 filterOffsetsH.raw(idx) = filterTQP.offset;
3608 }
3609 } else {
3610 // Set implicit bias channelwise TensorQuantizationParams.
3611 for (dim_t idx = 0; idx < numChannels; idx++) {
3612 float filterScale = filterScalesH.raw(idx);
3613 biasScalesH.raw(idx) = inputScale * filterScale;
3614 biasOffsetsH.raw(idx) = 0;
3615 }
3616 }
3617 biasScales = NodeValue(biasScalesC);
3618 biasOffsets = NodeValue(biasOffsetsC);
3619 }
3620
3621 // If input filter is FLOAT then quantize channel wise to filterElemQTy.
3622 if (quantizeFilter && filterElemKind == ElemKind::FloatTy) {
3623 Constant *filterC = dyn_cast<Constant>(filter.getNode());
3624 DCHECK(filterC)
3625 << "Filter input to ChannelwiseQuantizedConvolutionNode must be a "
3626 "Constant to quantize it";
3627 Constant *filterCQ = getParent()->createConstant(
3628 filterElemQTy, filterC->getType()->dims(), 1.0, 0, "filter");
3629 Constant *filterScalesC = dyn_cast<Constant>(filterScales.getNode());
3630 Constant *filterOffsetsC = dyn_cast<Constant>(filterOffsets.getNode());
3631 // Quantize filter channelwise.
3632 filterCQ->getPayloadMutable() = quantization::quantizeTensor(
3633 filterC->getPayload(), filterScalesC->getPayload(),
3634 filterOffsetsC->getPayload(), filterElemQTy, qDim, qStep);
3635 filter = NodeValue(filterCQ);
3636 }
3637
3638 // If input bias is FLOAT then quantize channel wise to biasElemQTy.
3639 if (quantizeBias && biasElemKind == ElemKind::FloatTy) {
3640 Constant *biasC = dyn_cast<Constant>(bias.getNode());
3641 DCHECK(biasC)
3642 << "Bias input to ChannelwiseQuantizedConvolutionNode must be a "
3643 "Constant to quantize it";
3644 Constant *biasCQ = getParent()->createConstant(
3645 biasElemQTy, biasC->getType()->dims(), 1.0, 0, "bias");
3646 Constant *biasScalesC = dyn_cast<Constant>(biasScales.getNode());
3647 Constant *biasOffsetsC = dyn_cast<Constant>(biasOffsets.getNode());
3648 // Quantize bias channelwise.
3649 biasCQ->getPayloadMutable() = quantization::quantizeTensor(
3650 biasC->getPayload(), biasScalesC->getPayload(),
3651 biasOffsetsC->getPayload(), biasElemQTy, qDim, qStep);
3652 bias = NodeValue(biasCQ);
3653 }
3654
3655 auto OT = getParent()->uniqueType(*outTy);
3656 return addNode(new ChannelwiseQuantizedConvolutionNode(
3657 name, OT, input, filter, bias, filterScales, filterOffsets, biasScales,
3658 biasOffsets, kernels, strides, pads, group, dilation,
3659 FusedActivation::NONE, {}));
3660}
3661
3662ConvTransposeNode *Function::createConvTranspose(
3663 PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
3664 dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels,
3665 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
3666 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation) {
3667 ShapeNHWC idim = ShapeNHWC(input.dims());
3668 ShapeHW kdim(kernels);
3669 PaddingTLBR pdim(pads);
3670 (void)pdim;
3671 assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
3672 (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
3673 "buffer too small for selected stride");
3674
3675 assert(group > 0 && "group should be larger than 0");
3676 assert(idim.c % group == 0 && "channels number must be divisible by groups");
3677 assert(outChannels % group == 0 && "outChannels must be divisible by groups");
3678
3679 // Calculate the size and allocate the output buffer.
3680 auto outSz = calculateConvTransposeOutputDims(idim.h, idim.w, kernels,
3681 strides, pads, dilation);
3682
3683 std::array<dim_t, 4> outDims = {
3684 {idim.n, outSz.first, outSz.second, outChannels}};
3685
3686 // Allocate the Filter and Bias tensors.
3687 std::array<dim_t, 4> filterDim = {
3688 {outChannels, kdim.height, kdim.width, idim.c / group}};
3689 size_t fanIn = kdim.height * kdim.width * idim.c;
3690 ElemKind inputTy = input.getType()->getElementType();
3691 assert((inputTy == ElemKind::FloatTy || inputTy == ElemKind::Float16Ty) &&
3692 "Convolution on non-floating point type?");
3693 auto *filter =
3694 getParent()->createPlaceholder(inputTy, filterDim, "filter", true);
3695
3696 auto *bias =
3697 getParent()->createPlaceholder(inputTy, {outChannels}, "bias", true);
3698 bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1,
3699 getPRNG());
3700
3701 bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn,
3702 getPRNG());
3703
3704 auto OT = getParent()->uniqueType(inputTy, outDims);
3705
3706 return addNode(new ConvTransposeNode(name, OT, input, filter, bias, kernels,
3707 strides, pads, group, dilation));
3708}
3709
3710ConvTransposeNode *Function::createConvTranspose(
3711 PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
3712 dim_t outChannels, unsigned_t kernel, unsigned_t stride, unsigned_t pad,
3713 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation) {
3714 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
3715 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
3716 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
3717 return createConvTranspose(bindings, name, input, outChannels, kernels,
3718 strides, pads, group, dilation);
3719}
3720
3721ConvertToNode *Function::createConvertTo(llvm::StringRef name, NodeValue input,
3722 TypeRef outTy) {
3723 return addNode(new ConvertToNode(name, outTy, input));
3724}
3725
3726ConvertToNode *Function::createConvertTo(llvm::StringRef name, NodeValue input,
3727 ElemKind k) {
3728 auto OT = getParent()->uniqueType(k, input.dims());
3729 return addNode(new ConvertToNode(name, OT, input));
3730}
3731
3732FullyConnectedNode *
3733Function::createFullyConnected(PlaceholderBindings &bindings,
3734 llvm::StringRef name, NodeValue input,
3735 dim_t outDepth, unsigned_t axis) {
3736 const ElemKind k = input.getType()->getElementType();
3737
3738 // FC always uses 2D input; flatten if necessary.
3739 if (input.dims().size() != 2) {
3740 input = createFlatten(name.str() + ".reshape2D", input, axis);
3741 }
3742 auto *W = getParent()->createPlaceholder(k, {input.dims()[1], outDepth},
3743 "weights", true);
3744 auto *B = getParent()->createPlaceholder(k, {outDepth}, "bias", true);
3745
3746 bindings.allocate(W)->init(Tensor::InitKind::Xavier, input.dims()[1],
3747 getPRNG());
3748 bindings.allocate(B)->init(Tensor::InitKind::Broadcast, .1, getPRNG());
3749
3750 auto OT = getParent()->uniqueType(k, {input.dims()[0], outDepth});
3751 return createFullyConnected(name, input, W, B, OT, axis);
3752}
3753
3754Node *Function::createDotProduct(llvm::StringRef name, NodeValue X,
3755 NodeValue Y) {
3756 auto XDimsSize = X.dims().size();
3757 (void)XDimsSize;
3758
3759 assert(X.dims() == Y.dims() && "X and Y must have the same shape");
3760 assert(((XDimsSize == 1) || (XDimsSize == 2)) && "X and Y must be 1D or 2D");
3761
3762 // Create Mul node.
3763 auto *MN = createMul(name.str() + ".mul", X, Y);
3764
3765 // If X and Y are 1D, the BatchedReduceAdd node is not needed.
3766 if (XDimsSize == 1) {
3767 return MN;
3768 }
3769
3770 // Create and return BatchedReduceAdd node.
3771 return createBatchedReduceAdd(name.str() + ".bra", MN, 1);
3772}
3773
3774BatchedPairwiseDotProductNode *
3775Function::createBatchedPairwiseDotProduct(llvm::StringRef name,
3776 llvm::ArrayRef<NodeValue> inputs) {
3777 assert(!inputs.empty());
3778 dim_t batchCount = inputs[0].getType()->dims()[0];
3779 dim_t numPairs = inputs.size() * (inputs.size() - 1) / 2;
3780 auto *outTy = getParent()->uniqueTypeWithNewShape(inputs[0].getType(),
3781 {batchCount, numPairs});
3782
3783 return addNode(new BatchedPairwiseDotProductNode(name, outTy, inputs));
3784}
3785
3786Node *Function::createElementwiseLinear(llvm::StringRef name, NodeValue X,
3787 NodeValue w, NodeValue b,
3788 unsigned axis) {
3789 auto XDims = X.dims();
3790 auto wDims = w.dims();
3791 auto bDims = b.dims();
3792
3793 // Suppress release mode unused variable warnings.
3794 (void)wDims;
3795 (void)bDims;
3796
3797 // Check that the inputs are sensible.
3798 assert(XDims.size() == 2 && "X must be 2D");
3799 assert((axis == 0 || axis == 1) && "axis must be 0 or 1");
3800 assert(wDims.size() == 1 && "w must be 1D");
3801 assert(bDims.size() == 1 && "b must be 1D");
3802 assert(wDims[0] == XDims[axis] &&
3803 "size of w must match input dimension of X");
3804 assert(bDims[0] == XDims[axis] &&
3805 "size of b must match input dimension of X");
3806
3807 // Broadcast w and b so that they have the same dimensions as X.
3808 auto *broadcastW =
3809 createBroadcast(name.str() + ".broadcastW", w, XDims, axis);
3810 auto *broadcastB =
3811 createBroadcast(name.str() + ".broadcastB", b, XDims, axis);
3812
3813 // Implement the elementwise linear operation by multiplying X elementwise
3814 // with broadcasted w and adding broadcasted b elementwise.
3815 auto *wX = createMul(name.str() + ".mul", broadcastW, X);
3816 auto *out = createAdd(name.str() + ".add", wX, broadcastB);
3817
3818 return out;
3819}
3820
3821void Function::createGRU(PlaceholderBindings &bindings,
3822 llvm::StringRef namePrefix,
3823 llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
3824 unsigned hiddenSize, unsigned outputSize,
3825 std::vector<NodeValue> &outputs) {
3826 std::string nameBase = namePrefix.str();
3827 const unsigned timeSteps = inputs.size();
3828 assert(timeSteps > 0 && "empty input");
3829 const unsigned inputSize = inputs.front().dims().back();
3830 assert(inputSize > 0 && "input dimensionality is zero");
3831
3832 // Initialize the state to zero.
3833 Placeholder *HInit = getParent()->createPlaceholder(
3834 ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_state", false);
3835 bindings.allocate(HInit)->zero();
3836 Node *Ht = HInit;
3837
3838 // Update gate:
3839 // Z <- sigmoid(Wxz * x + Whz * h + bz)
3840 // Reset gate:
3841 // R <- sigmoid(Wxr * x + Whr * h + br)
3842 // Hidden state:
3843 // h <- Z . h + (1 - Z) tanh (Wxh * x + Whh * (R . h) + bh)
3844
3845 // update gate
3846 float bUpdate = 0.1;
3847 Placeholder *Wxz = getParent()->createPlaceholder(
3848 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxz", true);
3849 Placeholder *Whz = getParent()->createPlaceholder(
3850 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whz", true);
3851 Placeholder *Bz1 = getParent()->createPlaceholder(
3852 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz1", true);
3853 Placeholder *Bz2 = getParent()->createPlaceholder(
3854 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz2", true);
3855
3856 bindings.allocate(Wxz)->init(glow::Tensor::InitKind::Xavier, inputSize,
3857 getPRNG());
3858 bindings.allocate(Whz)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3859 getPRNG());
3860 bindings.allocate(Bz1)->init(glow::Tensor::InitKind::Broadcast, bUpdate,
3861 getPRNG());
3862 bindings.allocate(Bz2)->init(glow::Tensor::InitKind::Broadcast, bUpdate,
3863 getPRNG());
3864
3865 // Reset gate.
3866 float bReset = -1.0;
3867 Placeholder *Wxr = getParent()->createPlaceholder(
3868 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxr", true);
3869 Placeholder *Whr = getParent()->createPlaceholder(
3870 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whr", true);
3871 Placeholder *Br1 = getParent()->createPlaceholder(
3872 ElemKind::FloatTy, {hiddenSize}, nameBase + ".br1", true);
3873 Placeholder *Br2 = getParent()->createPlaceholder(
3874 ElemKind::FloatTy, {hiddenSize}, nameBase + ".br2", true);
3875
3876 bindings.allocate(Wxr)->init(glow::Tensor::InitKind::Xavier, inputSize,
3877 getPRNG());
3878 bindings.allocate(Whr)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3879 getPRNG());
3880 bindings.allocate(Br1)->init(glow::Tensor::InitKind::Broadcast, bReset,
3881 getPRNG());
3882 bindings.allocate(Br2)->init(glow::Tensor::InitKind::Broadcast, bReset,
3883 getPRNG());
3884
3885 // hidden state
3886 float b = 0.1;
3887 Placeholder *Wxh = getParent()->createPlaceholder(
3888 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh", true);
3889 Placeholder *Whh = getParent()->createPlaceholder(
3890 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh", true);
3891 Placeholder *Bh1 = getParent()->createPlaceholder(
3892 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh1", true);
3893 Placeholder *Bh2 = getParent()->createPlaceholder(
3894 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh2", true);
3895
3896 bindings.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize,
3897 getPRNG());
3898 bindings.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3899 getPRNG());
3900 bindings.allocate(Bh1)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3901 bindings.allocate(Bh2)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3902
3903 // Output Layer.
3904 Placeholder *Why = getParent()->createPlaceholder(
3905 ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true);
3906 Placeholder *By = getParent()->createPlaceholder(
3907 ElemKind::FloatTy, {outputSize}, nameBase + ".by", true);
3908
3909 bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3910 getPRNG());
3911 bindings.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3912
3913 auto ty = getParent()->uniqueType(ElemKind::FloatTy, {batchSize, hiddenSize});
3914 auto *Ones = createSplat(nameBase + ".ones", ty, 1.0);
3915
3916 std::vector<Node *> outputNodes;
3917 for (unsigned t = 0; t < timeSteps; t++) {
3918 auto fc1Name = nameBase + ".fc1." + std::to_string(t);
3919 auto fc2Name = nameBase + ".fc2." + std::to_string(t);
3920 auto add1Name = nameBase + ".add1." + std::to_string(t);
3921 auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t);
3922
3923 auto *Zt = createSigmoid(
3924 sigmoid1Name,
3925 createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whz, Bz1),
3926 createFullyConnected(fc2Name, inputs[t], Wxz, Bz2)));
3927
3928 auto fc3Name = nameBase + ".fc3." + std::to_string(t);
3929 auto fc4Name = nameBase + ".fc4." + std::to_string(t);
3930 auto add2Name = nameBase + ".add2." + std::to_string(t);
3931 auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t);
3932
3933 auto *Rt = createSigmoid(
3934 sigmoid2Name,
3935 createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whr, Br1),
3936 createFullyConnected(fc4Name, inputs[t], Wxr, Br2)));
3937
3938 auto zhtName = nameBase + ".zh." + std::to_string(t);
3939 auto *ZHt = createMul(zhtName, Zt, Ht);
3940
3941 auto oneMinusZtName = nameBase + ".1-z." + std::to_string(t);
3942 auto *OneMinusZt = createSub(oneMinusZtName, Ones, Zt);
3943
3944 auto rhtName = nameBase + ".rh." + std::to_string(t);
3945 auto *RHt = createMul(rhtName, Rt, Ht);
3946
3947 auto fc5Name = nameBase + ".fc5." + std::to_string(t);
3948 auto fc6Name = nameBase + ".fc6." + std::to_string(t);
3949 auto add3Name = nameBase + ".add3." + std::to_string(t);
3950 auto tanh1Name = nameBase + ".tanh1." + std::to_string(t);
3951
3952 auto *Ut = createTanh(
3953 tanh1Name,
3954 createAdd(add3Name, createFullyConnected(fc5Name, RHt, Whh, Bh1),
3955 createFullyConnected(fc6Name, inputs[t], Wxh, Bh2)));
3956
3957 auto oneMinusZtUtName = nameBase + "1.-zu." + std::to_string(t);
3958 auto *OneMinusZtUt = createMul(oneMinusZtUtName, OneMinusZt, Ut);
3959
3960 auto htName = nameBase + ".H." + std::to_string(t);
3961 Ht = createAdd(htName, ZHt, OneMinusZtUt);
3962
3963 auto outName = nameBase + ".out." + std::to_string(t);
3964 auto *O = createFullyConnected(outName, Ht, Why, By);
3965 outputs.push_back(O);
3966 }
3967}
3968
3969void Function::createSimpleRNN(PlaceholderBindings &bindings,
3970 llvm::StringRef namePrefix,
3971 llvm::ArrayRef<NodeValue> inputs,
3972 unsigned batchSize, unsigned hiddenSize,
3973 unsigned outputSize,
3974 std::vector<NodeValue> &outputs) {
3975 std::string nameBase = namePrefix.str();
3976 const unsigned timeSteps = inputs.size();
3977 assert(timeSteps > 0 && "empty input");
3978 const unsigned inputSize = inputs.front().dims().back();
3979 assert(inputSize > 0 && "input dimensionality is zero");
3980
3981 // Initialize the state to zero.
3982 Placeholder *HInit =
3983 getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize},
3984 nameBase + ".initial_state", false);
3985 bindings.allocate(HInit)->zero();
3986 Node *Ht = HInit;
3987
3988 float b = 0.1;
3989 Placeholder *Whh = getParent()->createPlaceholder(
3990 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh", true);
3991 Placeholder *Bhh = getParent()->createPlaceholder(
3992 ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bhh", true);
3993 Placeholder *Wxh = getParent()->createPlaceholder(
3994 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh", true);
3995
3996 Placeholder *Bxh = getParent()->createPlaceholder(
3997 ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bxh", true);
3998 Placeholder *Why = getParent()->createPlaceholder(
3999 ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true);
4000 Placeholder *Bhy = getParent()->createPlaceholder(
4001 ElemKind::FloatTy, {outputSize}, nameBase + ".Bhy", true);
4002
4003 bindings.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
4004 getPRNG());
4005 bindings.allocate(Bhh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
4006 bindings.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize,
4007 getPRNG());
4008 bindings.allocate(Bxh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
4009 bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
4010 getPRNG());
4011 bindings.allocate(Bhy)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
4012
4013 // Un-roll backpropogation through time as a loop with the shared
4014 // parameters.
4015 for (unsigned t = 0; t < timeSteps; t++) {
4016 auto fc1Name = nameBase + ".fc1." + std::to_string(t);
4017 auto *FC1 = createFullyConnected(fc1Name, Ht, Whh, Bhh);
4018 auto fc2Name = nameBase + ".fc2." + std::to_string(t);
4019 auto *FC2 = createFullyConnected(fc2Name, inputs[t], Wxh, Bxh);
4020 auto aName = nameBase + ".add." + std::to_string(t);
4021 auto *A = createAdd(aName, FC1, FC2);
4022 auto tanhName = nameBase + ".tanh." + std::to_string(t);
4023 auto *H = createTanh(tanhName, A);
4024 auto outName = nameBase + ".out." + std::to_string(t);
4025 auto *O = createFullyConnected(outName, H, Why, Bhy);
4026 outputs.push_back(O);
4027
4028 Ht = H;
4029 };
4030}
4031
4032LSTMUnitNode *Function::createLSTMUnit(llvm::StringRef namePrefix,
4033 NodeValue Input, NodeValue C) {
4034
4035 return addNode(new LSTMUnitNode(namePrefix, Input, C));
4036}
4037
4038template <class T>
4039std::vector<NodeValue> Function::createSingleDirectionLSTM(
4040 std::string nameBase, T inputItr, const int timeSteps, NodeValue Wx,
4041 NodeValue Wh, NodeValue Bx, NodeValue Bh, NodeValue &H, NodeValue &C) {
4042
4043 std::vector<NodeValue> Hs;
4044 auto name = [&nameBase](const char *s, int t) {
4045 return strFormat("%s.%s_%d", nameBase.c_str(), s, t);
4046 };
4047 for (int t = 0; t < timeSteps; t++, inputItr++) {
4048
4049 auto *result = createAdd(
4050 name("add", t), createFullyConnected(name("fc_1", t), H, Wh, Bh),
4051 createFullyConnected(name("fc_2", t), *inputItr, Wx, Bx));
4052
4053 auto lstmUnitNode =
4054 addNode(new LSTMUnitNode(name("lstm_unit", t), result, C));
4055 H = lstmUnitNode->getNthResult(0);
4056 C = lstmUnitNode->getNthResult(1);
4057
4058 Hs.push_back(
4059 createReshape(name("H_reshape", t), H, {1, H.dims()[0], H.dims()[1]})
4060 ->getResult());
4061 }
4062 return Hs;
4063}
4064
4065std::vector<NodeValue> Function::createMultipleLayerSingleDirectionLSTM(
4066 std::string nameBase, NodeValue input, unsigned batchSize,
4067 unsigned inputSize, const int timeSteps, std::vector<NodeValue> &Wx,
4068 std::vector<NodeValue> &Wh, std::vector<NodeValue> &Bx,
4069 std::vector<NodeValue> &Bh, NodeValue &H, NodeValue &C) {
4070
4071 assert(Wx.size() > 0 && Wh.size() > 0 && Bx.size() > 0 && Bh.size() > 0 &&
4072 "Wx, Wh, Bx and Bh should be non empty vectors");
4073
4074 auto numLayers = Wx.size();
4075 NodeValue temp_input = input;
4076 std::vector<NodeValue> Hs;
4077 auto name = [&nameBase](const char *s, int t) {
4078 return strFormat("%s.%s_%d", nameBase.c_str(), s, t);
4079 };
4080 std::vector<NodeValue> Hv, Cv;
4081 auto reshape2dto3d = [=](NodeValue n, std::string nameStr) {
4082 return createReshape(nameStr, n, {1, n.dims()[0], n.dims()[1]});
4083 };
4084 for (unsigned layer = 0; layer < numLayers; layer++) {
4085 auto slidedInputs = createSlicedInput(temp_input, nameBase, batchSize,
4086 inputSize, timeSteps);
4087 auto Hn =
4088 createReshape(name("reshape_hn", layer),
4089 createSlice(name("slice_hn", layer), H, {layer, 0, 0},
4090 {layer + 1, H.dims()[1], H.dims()[2]})
4091 ->getResult(),
4092 {H.dims()[1], H.dims()[2]})
4093 ->getResult();
4094 auto Cn =
4095 createReshape(name("reshape_cn", layer),
4096 createSlice(name("slice_cn", layer), C, {layer, 0, 0},
4097 {layer + 1, C.dims()[1], C.dims()[2]})
4098 ->getResult(),
4099 {C.dims()[1], C.dims()[2]})
4100 ->getResult();
4101 Hs = createSingleDirectionLSTM(nameBase, slidedInputs.begin(), timeSteps,
4102 Wx[layer], Wh[layer], Bx[layer], Bh[layer],
4103 Hn, Cn);
4104 temp_input = createConcat(nameBase + "_lstm_temp_input", Hs, 0);
4105 Hv.emplace_back(reshape2dto3d(Hn, "_lstm_hv"));
4106 Cv.emplace_back(reshape2dto3d(Cn, "_lstm_cv"));
4107 inputSize = temp_input.dims()[2];
4108 }
4109 H = createConcat(nameBase + "_lstm_h_output", Hv, 0);
4110 C = createConcat(nameBase + "_lstm_c_output", Cv, 0);
4111 return Hs;
4112}
4113
4114std::vector<NodeValue> Function::createSlicedInput(NodeValue input,
4115 std::string &nameBase,
4116 unsigned batchSize,
4117 unsigned inputSize,
4118 const int timeSteps) {
4119 std::vector<NodeValue> inputs;
4120 auto name = [&nameBase](const char *s, int t) {
4121 return strFormat("%s.%s_%d", nameBase.c_str(), s, t);
4122 };
4123 for (unsigned t = 0; t < timeSteps; t++) {
4124 auto inputSliced = createSlice(name("slice", t), input, {t, 0, 0},
4125 {t + 1, batchSize, inputSize})
4126 ->getResult();
4127 inputSliced =
4128 createReshape(name("reshape", t), inputSliced, {batchSize, inputSize})
4129 ->getResult();
4130 inputs.push_back(inputSliced);
4131 }
4132 return inputs;
4133}
4134
4135void Function::createPyTorchLSTM(llvm::StringRef namePrefix, NodeValue input,
4136 std::vector<NodeValue> &Wx,
4137 std::vector<NodeValue> &Wh,
4138 std::vector<NodeValue> &Bx,
4139 std::vector<NodeValue> &Bh, NodeValue &Ht,
4140 NodeValue &Ct, NodeValue &output,
4141 bool isBidirectional, NodeValue WxR,
4142 NodeValue WhR, NodeValue BxR, NodeValue BhR) {
4143 std::string nameBase = namePrefix.str();
4144 assert(input.dims().back() > 0 && "input dimensionality is zero");
4145 assert((!isBidirectional || WxR != NodeValue(nullptr)) &&
4146 "Bidirectional LSTM must provide reverse weights & biases");
4147 assert(Wx.size() > 0 && Wh.size() > 0 && Bx.size() > 0 && Bh.size() > 0 &&
4148 "Wx, Wh, Bx and Bh should be non empty vectors");
4149
4150 std::vector<NodeValue> inputs, outputs;
4151 unsigned batchSize, inputSize, timeSteps, hiddenSize;
4152 batchSize = input.dims()[1];
4153 inputSize = input.dims()[2];
4154 if (Wh.size() == 1) {
4155 hiddenSize = Wh[0].dims()[0];
4156 } else {
4157 hiddenSize = Wh[0].dims()[1];
4158 }
4159
4160 timeSteps = input.dims()[0];
4161
4162 // Input gate:
4163 // I <- sigmoid(Wxi * x + Bxi + Whi * h + Bhi)
4164 // Forget gate:
4165 // F <- sigmoid(Wxf * x + Bxf + Whf * h + Bhf)
4166 // Cell gate:
4167 // G <- tanh(Wxg * x + Bxg + Whg * h + Bhg)
4168 // Output gate:
4169 // O <- sigmoid(Wxo * x + Bxo + Who * h + Bho)
4170 // Cell state:
4171 // C <- F . C + I . G
4172 // Hidden state:
4173 // h <- O . tanh(C)
4174
4175 if (isBidirectional) {
4176 // For bidirectional LSTM, we split H and C to two part, each direction a
4177 // part. For each part we calculate them separately.
4178 NodeValue Hforward, Hbackward, Cforward, Cbackward;
4179 Hforward = createReshape("reshape_hforward",
4180 createSlice("slice_hforward", Ht, {0, 0, 0},
4181 {1, batchSize, hiddenSize}),
4182 {batchSize, hiddenSize})
4183 ->getResult();
4184 Hbackward = createReshape("reshape_hbackward",
4185 createSlice("slice_hbackward", Ht, {1, 0, 0},
4186 {2, batchSize, hiddenSize}),
4187 {batchSize, hiddenSize})
4188 ->getResult();
4189 Cforward = createReshape("reshape_cforward",
4190 createSlice("slice_cforward", Ct, {0, 0, 0},
4191 {1, batchSize, hiddenSize}),
4192 {batchSize, hiddenSize})
4193 ->getResult();
4194 Cbackward = createReshape("reshape_cbackward",
4195 createSlice("slice_cbackward", Ct, {1, 0, 0},
4196 {2, batchSize, hiddenSize}),
4197 {batchSize, hiddenSize})
4198 ->getResult();
4199
4200 auto slicedInputs =
4201 createSlicedInput(input, nameBase, batchSize, inputSize, timeSteps);
4202 auto outputForwards =
4203 createSingleDirectionLSTM<std::vector<NodeValue>::iterator>(
4204 nameBase + "_lstm_forward", slicedInputs.begin(), timeSteps, Wx[0],
4205 Wh[0], Bx[0], Bh[0], Hforward, Cforward);
4206 auto outputBackwards =
4207 createSingleDirectionLSTM<std::vector<NodeValue>::reverse_iterator>(
4208 nameBase + "_lstm_backward", slicedInputs.rbegin(), timeSteps, WxR,
4209 WhR, BxR, BhR, Hbackward, Cbackward);
4210 std::reverse(outputBackwards.begin(), outputBackwards.end());
4211 NodeValue outputForward = createConcat(
4212 nameBase + "_lstm_forward_output_concat", outputForwards, 0);
4213 NodeValue outputBackward = createConcat(
4214 nameBase + "_lstm_backward_output_concat", outputBackwards, 0);
4215 output =
4216 createConcat("final_output_concat", {outputForward, outputBackward}, 2)
4217 ->getResult();
4218
4219 auto reshape2dto3d = [=](NodeValue n, std::string nameStr) {
4220 return createReshape(nameStr, n, {1, n.dims()[0], n.dims()[1]});
4221 };
4222 Ht = createConcat("Ht_Concat",
4223 {reshape2dto3d(Hforward, "final_Hforward_reshape"),
4224 reshape2dto3d(Hbackward, "final_Hbackward_reshape")},
4225 0)
4226 ->getResult();
4227 Ct = createConcat("Ct_Concat",
4228 {reshape2dto3d(Cforward, "final_Cforward_reshape"),
4229 reshape2dto3d(Cbackward, "final_Cbackward_reshape")},
4230 0)
4231 ->getResult();
4232
4233 } else {
4234 if (Ht.dims().size() == 2) {
4235 auto slicedInputs =
4236 createSlicedInput(input, nameBase, batchSize, inputSize, timeSteps);
4237 outputs = createSingleDirectionLSTM<std::vector<NodeValue>::iterator>(
4238 nameBase + "_lstm", slicedInputs.begin(), timeSteps, Wx[0], Wh[0],
4239 Bx[0], Bh[0], Ht, Ct);
4240 } else {
4241 outputs = createMultipleLayerSingleDirectionLSTM(
4242 nameBase + "_lstm", input, batchSize, inputSize, timeSteps, Wx, Wh,
4243 Bx, Bh, Ht, Ct);
4244 }
4245 output = createConcat(nameBase + "_lstm_output_concat", outputs, 0);
4246 }
4247};
4248
4249void Function::createLSTM(PlaceholderBindings &bindings,
4250 llvm::StringRef namePrefix,
4251 llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
4252 unsigned hiddenSize, unsigned outputSize,
4253 std::vector<NodeValue> &outputs) {
4254 std::string nameBase = namePrefix.str();
4255 const unsigned timeSteps = inputs.size();
4256 assert(timeSteps > 0 && "empty input");
4257 const unsigned inputSize = inputs.front().dims().back();
4258 assert(inputSize > 0 && "input dimensionality is zero");
4259
4260 // Initialize the hidden and cell states to zero.
4261 Placeholder *HInit =
4262 getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize},
4263 "initial_hidden_state", false);
4264 bindings.allocate(HInit)->zero();
4265 Node *Ht = HInit;
4266
4267 Placeholder *CInit = getParent()->createPlaceholder(
4268 ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_cell_state", false);
4269 bindings.allocate(CInit)->zero();
4270 Node *Ct = CInit;
4271
4272 // Forget gate:
4273 // F <- sigmoid(Wxf * x + Whf * h + bf)
4274 // Input gate:
4275 // I <- sigmoid(Wxi * x + Whi * h + bi)
4276 // Output gate:
4277 // O <- sigmoid(Wxo * x + Who * h + bi)
4278 // Cell state:
4279 // C <- F . C + I . tanh(Wxc * x + Whc * h + bc)
4280 // Hidden state:
4281 // h <- O . tanh(C)
4282
4283 // forget gate
4284 float bForget = 1.0;
4285 Placeholder *Wxf = getParent()->createPlaceholder(
4286 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxf", true);
4287 Placeholder *Whf = getParent()->createPlaceholder(
4288 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whf", true);
4289 Placeholder *Bf1 = getParent()->createPlaceholder(
4290 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf1", true);
4291 Placeholder *Bf2 = getParent()->createPlaceholder(
4292 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf2", true);
4293 bindings.allocate(Wxf)->init(glow::Tensor::InitKind::Xavier, inputSize,
4294 getPRNG());
4295 bindings.allocate(Whf)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
4296 getPRNG());
4297 bindings.allocate(Bf1)->init(glow::Tensor::InitKind::Broadcast, bForget,
4298 getPRNG());
4299 bindings.allocate(Bf2)->init(glow::Tensor::InitKind::Broadcast, bForget,
4300 getPRNG());
4301
4302 // input gate
4303 float bInput = 0.1;
4304 Placeholder *Wxi = getParent()->createPlaceholder(
4305 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxi", true);
4306 Placeholder *Whi = getParent()->createPlaceholder(
4307 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whi", true);
4308 Placeholder *Bi1 = getParent()->createPlaceholder(
4309 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi1", true);
4310 Placeholder *Bi2 = getParent()->createPlaceholder(
4311 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi2", true);
4312
4313 bindings.allocate(Wxi)->init(glow::Tensor::InitKind::Xavier, inputSize,
4314 getPRNG());
4315 bindings.allocate(Whi)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
4316 getPRNG());
4317 bindings.allocate(Bi1)->init(glow::Tensor::InitKind::Broadcast, bInput,
4318 getPRNG());
4319 bindings.allocate(Bi2)->init(glow::Tensor::InitKind::Broadcast, bInput,
4320 getPRNG());
4321
4322 // output gate
4323 float bOutput = 0.1;
4324 Placeholder *Wxo = getParent()->createPlaceholder(
4325 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxo", true);
4326 Placeholder *Who = getParent()->createPlaceholder(
4327 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Who", true);
4328 Placeholder *Bo1 = getParent()->createPlaceholder(
4329 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo1", true);
4330 Placeholder *Bo2 = getParent()->createPlaceholder(
4331 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo2", true);
4332
4333 bindings.allocate(Wxo)->init(glow::Tensor::InitKind::Xavier, inputSize,
4334 getPRNG());
4335 bindings.allocate(Who)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
4336 getPRNG());
4337 bindings.allocate(Bo1)->init(glow::Tensor::InitKind::Broadcast, bOutput,
4338 getPRNG());
4339 bindings.allocate(Bo2)->init(glow::Tensor::InitKind::Broadcast, bOutput,
4340 getPRNG());
4341
4342 // cell state
4343 float bCell = 0.1;
4344 Placeholder *Wxc = getParent()->createPlaceholder(
4345 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxc", true);
4346 Placeholder *Whc = getParent()->createPlaceholder(
4347 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whc", true);
4348 Placeholder *Bc1 = getParent()->createPlaceholder(
4349 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc1", true);
4350 Placeholder *Bc2 = getParent()->createPlaceholder(
4351 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc2", true);
4352
4353 bindings.allocate(Wxc)->init(glow::Tensor::InitKind::Xavier, inputSize,
4354 getPRNG());
4355 bindings.allocate(Whc)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
4356 getPRNG());
4357 bindings.allocate(Bc1)->init(glow::Tensor::InitKind::Broadcast, bCell,
4358 getPRNG());
4359 bindings.allocate(Bc2)->init(glow::Tensor::InitKind::Broadcast, bCell,
4360 getPRNG());
4361
4362 // output layer
4363 float b = 0.1;
4364 Placeholder *Why = getParent()->createPlaceholder(
4365 ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true);
4366 Placeholder *By = getParent()->createPlaceholder(
4367 ElemKind::FloatTy, {outputSize}, nameBase + ".by", true);
4368
4369 bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
4370 getPRNG());
4371 bindings.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
4372
4373 std::vector<Node *> outputNodes;
4374 for (unsigned t = 0; t < timeSteps; t++) {
4375 auto fc1Name = nameBase + ".fc1." + std::to_string(t);
4376 auto fc2Name = nameBase + ".fc2." + std::to_string(t);
4377 auto add1Name = nameBase + ".add1." + std::to_string(t);
4378 auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t);
4379
4380 auto *Ft = createSigmoid(
4381 sigmoid1Name,
4382 createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whf, Bf1),
4383 createFullyConnected(fc2Name, inputs[t], Wxf, Bf2)));
4384
4385 auto fc3Name = nameBase + ".fc3." + std::to_string(t);
4386 auto fc4Name = nameBase + ".fc4." + std::to_string(t);
4387 auto add2Name = nameBase + ".add2." + std::to_string(t);
4388 auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t);
4389
4390 auto *It = createSigmoid(
4391 sigmoid2Name,
4392 createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whi, Bi1),
4393 createFullyConnected(fc4Name, inputs[t], Wxi, Bi2)));
4394
4395 auto fc5Name = nameBase + ".fc5." + std::to_string(t);
4396 auto fc6Name = nameBase + ".fc6." + std::to_string(t);
4397 auto add3Name = nameBase + ".add3." + std::to_string(t);
4398 auto sigmoid3Name = nameBase + ".sigmoid3." + std::to_string(t);
4399
4400 auto *Ot = createSigmoid(
4401 sigmoid3Name,
4402 createAdd(add3Name, createFullyConnected(fc5Name, Ht, Who, Bo1),
4403 createFullyConnected(fc6Name, inputs[t], Wxo, Bo2)));
4404
4405 auto fc7Name = nameBase + ".fc7." + std::to_string(t);
4406 auto fc8Name = nameBase + ".fc8." + std::to_string(t);
4407 auto add4Name = nameBase + ".add4." + std::to_string(t);
4408 auto tanh1Name = nameBase + ".tanh1." + std::to_string(t);
4409
4410 auto *CRt = createTanh(
4411 tanh1Name,
4412 createAdd(add4Name, createFullyConnected(fc7Name, Ht, Whc, Bc1),
4413 createFullyConnected(fc8Name, inputs[t], Wxc, Bc2)));
4414
4415 auto mul1Name = nameBase + ".mul1." + std::to_string(t);
4416 auto mul2Name = nameBase + ".mul2." + std::to_string(t);
4417 Ct = createAdd(nameBase + ".C." + std::to_string(t),
4418 createMul(mul1Name, Ft, Ct), createMul(mul2Name, It, CRt));
4419
4420 auto htName = nameBase + ".H." + std::to_string(t);
4421 auto tanh2Name = nameBase + ".tanh2." + std::to_string(t);
4422 Ht = createMul(htName, Ot, createTanh(tanh2Name, Ct));
4423
4424 auto outName = nameBase + ".out." + std::to_string(t);
4425 auto *O = createFullyConnected(outName, Ht, Why, By);
4426 outputs.push_back(O);
4427 }
4428};
4429
4430void Function::createOnnxRNN(llvm::StringRef namePrefix, NodeValue X,
4431 NodeValue W, NodeValue R, NodeValue B,
4432 NodeValue initial_h, NodeValue &Y, NodeValue &Y_h,
4433 unsigned hiddenSize, RnnDirection direction,
4434 std::vector<RnnActivation> &activations) {
4435
4436#define RNN_X_SLICE_RANGE(idx) \
4437 {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
4438#define RNN_W_SLICE_RANGE(idx0, idx1) \
4439 {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
4440#define RNN_R_SLICE_RANGE(idx0, idx1) \
4441 {idx0, idx1 * hiddenSize, 0}, { \
4442 idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \
4443 }
4444#define RNN_B_SLICE_RANGE(idx0, idx1) \
4445 {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
4446#define RNN_H_SLICE_RANGE(idx) \
4447 {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
4448#define RNN_CREATE_FC(name, LHS, RHS, BIAS) \
4449 BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \
4450 : (Node *)createMatMul(name, LHS, RHS)
4451
4452 // Operator name.
4453 const std::string &opName = namePrefix.str();
4454
4455 // Get all size parameters.
4456 dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
4457 assert(X.dims().size() == 3 &&
4458 "ONNX RNN input 'X' should have 3 dimensions!");
4459 dim_t seqLength = X.dims()[0];
4460 dim_t batchSize = X.dims()[1];
4461 dim_t inputSize = X.dims()[2];
4462
4463 // Validate W size.
4464 assert(W.dims().size() == 3 &&
4465 "ONNX RNN input 'W' should have 3 dimensions!");
4466 assert(W.dims()[0] == numDirections && W.dims()[1] == hiddenSize &&
4467 W.dims()[2] == inputSize && "ONNX RNN 'W' tensor size invalid!");
4468
4469 // Validate R size.
4470 assert(R.dims().size() == 3 &&
4471 "ONNX RNN input 'R' should have 3 dimensions!");
4472 assert(R.dims()[0] == numDirections && R.dims()[1] == hiddenSize &&
4473 R.dims()[2] == hiddenSize && "ONNX RNN 'R' tensor size invalid!");
4474
4475 // Validate B size.
4476 if (B.getNode()) {
4477 assert(B.dims().size() == 2 &&
4478 "ONNX RNN input 'B' should have 2 dimensions!");
4479 assert(B.dims()[0] == numDirections && B.dims()[1] == 2 * hiddenSize &&
4480 "ONNX RNN 'B' tensor size invalid!");
4481 }
4482
4483 // Validate initial_h size if given else create Splat with 0.
4484 if (initial_h.getNode()) {
4485 assert(initial_h.dims().size() == 3 &&
4486 "ONNX RNN input 'initial_h' should have 2 dimensions!");
4487 assert(initial_h.dims()[0] == numDirections &&
4488 initial_h.dims()[1] == batchSize &&
4489 initial_h.dims()[2] == hiddenSize &&
4490 "ONNX RNN 'initial_h' tensor size invalid!");
4491 } else {
4492 auto splatTy = getParent()->uniqueType(
4493 ElemKind::FloatTy, {numDirections, batchSize, hiddenSize});
4494 initial_h = createSplat(opName + ".initial_h", splatTy, 0.0);
4495 }
4496
4497 // Validate number of activations.
4498 assert(activations.size() == numDirections * 1 &&
4499 "ONNX RNN activations vector invalid!");
4500
4501 // Create X slices.
4502 std::vector<Node *> Xslices;
4503 for (dim_t t = 0; t < seqLength; t++) {
4504 auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
4505 Node *Xt = createSlice(XsliceName, X, RNN_X_SLICE_RANGE(t));
4506 auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
4507 Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
4508 Xslices.push_back(Xt);
4509 }
4510
4511 // Lambda to load forward/backward RNN cell.
4512 auto loadRNNCell = [&](bool forward, std::vector<NodeValue> &Yslices,
4513 NodeValue &Hslice) {
4514 // Name prefix.
4515 std::string dirLabel = forward ? ".fw" : ".bw";
4516 std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
4517
4518 // Slice index used for creating weights slices.
4519 dim_t sliceIdx0 = 0;
4520 if (direction == RnnDirection::Bidirectional) {
4521 sliceIdx0 = forward ? 0 : 1;
4522 }
4523
4524 // Activations.
4525 size_t activationOffset = sliceIdx0 * 1;
4526 auto activationF = activations[activationOffset + 0];
4527
4528 // Create W slice (Required).
4529 NodeValue Wi =
4530 createSlice(prefix + ".Wi.", W, RNN_W_SLICE_RANGE(sliceIdx0, 0));
4531 Wi = createReshape(prefix + ".Wi.reshape", Wi, {hiddenSize, inputSize});
4532 Wi = createTranspose(prefix + ".Wi.transp", Wi, {1, 0});
4533
4534 // Create R slice (Required).
4535 NodeValue Ri =
4536 createSlice(prefix + ".Ri.", R, RNN_R_SLICE_RANGE(sliceIdx0, 0));
4537 Ri = createReshape(prefix + ".Ri.reshape", Ri, {hiddenSize, hiddenSize});
4538 Ri = createTranspose(prefix + ".Ri.transp", Ri, {1, 0});
4539
4540 // Create B slices (optional).
4541 NodeValue bWi = nullptr;
4542 NodeValue bRi = nullptr;
4543
4544 if (B) {
4545
4546 bWi = createSlice(prefix + ".bWi.", B, RNN_B_SLICE_RANGE(sliceIdx0, 0));
4547 bRi = createSlice(prefix + ".bRi.", B, RNN_B_SLICE_RANGE(sliceIdx0, 1));
4548
4549 bWi = createReshape(prefix + ".bWi.reshape", bWi, {hiddenSize});
4550 bRi = createReshape(prefix + ".bRi.reshape", bRi, {hiddenSize});
4551 }
4552
4553 // Create H slice for this direction.
4554 Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
4555 RNN_H_SLICE_RANGE(sliceIdx0));
4556 Hinit =
4557 createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
4558
4559 // Initialize.
4560 Node *Ht = Hinit;
4561
4562 // Unroll RNN cell for all time steps.
4563 for (size_t t = 0; t < seqLength; t++) {
4564
4565 // Input for current time step.
4566 // For the reverse RNN cell the inputs are provided in reverse order.
4567 Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
4568
4569 // Hidden state update: Ht = f(Xt * Wi + bWi + Ht-1 * Ri + bRi).
4570 Ht = createAdd(prefix + ".H.add",
4571 RNN_CREATE_FC(prefix + ".H.fc1", Xt, Wi, bWi),
4572 RNN_CREATE_FC(prefix + ".H.fc2", Ht, Ri, bRi));
4573 Ht = activationF(prefix + ".H.act", Ht);
4574
4575 // Output.
4576 Yslices.push_back(Ht);
4577 }
4578
4579 // Updated states nodes.
4580 Hslice = Ht;
4581 }; // End of local lambda "loadRNNCell".
4582
4583 bool forwardEnabled = ((direction == RnnDirection::Forward) ||
4584 (direction == RnnDirection::Bidirectional));
4585 bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
4586 (direction == RnnDirection::Bidirectional));
4587
4588 std::vector<NodeValue> YSlices;
4589 std::vector<NodeValue> Hslices;
4590
4591 // Load forward RNN.
4592 std::vector<NodeValue> forwardYslices;
4593 if (forwardEnabled) {
4594 NodeValue forwardHslice;
4595 loadRNNCell(/* forward */ true, forwardYslices, forwardHslice);
4596 Hslices.push_back(forwardHslice);
4597 }
4598
4599 // Load backward RNN.
4600 std::vector<NodeValue> backwardYslices;
4601 if (backwardEnabled) {
4602 NodeValue backwardHslice;
4603 loadRNNCell(/* forward */ false, backwardYslices, backwardHslice);
4604 Hslices.push_back(backwardHslice);
4605 }
4606
4607 // Gather Y slices.
4608 for (size_t t = 0; t < seqLength; t++) {
4609 if (forwardEnabled) {
4610 YSlices.push_back(forwardYslices[t]);
4611 }
4612 if (backwardEnabled) {
4613 YSlices.push_back(backwardYslices[seqLength - 1 - t]);
4614 }
4615 }
4616
4617 // Concatenate Y slices.
4618 // Y size is [seqLength, numDirections, batchSize, hiddenSize].
4619 Y = createReshape(opName + ".Y.reshape",
4620 createConcat(opName + ".Y.concat", YSlices, 0),
4621 {seqLength, numDirections, batchSize, hiddenSize});
4622
4623 // Concatenate Y_h slices.
4624 // Y_h size is [numDirections, batchSize, hiddenSize].
4625 Y_h = createReshape(opName + ".Y_h.reshape",
4626 createConcat(opName + ".Y_h.concat", Hslices, 0),
4627 {numDirections, batchSize, hiddenSize});
4628
4629#undef RNN_X_SLICE_RANGE
4630#undef RNN_W_SLICE_RANGE
4631#undef RNN_R_SLICE_RANGE
4632#undef RNN_B_SLICE_RANGE
4633#undef RNN_H_SLICE_RANGE
4634#undef RNN_CREATE_FC
4635}
4636
4637void Function::createOnnxGRU(llvm::StringRef namePrefix, NodeValue X,
4638 NodeValue W, NodeValue R, NodeValue B,
4639 NodeValue initial_h, NodeValue &Y, NodeValue &Y_h,
4640 unsigned hiddenSize, RnnDirection direction,
4641 std::vector<RnnActivation> &activations,
4642 bool linearBeforeReset) {
4643
4644#define GRU_X_SLICE_RANGE(idx) \
4645 {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
4646#define GRU_W_SLICE_RANGE(idx0, idx1) \
4647 {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
4648#define GRU_R_SLICE_RANGE(idx0, idx1) \
4649 {idx0, idx1 * hiddenSize, 0}, { \
4650 idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \
4651 }
4652#define GRU_B_SLICE_RANGE(idx0, idx1) \
4653 {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
4654#define GRU_H_SLICE_RANGE(idx) \
4655 {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
4656#define GRU_CREATE_FC(name, LHS, RHS, BIAS) \
4657 BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \
4658 : (Node *)createMatMul(name, LHS, RHS)
4659
4660 // Operator name.
4661 const std::string &opName = namePrefix.str();
4662
4663 // Get all size parameters.
4664 dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
4665 assert(X.dims().size() == 3 &&
4666 "ONNX GRU input 'X' should have 3 dimensions!");
4667 dim_t seqLength = X.dims()[0];
4668 dim_t batchSize = X.dims()[1];
4669 dim_t inputSize = X.dims()[2];
4670
4671 // Validate W size.
4672 assert(W.dims().size() == 3 &&
4673 "ONNX GRU input 'W' should have 3 dimensions!");
4674 assert(W.dims()[0] == numDirections && W.dims()[1] == 3 * hiddenSize &&
4675 W.dims()[2] == inputSize && "ONNX GRU 'W' tensor size invalid!");
4676
4677 // Validate R size.
4678 assert(R.dims().size() == 3 &&
4679 "ONNX GRU input 'R' should have 3 dimensions!");
4680 assert(R.dims()[0] == numDirections && R.dims()[1] == 3 * hiddenSize &&
4681 R.dims()[2] == hiddenSize && "ONNX GRU 'R' tensor size invalid!");
4682
4683 // Validate B size.
4684 if (B.getNode()) {
4685 assert(B.dims().size() == 2 &&
4686 "ONNX GRU input 'B' should have 2 dimensions!");
4687 assert(B.dims()[0] == numDirections && B.dims()[1] == 6 * hiddenSize &&
4688 "ONNX GRU 'B' tensor size invalid!");
4689 }
4690
4691 // Validate initial_h size if given else create Splat with 0.
4692 if (initial_h.getNode()) {
4693 assert(initial_h.dims().size() == 3 &&
4694 "ONNX GRU input 'initial_h' should have 2 dimensions!");
4695 assert(initial_h.dims()[0] == numDirections &&
4696 initial_h.dims()[1] == batchSize &&
4697 initial_h.dims()[2] == hiddenSize &&
4698 "ONNX GRU 'initial_h' tensor size invalid!");
4699 } else {
4700 auto splatTy = getParent()->uniqueType(
4701 ElemKind::FloatTy, {numDirections, batchSize, hiddenSize});
4702 initial_h = createSplat(opName + ".initial_h", splatTy, 0.0);
4703 }
4704
4705 // Validate number of activations.
4706 assert(activations.size() == numDirections * 2 &&
4707 "ONNX GRU activations vector invalid!");
4708
4709 // Create X slices.
4710 std::vector<Node *> Xslices;
4711 for (dim_t t = 0; t < seqLength; t++) {
4712 auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
4713 Node *Xt = createSlice(XsliceName, X, GRU_X_SLICE_RANGE(t));
4714 auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
4715 Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
4716 Xslices.push_back(Xt);
4717 }
4718
4719 // Lambda to load forward/backward GRU cell.
4720 auto loadGRUCell = [&](bool forward, std::vector<NodeValue> &Yslices,
4721 NodeValue &Hslice) {
4722 // Name prefix.
4723 std::string dirLabel = forward ? ".fw" : ".bw";
4724 std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
4725
4726 // Slice index used for creating weights slices.
4727 dim_t sliceIdx0 = 0;
4728 if (direction == RnnDirection::Bidirectional) {
4729 sliceIdx0 = forward ? 0 : 1;
4730 }
4731
4732 // Activations.
4733 size_t activationOffset = sliceIdx0 * 2;
4734 auto activationF = activations[activationOffset + 0];
4735 auto activationG = activations[activationOffset + 1];
4736
4737 // Create W slices (Required).
4738 NodeValue Wz =
4739 createSlice(prefix + ".Wz.", W, GRU_W_SLICE_RANGE(sliceIdx0, 0));
4740 NodeValue Wr =
4741 createSlice(prefix + ".Wr.", W, GRU_W_SLICE_RANGE(sliceIdx0, 1));
4742 NodeValue Wh =
4743 createSlice(prefix + ".Wh.", W, GRU_W_SLICE_RANGE(sliceIdx0, 2));
4744
4745 Wz = createReshape(prefix + ".Wz.reshape", Wz, {hiddenSize, inputSize});
4746 Wr = createReshape(prefix + ".Wr.reshape", Wr, {hiddenSize, inputSize});
4747 Wh = createReshape(prefix + ".Wh.reshape", Wh, {hiddenSize, inputSize});
4748
4749 Wz = createTranspose(prefix + ".Wz.transp", Wz, {1, 0});
4750 Wr = createTranspose(prefix + ".Wr.transp", Wr, {1, 0});
4751 Wh = createTranspose(prefix + ".Wh.transp", Wh, {1, 0});
4752
4753 // Create R slices (Required).
4754 NodeValue Rz =
4755 createSlice(prefix + ".Rz.", R, GRU_R_SLICE_RANGE(sliceIdx0, 0));
4756 NodeValue Rr =
4757 createSlice(prefix + ".Rr.", R, GRU_R_SLICE_RANGE(sliceIdx0, 1));
4758 NodeValue Rh =
4759 createSlice(prefix + ".Rh.", R, GRU_R_SLICE_RANGE(sliceIdx0, 2));
4760
4761 Rz = createReshape(prefix + ".Rz.reshape", Rz, {hiddenSize, hiddenSize});
4762 Rr = createReshape(prefix + ".Rr.reshape", Rr, {hiddenSize, hiddenSize});
4763 Rh = createReshape(prefix + ".Rh.reshape", Rh, {hiddenSize, hiddenSize});
4764
4765 Rz = createTranspose(prefix + ".Rz.transp", Rz, {1, 0});
4766 Rr = createTranspose(prefix + ".Rr.transp", Rr, {1, 0});
4767 Rh = createTranspose(prefix + ".Rh.transp", Rh, {1, 0});
4768
4769 // Create B slices (optional).
4770 NodeValue bWz = nullptr;
4771 NodeValue bWr = nullptr;
4772 NodeValue bWh = nullptr;
4773 NodeValue bRz = nullptr;
4774 NodeValue bRr = nullptr;
4775 NodeValue bRh = nullptr;
4776
4777 if (B) {
4778
4779 bWz = createSlice(prefix + ".bWz.", B, GRU_B_SLICE_RANGE(sliceIdx0, 0));
4780 bWr = createSlice(prefix + ".bWr.", B, GRU_B_SLICE_RANGE(sliceIdx0, 1));
4781 bWh = createSlice(prefix + ".bWh.", B, GRU_B_SLICE_RANGE(sliceIdx0, 2));
4782 bRz = createSlice(prefix + ".bRz.", B, GRU_B_SLICE_RANGE(sliceIdx0, 3));
4783 bRr = createSlice(prefix + ".bRr.", B, GRU_B_SLICE_RANGE(sliceIdx0, 4));
4784 bRh = createSlice(prefix + ".bRh.", B, GRU_B_SLICE_RANGE(sliceIdx0, 5));
4785
4786 bWz = createReshape(prefix + ".bWz.reshape", bWz, {hiddenSize});
4787 bWr = createReshape(prefix + ".bWr.reshape", bWr, {hiddenSize});
4788 bWh = createReshape(prefix + ".bWh.reshape", bWh, {hiddenSize});
4789 bRz = createReshape(prefix + ".bRz.reshape", bRz, {hiddenSize});
4790 bRr = createReshape(prefix + ".bRr.reshape", bRr, {hiddenSize});
4791 bRh = createReshape(prefix + ".bRh.reshape", bRh, {hiddenSize});
4792 }
4793
4794 // Create H slice for this direction.
4795 Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
4796 GRU_H_SLICE_RANGE(sliceIdx0));
4797 Hinit =
4798 createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
4799
4800 // Initialize.
4801 Node *Ht = Hinit;
4802
4803 // Unroll GRU cell for all time steps.
4804 for (size_t t = 0; t < seqLength; t++) {
4805
4806 // Input for current time step.
4807 // For the reverse GRU cell the inputs are provided in reverse order.
4808 Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
4809
4810 // Update gate: zt = f(Xt * Wz + bWz + Ht-1 * Rz + bRz).
4811 Node *zt = createAdd(prefix + ".Z.add1",
4812 GRU_CREATE_FC(prefix + ".Z.fc1", Xt, Wz, bWz),
4813 GRU_CREATE_FC(prefix + ".Z.fc2", Ht, Rz, bRz));
4814 zt = activationF(prefix + ".Z.act", zt);
4815
4816 // Reset gate: rt = f(Xt * Wr + bWr + Ht-1 * Rr + bRr).
4817 Node *rt = createAdd(prefix + ".R.add1",
4818 GRU_CREATE_FC(prefix + ".R.fc1", Xt, Wr, bWr),
4819 GRU_CREATE_FC(prefix + ".R.fc2", Ht, Rr, bRr));
4820 rt = activationF(prefix + ".R.act", rt);
4821
4822 // Hidden gate:
4823 // For linearBeforeReset = true:
4824 // htild = g(Xt * Wh + bWh + rt . (Ht-1 * Rh + bRh)).
4825 // For linearBeforeReset = false:
4826 // htild = g(Xt * Wh + bWh + (rt . Ht-1) * Rh + bRh).
4827 Node *htild;
4828 if (linearBeforeReset) {
4829 htild = createAdd(
4830 prefix + ".Htild.add",
4831 GRU_CREATE_FC(prefix + ".Htild.fc1", Xt, Wh, bWh),
4832 createMul(prefix + ".Htild.reset", rt,
4833 GRU_CREATE_FC(prefix + ".Htild.fc2", Ht, Rh, bRh)));
4834 } else {
4835 htild = createAdd(
4836 prefix + ".Htild.add",
4837 GRU_CREATE_FC(prefix + ".Htild.fc1", Xt, Wh, bWh),
4838 GRU_CREATE_FC(prefix + ".Htild.fc2",
4839 createMul(prefix + ".Htild.reset", rt, Ht), Rh, bRh));
4840 }
4841 htild = activationG(prefix + ".Htild.act", htild);
4842
4843 // Hidden state update:
4844 // Ht = (1 - zt) . htild + zt . Ht-1 = htild - zt . htild + zt . Ht-1.
4845 Ht = createAdd(prefix + ".H.add",
4846 createSub(prefix + ".H.sub", htild,
4847 createMul(prefix + ".H.mult1", zt, htild)),
4848 createMul(prefix + ".H.mult2", zt, Ht));
4849
4850 // Output.
4851 Yslices.push_back(Ht);
4852 }
4853
4854 // Updated states nodes.
4855 Hslice = Ht;
4856 }; // End of local lambda "loadGRUCell".
4857
4858 bool forwardEnabled = ((direction == RnnDirection::Forward) ||
4859 (direction == RnnDirection::Bidirectional));
4860 bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
4861 (direction == RnnDirection::Bidirectional));
4862
4863 std::vector<NodeValue> YSlices;
4864 std::vector<NodeValue> Hslices;
4865
4866 // Load forward GRU.
4867 std::vector<NodeValue> forwardYslices;
4868 if (forwardEnabled) {
4869 NodeValue forwardHslice;
4870 loadGRUCell(/* forward */ true, forwardYslices, forwardHslice);
4871 Hslices.push_back(forwardHslice);
4872 }
4873
4874 // Load backward GRU.
4875 std::vector<NodeValue> backwardYslices;
4876 if (backwardEnabled) {
4877 NodeValue backwardHslice;
4878 loadGRUCell(/* forward */ false, backwardYslices, backwardHslice);
4879 Hslices.push_back(backwardHslice);
4880 }
4881
4882 // Gather Y slices.
4883 for (size_t t = 0; t < seqLength; t++) {
4884 if (forwardEnabled) {
4885 YSlices.push_back(forwardYslices[t]);
4886 }
4887 if (backwardEnabled) {
4888 YSlices.push_back(backwardYslices[seqLength - 1 - t]);
4889 }
4890 }
4891
4892 // Concatenate Y slices.
4893 // Y size is [seqLength, numDirections, batchSize, hiddenSize].
4894 Y = createReshape(opName + ".Y.reshape",
4895 createConcat(opName + ".Y.concat", YSlices, 0),
4896 {seqLength, numDirections, batchSize, hiddenSize});
4897
4898 // Concatenate Y_h slices.
4899 // Y_h size is [numDirections, batchSize, hiddenSize].
4900 Y_h = createReshape(opName + ".Y_h.reshape",
4901 createConcat(opName + ".Y_h.concat", Hslices, 0),
4902 {numDirections, batchSize, hiddenSize});
4903
4904#undef GRU_X_SLICE_RANGE
4905#undef GRU_W_SLICE_RANGE
4906#undef GRU_R_SLICE_RANGE
4907#undef GRU_B_SLICE_RANGE
4908#undef GRU_H_SLICE_RANGE
4909#undef GRU_CREATE_FC
4910}
4911
4912void Function::createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X,
4913 NodeValue W, NodeValue R, NodeValue B,
4914 NodeValue initial_h, NodeValue initial_c,
4915 NodeValue P, NodeValue &Y, NodeValue &Y_h,
4916 NodeValue &Y_c, unsigned hiddenSize,
4917 RnnDirection direction,
4918 std::vector<RnnActivation> &activations,
4919 bool inputForget) {
4920
4921#define LSTM_X_SLICE_RANGE(idx) \
4922 {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
4923#define LSTM_H_SLICE_RANGE(idx) \
4924 {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
4925#define LSTM_C_SLICE_RANGE(idx) \
4926 {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
4927#define LSTM_W_SLICE_RANGE(idx0, idx1) \
4928 {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
4929#define LSTM_R_SLICE_RANGE(idx0, idx1) \
4930 {idx0, idx1 * hiddenSize, 0}, { \
4931 idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \
4932 }
4933#define LSTM_B_SLICE_RANGE(idx0, idx1) \
4934 {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
4935#define LSTM_P_SLICE_RANGE(idx0, idx1) \
4936 {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
4937#define LSTM_CREATE_FC(name, LHS, RHS, BIAS) \
4938 BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \
4939 : (Node *)createMatMul(name, LHS, RHS)
4940
4941 // Operator name.
4942 const std::string &opName = namePrefix.str();
4943
4944 // Get all size parameters.
4945 dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
4946 assert(X.dims().size() == 3 &&
4947 "ONNX LSTM input 'X' should have 3 dimensions!");
4948 dim_t seqLength = X.dims()[0];
4949 dim_t batchSize = X.dims()[1];
4950 dim_t inputSize = X.dims()[2];
4951
4952 // Validate W size.
4953 assert(W.dims().size() == 3 &&
4954 "ONNX LSTM input 'W' should have 3 dimensions!");
4955 assert(W.dims()[0] == numDirections && W.dims()[1] == 4 * hiddenSize &&
4956 W.dims()[2] == inputSize && "ONNX LSTM 'W' tensor size invalid!");
4957
4958 // Validate R size.
4959 assert(R.dims().size() == 3 &&
4960 "ONNX LSTM input 'R' should have 3 dimensions!");
4961 assert(R.dims()[0] == numDirections && R.dims()[1] == 4 * hiddenSize &&
4962 R.dims()[2] == hiddenSize && "ONNX LSTM 'R' tensor size invalid!");
4963
4964 // Validate B size.
4965 if (B.getNode()) {
4966 assert(B.dims().size() == 2 &&
4967 "ONNX LSTM input 'B' should have 2 dimensions!");
4968 assert(B.dims()[0] == numDirections && B.dims()[1] == 8 * hiddenSize &&
4969 "ONNX LSTM 'B' tensor size invalid!");
4970 }
4971
4972 // Validate initial_h size if given else create Splat with 0.
4973 if (initial_h.getNode()) {
4974 assert(initial_h.dims().size() == 3 &&
4975 "ONNX LSTM input 'initial_h' should have 2 dimensions!");
4976 assert(initial_h.dims()[0] == numDirections &&
4977 initial_h.dims()[1] == batchSize &&
4978 initial_h.dims()[2] == hiddenSize &&
4979 "ONNX LSTM 'initial_h' tensor size invalid!");
4980 } else {
4981 auto splatTy = getParent()->uniqueType(
4982 ElemKind::FloatTy, {numDirections, batchSize, hiddenSize});
4983 initial_h = createSplat(opName + ".initial_h", splatTy, 0.0);
4984 }
4985
4986 // Validate initial_c size if given else create Splat with 0.
4987 if (initial_c.getNode()) {
4988 assert(initial_c.dims().size() == 3 &&
4989 "ONNX LSTM input 'initial_c' should have 2 dimensions!");
4990 assert(initial_c.dims()[0] == numDirections &&
4991 initial_c.dims()[1] == batchSize &&
4992 initial_c.dims()[2] == hiddenSize &&
4993 "ONNX LSTM 'initial_c' tensor size invalid!");
4994 } else {
4995 auto splatTy = getParent()->uniqueType(
4996 ElemKind::FloatTy, {numDirections, batchSize, hiddenSize});
4997 initial_c = createSplat(opName + ".initial_c", splatTy, 0.0);
4998 }
4999
5000 // Validate P size.
5001 if (P.getNode()) {
5002 assert(P.dims().size() == 2 &&
5003 "ONNX LSTM input 'P' should have 2 dimensions!");
5004 assert(P.dims()[0] == numDirections && P.dims()[1] == 3 * hiddenSize &&
5005 "ONNX LSTM 'P' tensor size invalid!");
5006 }
5007
5008 // Validate number of activations.
5009 assert(activations.size() == numDirections * 3 &&
5010 "ONNX LSTM activations vector invalid!");
5011
5012 // Create X slices.
5013 std::vector<Node *> Xslices;
5014 for (dim_t t = 0; t < seqLength; t++) {
5015 auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
5016 Node *Xt = createSlice(XsliceName, X, LSTM_X_SLICE_RANGE(t));
5017 auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
5018 Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
5019 Xslices.push_back(Xt);
5020 }
5021
5022 // Lambda to load forward/backward LSTM cell.
5023 auto loadLSTMCell = [&](bool forward, std::vector<NodeValue> &Yslices,
5024 NodeValue &Hslice, NodeValue &Cslice) {
5025 // Name prefix.
5026 std::string dirLabel = forward ? ".fw" : ".bw";
5027 std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
5028
5029 // Slice index used for creating weights slices.
5030 dim_t sliceIdx0 = 0;
5031 if (direction == RnnDirection::Bidirectional) {
5032 sliceIdx0 = forward ? 0 : 1;
5033 }
5034
5035 // Activations.
5036 size_t activationOffset = sliceIdx0 * 3;
5037 auto activationF = activations[activationOffset + 0];
5038 auto activationG = activations[activationOffset + 1];
5039 auto activationH = activations[activationOffset + 2];
5040
5041 // Create W slices (Required).
5042 NodeValue Wi =
5043 createSlice(prefix + ".Wi.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 0));
5044 NodeValue Wo =
5045 createSlice(prefix + ".Wo.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 1));
5046 NodeValue Wf =
5047 createSlice(prefix + ".Wf.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 2));
5048 NodeValue Wc =
5049 createSlice(prefix + ".Wc.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 3));
5050
5051 Wi = createReshape(prefix + ".Wi.reshape", Wi, {hiddenSize, inputSize});
5052 Wo = createReshape(prefix + ".Wo.reshape", Wo, {hiddenSize, inputSize});
5053 Wf = createReshape(prefix + ".Wf.reshape", Wf, {hiddenSize, inputSize});
5054 Wc = createReshape(prefix + ".Wc.reshape", Wc, {hiddenSize, inputSize});
5055
5056 Wi = createTranspose(prefix + ".Wi.transp", Wi, {1, 0});
5057 Wo = createTranspose(prefix + ".Wo.transp", Wo, {1, 0});
5058 Wf = createTranspose(prefix + ".Wf.transp", Wf, {1, 0});
5059 Wc = createTranspose(prefix + ".Wc.transp", Wc, {1, 0});
5060
5061 // Create R slices (Required).
5062 NodeValue Ri =
5063 createSlice(prefix + ".Ri.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 0));
5064 NodeValue Ro =
5065 createSlice(prefix + ".Ro.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 1));
5066 NodeValue Rf =
5067 createSlice(prefix + ".Rf.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 2));
5068 NodeValue Rc =
5069 createSlice(prefix + ".Rc.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 3));
5070
5071 Ri = createReshape(prefix + ".Ri.reshape", Ri, {hiddenSize, hiddenSize});
5072 Ro = createReshape(prefix + ".Ro.reshape", Ro, {hiddenSize, hiddenSize});
5073 Rf = createReshape(prefix + ".Rf.reshape", Rf, {hiddenSize, hiddenSize});
5074 Rc = createReshape(prefix + ".Rc.reshape", Rc, {hiddenSize, hiddenSize});
5075
5076 Ri = createTranspose(prefix + ".Ri.transp", Ri, {1, 0});
5077 Ro = createTranspose(prefix + ".Ro.transp", Ro, {1, 0});
5078 Rf = createTranspose(prefix + ".Rf.transp", Rf, {1, 0});
5079 Rc = createTranspose(prefix + ".Rc.transp", Rc, {1, 0});
5080
5081 // Create B slices (optional).
5082 NodeValue bWi = nullptr;
5083 NodeValue bWo = nullptr;
5084 NodeValue bWf = nullptr;
5085 NodeValue bWc = nullptr;
5086 NodeValue bRi = nullptr;
5087 NodeValue bRo = nullptr;
5088 NodeValue bRf = nullptr;
5089 NodeValue bRc = nullptr;
5090
5091 if (B) {
5092
5093 bWi = createSlice(prefix + ".bWi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 0));
5094 bWo = createSlice(prefix + ".bWo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 1));
5095 bWf = createSlice(prefix + ".bWf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 2));
5096 bWc = createSlice(prefix + ".bWc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 3));
5097 bRi = createSlice(prefix + ".bRi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 4));
5098 bRo = createSlice(prefix + ".bRo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 5));
5099 bRf = createSlice(prefix + ".bRf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 6));
5100 bRc = createSlice(prefix + ".bRc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 7));
5101
5102 bWi = createReshape(prefix + ".bWi.reshape", bWi, {hiddenSize});
5103 bWo = createReshape(prefix + ".bWo.reshape", bWo, {hiddenSize});
5104 bWf = createReshape(prefix + ".bWf.reshape", bWf, {hiddenSize});
5105 bWc = createReshape(prefix + ".bWc.reshape", bWc, {hiddenSize});
5106 bRi = createReshape(prefix + ".bRi.reshape", bRi, {hiddenSize});
5107 bRo = createReshape(prefix + ".bRo.reshape", bRo, {hiddenSize});
5108 bRf = createReshape(prefix + ".bRf.reshape", bRf, {hiddenSize});
5109 bRc = createReshape(prefix + ".bRc.reshape", bRc, {hiddenSize});
5110 }
5111
5112 // Create P slices (optional).
5113 NodeValue Pi = nullptr;
5114 NodeValue Po = nullptr;
5115 NodeValue Pf = nullptr;
5116
5117 if (P) {
5118
5119 Pi = createSlice(prefix + ".Pi.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 0));
5120 Po = createSlice(prefix + ".Po.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 1));
5121 Pf = createSlice(prefix + ".Pf.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 2));
5122
5123 // Repeat P slices to match [batchSize, hiddenSize].
5124 Pi = createTile(prefix + ".Pi.repeat", Pi, batchSize, 0);
5125 Po = createTile(prefix + ".Po.repeat", Po, batchSize, 0);
5126 Pf = createTile(prefix + ".Pf.repeat", Pf, batchSize, 0);
5127 }
5128
5129 // Create H slice for this direction.
5130 Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
5131 LSTM_H_SLICE_RANGE(sliceIdx0));
5132 Hinit =
5133 createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
5134
5135 // Create C slice for this direction.
5136 Node *Cinit = createSlice(prefix + ".C.slice", initial_c,
5137 LSTM_C_SLICE_RANGE(sliceIdx0));
5138 Cinit =
5139 createReshape(prefix + ".C.reshape", Cinit, {batchSize, hiddenSize});
5140
5141 // Initialize.
5142 Node *Ht = Hinit;
5143 Node *Ct = Cinit;
5144
5145 // Unroll LSTM cell for all time steps.
5146 for (size_t t = 0; t < seqLength; t++) {
5147
5148 // Input for current time step.
5149 // For the reverse LSTM cell the inputs are provided in reverse order.
5150 Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
5151
5152 // Forget gate: ft = f(Xt * Wf + bWf + Ht-1 * Rf + bRf + Pf . Ct-1).
5153 Node *ft = createAdd(prefix + ".F.add1",
5154 LSTM_CREATE_FC(prefix + ".F.fc1", Xt, Wf, bWf),
5155 LSTM_CREATE_FC(prefix + ".F.fc2", Ht, Rf, bRf));
5156 if (Pf) {
5157 ft = createAdd(prefix + ".F.add2", ft,
5158 createMul(prefix + ".F.mult", Pf, Ct));
5159 }
5160 ft = activationF(prefix + ".F.act", ft);
5161
5162 // Cell state candidate: ctild = g(Xt * Wc + bWc + Ht-1 * Rc + bRc).
5163 Node *ctild =
5164 createAdd(prefix + ".Ctild.add",
5165 LSTM_CREATE_FC(prefix + ".Ctild.fc1", Xt, Wc, bWc),
5166 LSTM_CREATE_FC(prefix + ".Ctild.fc2", Ht, Rc, bRc));
5167 ctild = activationG(prefix + ".Ctild.act", ctild);
5168
5169 // Input gate:
5170 // For inputForget == true:
5171 // it = 1 - ft.
5172 // For inputForget == false:
5173 // it = f(Xt * Wi + bWi + Ht-1 * Ri + bRi + Pi . Ct-1).
5174 Node *it;
5175 if (inputForget) {
5176 auto splatTy = ft->getNthResult(0).getType();
5177 it = createSub(prefix + ".I.sub",
5178 createSplat(prefix + ".I.splat", splatTy, 1.0), ft);
5179 } else {
5180 it = createAdd(prefix + ".I.add1",
5181 LSTM_CREATE_FC(prefix + ".I.fc1", Xt, Wi, bWi),
5182 LSTM_CREATE_FC(prefix + ".I.fc2", Ht, Ri, bRi));
5183 if (Pi) {
5184 it = createAdd(prefix + ".I.add2", it,
5185 createMul(prefix + ".I.mult", Pi, Ct));
5186 }
5187 it = activationF(prefix + ".I.act", it);
5188 }
5189
5190 // Cell state update: Ct = ft . Ct-1 + it . ctild.
5191 Ct = createAdd(prefix + ".C.add", createMul(prefix + ".C.mult1", ft, Ct),
5192 createMul(prefix + ".C.mult2", it, ctild));
5193
5194 // Output gate: ot = f(Xt * Wo + bWo + Ht-1 * Ro + bRo + Po . Ct).
5195 Node *ot = createAdd(prefix + ".O.add1",
5196 LSTM_CREATE_FC(prefix + ".O.fc1", Xt, Wo, bWo),
5197 LSTM_CREATE_FC(prefix + ".O.fc2", Ht, Ro, bRo));
5198 if (Po) {
5199 ot = createAdd(prefix + ".O.add2", ot,
5200 createMul(prefix + ".O.mult", Po, Ct));
5201 }
5202 ot = activationF(prefix + ".O.act", ot);
5203
5204 // Hidden state update: Ht = ot . h(Ct).
5205 Ht =
5206 createMul(prefix + ".H.mult", ot, activationH(prefix + ".H.act", Ct));
5207
5208 // Output.
5209 Yslices.push_back(Ht);
5210 }
5211
5212 // Updated states nodes.
5213 Hslice = Ht;
5214 Cslice = Ct;
5215 }; // End of local lambda "loadLSTMCell".
5216
5217 bool forwardEnabled = ((direction == RnnDirection::Forward) ||
5218 (direction == RnnDirection::Bidirectional));
5219 bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
5220 (direction == RnnDirection::Bidirectional));
5221
5222 std::vector<NodeValue> YSlices;
5223 std::vector<NodeValue> Hslices;
5224 std::vector<NodeValue> Cslices;
5225
5226 // Load forward LSTM.
5227 std::vector<NodeValue> forwardYslices;
5228 if (forwardEnabled) {
5229 NodeValue forwardHslice;
5230 NodeValue forwardCslice;
5231 loadLSTMCell(/* forward */ true, forwardYslices, forwardHslice,
5232 forwardCslice);
5233 Hslices.push_back(forwardHslice);
5234 Cslices.push_back(forwardCslice);
5235 }
5236
5237 // Load backward LSTM.
5238 std::vector<NodeValue> backwardYslices;
5239 if (backwardEnabled) {
5240 NodeValue backwardHslice;
5241 NodeValue backwardCslice;
5242 loadLSTMCell(/* forward */ false, backwardYslices, backwardHslice,
5243 backwardCslice);
5244 Hslices.push_back(backwardHslice);
5245 Cslices.push_back(backwardCslice);
5246 }
5247
5248 // Gather Y slices.
5249 for (size_t t = 0; t < seqLength; t++) {
5250 if (forwardEnabled) {
5251 YSlices.push_back(forwardYslices[t]);
5252 }
5253 if (backwardEnabled) {
5254 YSlices.push_back(backwardYslices[seqLength - 1 - t]);
5255 }
5256 }
5257
5258 // Concatenate Y slices.
5259 // Y size is [seqLength, numDirections, batchSize, hiddenSize].
5260 Y = createReshape(opName + ".Y.reshape",
5261 createConcat(opName + ".Y.concat", YSlices, 0),
5262 {seqLength, numDirections, batchSize, hiddenSize});
5263
5264 // Concatenate Y_h slices.
5265 // Y_h size is [numDirections, batchSize, hiddenSize].
5266 Y_h = createReshape(opName + ".Y_h.reshape",
5267 createConcat(opName + ".Y_h.concat", Hslices, 0),
5268 {numDirections, batchSize, hiddenSize});
5269
5270 // Concatenate Y_c slices.
5271 // Y_c size is [numDirections, batchSize, hiddenSize].
5272 Y_c = createReshape(opName + ".Y_c.reshape",
5273 createConcat(opName + ".Y_c.concat", Cslices, 0),
5274 {numDirections, batchSize, hiddenSize});
5275
5276#undef LSTM_X_SLICE_RANGE
5277#undef LSTM_H_SLICE_RANGE
5278#undef LSTM_C_SLICE_RANGE
5279#undef LSTM_W_SLICE_RANGE
5280#undef LSTM_R_SLICE_RANGE
5281#undef LSTM_B_SLICE_RANGE
5282#undef LSTM_P_SLICE_RANGE
5283#undef LSTM_CREATE_FC
5284}
5285
5286TraceEventNode *Function::createTraceEvent(llvm::StringRef eventName,
5287 llvm::StringRef eventType,
5288 Node *data, unsigned index) {
5289 std::string name = (getName() + "_" + eventName + "_instrumentation").str();
5290 return addNode(
5291 new TraceEventNode(name, data, eventName.str(), eventType.str(), index));
5292}
5293
5294NonMaxSuppressionNode *Function::createNonMaxSuppressionV4(
5295 llvm::StringRef name, NodeValue boxes, NodeValue scores,
5296 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
5297 float scoreThreshold, ElemKind elTy) {
5298 // V4
5299 // Class/Score [BatchNum][BoxNum]
5300 // Boxes [BatdhNum][BoxNum][4]
5301 // Result [BatchNum*MaxOutputPerBatch]
5302 // NumberOfIndicesDetected [BatchNum*MaxOutputPerBatch]
5303 auto scoresDim = scores.dims();
5304 int scoresBoxDim = scoresDim.size() - 1;
5305 if (maxOutputBoxesPerClass == 0) {
5306 maxOutputBoxesPerClass = scoresDim[scoresBoxDim];
5307 }
5308
5309 // Allocating maximum because we don't know how many boxes will actually be
5310 // detected.
5311 std::vector<dim_t> newDim = {static_cast<dim_t>(maxOutputBoxesPerClass)};
5312 auto indicesTy = getParent()->uniqueType(elTy, newDim);
5313 auto numberOfSelectedIndicesTy = getParent()->uniqueType(
5314 elTy, {static_cast<dim_t>(maxOutputBoxesPerClass)});
5315 return addNode(new NonMaxSuppressionNode(
5316 name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
5317 maxOutputBoxesPerClass, iouThreshold, scoreThreshold, true));
5318}
5319
5320NonMaxSuppressionNode *
5321Function::createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes,
5322 NodeValue scores, int64_t centerPointBox,
5323 int64_t maxOutputBoxesPerClass,
5324 float iouThreshold, float scoreThreshold) {
5325 return createNonMaxSuppressionV4(name, boxes, scores, centerPointBox,
5326 maxOutputBoxesPerClass, iouThreshold,
5327 scoreThreshold, ElemKind::Int64ITy);
5328}
5329
5330NonMaxSuppressionNode *Function::createNonMaxSuppressionV4(
5331 llvm::StringRef name, NodeValue boxes, NodeValue scores,
5332 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
5333 float scoreThreshold, TypeRef indicesTy,
5334 TypeRef numberOfSelectedIndicesTy) {
5335 assert(maxOutputBoxesPerClass > 0 && "Invalid maxOutputBoxesPerClass.");
5336
5337 return addNode(new NonMaxSuppressionNode(
5338 name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
5339 maxOutputBoxesPerClass, iouThreshold, scoreThreshold, true));
5340}
5341
5342NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX(
5343 llvm::StringRef name, NodeValue boxes, NodeValue scores,
5344 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
5345 float scoreThreshold, ElemKind elTy) {
5346 // ONNX
5347 // Class/Score [BatchNum][ClassNum][BoxNum]
5348 // Box [BatchNum][BoxNum][4]
5349 // Result [BatchNum*MaxOutputPerBatch][3]
5350 auto boxesDim = boxes.dims();
5351 auto scoresDim = scores.dims();
5352 int scoresBoxDim = scoresDim.size() - 1;
5353 int scoresClassDim = scoresDim.size() - 2;
5354 int scoresBatchDim = scoresDim.size() - 3;
5355 int boxesBatchDim = boxesDim.size() - 3;
5356 if (maxOutputBoxesPerClass == 0) {
5357 maxOutputBoxesPerClass = scoresDim[scoresBoxDim];
5358 }
5359
5360 // allocating maximum because we don't know how many boxes will actually be
5361 // detected.
5362 std::vector<dim_t> newDim = {scoresDim[scoresBatchDim] *
5363 scoresDim[scoresClassDim] *
5364 static_cast<dim_t>(maxOutputBoxesPerClass),
5365 3};
5366 auto indicesTy = getParent()->uniqueType(elTy, newDim);
5367 auto numberOfSelectedIndicesTy = getParent()->uniqueType(
5368 elTy,
5369 {boxesDim[boxesBatchDim] * static_cast<dim_t>(maxOutputBoxesPerClass)});
5370 return addNode(new NonMaxSuppressionNode(
5371 name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
5372 maxOutputBoxesPerClass, iouThreshold, scoreThreshold, false));
5373}
5374
5375NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX(
5376 llvm::StringRef name, NodeValue boxes, NodeValue scores,
5377 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
5378 float scoreThreshold) {
5379 return createNonMaxSuppressionONNX(name, boxes, scores, centerPointBox,
5380 maxOutputBoxesPerClass, iouThreshold,
5381 scoreThreshold, ElemKind::Int64ITy);
5382}
5383
5384NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX(
5385 llvm::StringRef name, NodeValue boxes, NodeValue scores,
5386 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
5387 float scoreThreshold, TypeRef indicesTy) {
5388 auto boxesDim = boxes.dims();
5389 assert(maxOutputBoxesPerClass > 0 && "Invalid maxOutputBoxesPerClass.");
5390
5391 // allocating maximum because we don't know how many boxes will actually be
5392 // detected.
5393 auto numberOfSelectedIndicesTy = getParent()->uniqueType(
5394 ElemKind::Int32ITy, {1, 1, 1,
5395 boxesDim[boxesDim.size() - 2] *
5396 static_cast<dim_t>(maxOutputBoxesPerClass)});
5397 return addNode(new NonMaxSuppressionNode(
5398 name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
5399 maxOutputBoxesPerClass, iouThreshold, scoreThreshold, false));
5400}
5401
5402TFLiteDetectionPostProcessNode *Function::createTFLiteDetectionPostProcess(
5403 llvm::StringRef name, NodeValue boxes, NodeValue scores, NodeValue anchors,
5404 int32_t numClasses, int32_t maxDetections, int32_t maxClassesPerDetection,
5405 int32_t maxDetectionsPerClass, float iouThreshold, float scoreThreshold,
5406 float xScale, float yScale, float hScale, float wScale, bool regularNMS) {
5407
5408 // Maximum number of detections depending on fast/regular method.
5409 dim_t numBoxes = anchors.dims()[0];
5410 dim_t fastMaxDetections =
5411 std::min(numBoxes, static_cast<dim_t>(maxDetections));
5412 dim_t regularMaxDetections = maxDetections;
5413 dim_t numMaxDetections =
5414 regularNMS ? regularMaxDetections : fastMaxDetections;
5415
5416 // Create output types. We allocate enough size for the worst possible case
5417 // when the maximum number of detections is obtained.
5418 std::vector<dim_t> detectionBoxesDims = {static_cast<dim_t>(numMaxDetections),
5419 4};
5420 TypeRef detectionBoxesTy =
5421 getParent()->uniqueType(ElemKind::FloatTy, detectionBoxesDims);
5422 std::vector<dim_t> detectionClassesDims = {
5423 static_cast<dim_t>(numMaxDetections)};
5424 TypeRef detectionClassesTy =
5425 getParent()->uniqueType(ElemKind::Int32ITy, detectionClassesDims);
5426 std::vector<dim_t> detectionScoresDims = {
5427 static_cast<dim_t>(numMaxDetections)};
5428 TypeRef detectionScoresTy =
5429 getParent()->uniqueType(ElemKind::FloatTy, detectionScoresDims);
5430 TypeRef numDetectionsTy = getParent()->uniqueType(ElemKind::Int32ITy, {1});
5431
5432 // Dequantize inputs if quantized.
5433 if (boxes.getType()->isQuantizedType()) {
5434 boxes = createDequantize(name.str() + ".dequant.boxes", boxes,
5435 ElemKind::FloatTy);
5436 }
5437 if (scores.getType()->isQuantizedType()) {
5438 scores = createDequantize(name.str() + ".dequant.scores", scores,
5439 ElemKind::FloatTy);
5440 }
5441 if (anchors.getType()->isQuantizedType()) {
5442 anchors = createDequantize(name.str() + ".dequant.anchors", anchors,
5443 ElemKind::FloatTy);
5444 }
5445
5446 // Create node.
5447 return addNode(new TFLiteDetectionPostProcessNode(
5448 name, detectionBoxesTy, detectionClassesTy, detectionScoresTy,
5449 numDetectionsTy, boxes, scores, anchors, numClasses, maxDetections,
5450 maxClassesPerDetection, maxDetectionsPerClass, iouThreshold,
5451 scoreThreshold, xScale, yScale, hScale, wScale, regularNMS));
5452}
5453
5454Constant *Function::createCosineWindow(llvm::StringRef name, dim_t length) {
5455 auto window = getParent()->createConstant(ElemKind::FloatTy, {length}, name);
5456 auto windowH = window->getHandle<float>();
5457 for (dim_t n = 0; n < length; n++) {
5458 windowH.raw(n) =
5459 0.5 - 0.5 * cos(2.0 * M_PI * (double)(n) / (double)(length));
5460 }
5461 return window;
5462}
5463
5464Constant *Function::createFFTTwiddleFactors(llvm::StringRef name,
5465 dim_t fftLength) {
5466 auto twiddleFactors =
5467 getParent()->createConstant(ElemKind::FloatTy, {2 * fftLength}, name);
5468 auto twiddleFactorsH = twiddleFactors->getHandle<float>();
5469 for (dim_t k = 0; k < fftLength; k++) {
5470 twiddleFactorsH.raw(2 * k + 0) =
5471 cos(2.0 * M_PI * (double)(k) / (double)(fftLength));
5472 twiddleFactorsH.raw(2 * k + 1) =
5473 -sin(2.0 * M_PI * (double)(k) / (double)(fftLength));
5474 }
5475 return twiddleFactors;
5476}
5477
5478Constant *Function::createFFTBitReverseIndices(llvm::StringRef name,
5479 dim_t fftLength) {
5480 assert(fftLength >= 1 && "FFT length must be at least 1!");
5481 // Local function to reverse the bits of a number.
5482 auto reverseBits = [](uint64_t bits, dim_t numBits) -> uint64_t {
5483 assert(((0 <= numBits) && (numBits <= 64)) &&
5484 "Maximum number of bits exceeded for 'reverseBits' function!");
5485 if (numBits <= 0) {
5486 return 0;
5487 }
5488 uint64_t bitsRev = 0;
5489 uint64_t bitsMask = 1;
5490 uint64_t bitsRevMask = 1 << (numBits - 1);
5491 for (dim_t idx = 0; idx < numBits; idx++) {
5492 if (bits & bitsMask) {
5493 bitsRev |= bitsRevMask;
5494 }
5495 bitsMask <<= 1;
5496 bitsRevMask >>= 1;
5497 }
5498 return bitsRev;
5499 };
5500 auto bitReverseIndices =
5501 getParent()->createConstant(ElemKind::Int32ITy, {fftLength}, name);
5502 auto bitReverseIndicesH = bitReverseIndices->getHandle<int32_t>();
5503 dim_t numBits = std::log2((double)fftLength);
5504 for (dim_t idx = 0; idx < fftLength; idx++) {
5505 bitReverseIndicesH.raw(idx) =
5506 static_cast<int32_t>(reverseBits(idx, numBits));
5507 }
5508 return bitReverseIndices;
5509}
5510
5511Constant *Function::createFFTComplexToRealWeights(llvm::StringRef name,
5512 dim_t fftLength,
5513 dim_t outLength) {
5514 auto complexToRealWeights =
5515 getParent()->createConstant(ElemKind::FloatTy, {2 * outLength}, name);
5516 auto complexToRealWeightsH = complexToRealWeights->getHandle<float>();
5517 for (dim_t k = 0; k < outLength; k++) {
5518 complexToRealWeightsH.raw(2 * k + 0) =
5519 0.5 * (1 - sin(2.0 * M_PI * (double)(k) / (double)(fftLength)));
5520 complexToRealWeightsH.raw(2 * k + 1) =
5521 -0.5 * cos(2.0 * M_PI * (double)(k) / (double)(fftLength));
5522 }
5523 return complexToRealWeights;
5524}
5525
5526AudioSpectrogramNode *Function::createAudioSpectrogram(llvm::StringRef name,
5527 NodeValue input,
5528 int64_t windowSize,
5529 int64_t windowStride,
5530 bool magnitudeSquared) {
5531 // Output shape will be windowCount x (fftLength / 2 + 1).
5532 dim_t inputLength = input.getType()->size();
5533 dim_t windowCount = std::floor((inputLength - windowSize) / windowStride) + 1;
5534 dim_t fftLength = 1 << (dim_t)std::ceil(std::log2((double)windowSize));
5535 auto spectrogramTy = getParent()->uniqueType(
5536 ElemKind::FloatTy, {windowCount, fftLength / 2 + 1});
5537
5538 // Create a cosine FFT windowing function.
5539 auto window = createCosineWindow(std::string(name) + ".Window", windowSize);
5540
5541 // Create the FFT weights for a fftLength/2 complex FFT.
5542 auto twiddleFactors = createFFTTwiddleFactors(
5543 std::string(name) + ".TwiddleFactors", fftLength / 2);
5544 auto bitReverseIndices = createFFTBitReverseIndices(
5545 std::string(name) + ".BitReverseIndices", fftLength / 2);
5546
5547 // Create the complex to real FFT mapping coefficients.
5548 // For small FFT length make sure to generate at least 1 coefficient.
5549 auto complexToRealWeights = createFFTComplexToRealWeights(
5550 std::string(name) + ".ComplexToRealWeights", fftLength,
5551 (fftLength / 4) >= 1 ? (fftLength / 4) : 1);
5552
5553 // Create AudioSpectrogram node.
5554 return addNode(new AudioSpectrogramNode(
5555 name, spectrogramTy, input, window, twiddleFactors, bitReverseIndices,
5556 complexToRealWeights, windowSize, windowStride, magnitudeSquared));
5557}
5558
5559void Function::createMelWeights(llvm::StringRef prefix, dim_t spectrogramLength,
5560 float sampleRate, float lowerFrequency,
5561 float upperFrequency, dim_t filterBankCount,
5562 Constant *&melWeights, Constant *&melRanges) {
5563 auto fftLength = 2 * (spectrogramLength - 1);
5564 dim_t numFreqBins = fftLength / 2;
5565 dim_t numMelBins = filterBankCount;
5566
5567 // Mel frequency scale local lambda function.
5568 auto melFreqScale = [](float freq) -> float {
5569 return 1127.0f * logf(1.0f + freq / 700.0f);
5570 };
5571
5572 // Always exclude DC (TensorFlow implementation choice from HTK).
5573 float freqDelta = sampleRate / (float)(fftLength);
5574 dim_t freqIdxMin = (dim_t)(1.5 + (lowerFrequency / freqDelta));
5575 dim_t freqIdxMax = (dim_t)(upperFrequency / freqDelta);
5576 freqIdxMax = (freqIdxMax >= numFreqBins) ? numFreqBins : freqIdxMax;
5577
5578 // Create Mel ranges constant.
5579 melRanges = getParent()->createConstant(ElemKind::Int32ITy, {2 * numMelBins},
5580 std::string(prefix) + ".MelRanges");
5581 auto melRangesH = melRanges->getHandle<int32_t>();
5582
5583 // Mel weights and frequency start/stop (inclusive) buffers.
5584 auto melBinFreqWeights = std::make_unique<float[]>(numMelBins * numFreqBins);
5585 dim_t melBinFreqWeightsNum = 0;
5586
5587 // Mel frequency limits.
5588 float melFreqLower = melFreqScale(lowerFrequency);
5589 float melFreqUpper = melFreqScale(upperFrequency);
5590 float melFreqDelta = (melFreqUpper - melFreqLower) / (numMelBins + 1);
5591 for (dim_t melIdx = 0; melIdx < numMelBins; melIdx++) {
5592
5593 float melFreqLeft = melFreqLower + (melIdx + 0) * melFreqDelta;
5594 float melFreqCenter = melFreqLower + (melIdx + 1) * melFreqDelta;
5595 float melFreqRight = melFreqLower + (melIdx + 2) * melFreqDelta;
5596
5597 int32_t freqIdxStart = -1;
5598 int32_t freqIdxStop = -2;
5599
5600 for (dim_t freqIdx = freqIdxMin; freqIdx <= freqIdxMax; freqIdx++) {
5601 float melFreq = melFreqScale(freqIdx * freqDelta);
5602 if ((melFreqLeft < melFreq) && (melFreq < melFreqRight)) {
5603
5604 // Compute frequency bin weight for this Mel bin.
5605 float weight = 1.0f - std::abs(melFreq - melFreqCenter) / melFreqDelta;
5606
5607 // Store the frequency bin weight.
5608 melBinFreqWeights[melBinFreqWeightsNum++] = weight;
5609
5610 // Update frequency bin start/stop index.
5611 if (freqIdxStart == -1) {
5612 freqIdxStart = freqIdx;
5613 }
5614 freqIdxStop = freqIdx;
5615 }
5616 }
5617
5618 // Store the frequency bin start/stop index.
5619 melRangesH.raw(2 * melIdx + 0) = freqIdxStart;
5620 melRangesH.raw(2 * melIdx + 1) = freqIdxStop;
5621 }
5622
5623 // Validate Mel ranges.
5624 dim_t melBinFreqWeightsNumValidate = 0;
5625 for (dim_t melIdx = 0; melIdx < numMelBins; melIdx++) {
5626 int32_t freqIdxRange =
5627 melRangesH.raw(2 * melIdx + 1) - melRangesH.raw(2 * melIdx + 0) + 1;
5628 melBinFreqWeightsNumValidate += freqIdxRange;
5629 }
5630 assert(melBinFreqWeightsNum == melBinFreqWeightsNumValidate &&
5631 "Invalid Mel ranges");
5632
5633 // Create Mel weights constant.
5634 melWeights =
5635 getParent()->createConstant(ElemKind::FloatTy, {melBinFreqWeightsNum},
5636 std::string(prefix) + ".MelWeights");
5637 auto melWeightsH = melWeights->getHandle<float>();
5638 for (dim_t idx = 0; idx < melBinFreqWeightsNum; idx++) {
5639 melWeightsH.raw(idx) = melBinFreqWeights[idx];
5640 }
5641}
5642
5643Constant *Function::createDCTMat(llvm::StringRef name, dim_t N, dim_t K) {
5644 Constant *dctMat =
5645 getParent()->createConstant(ElemKind::FloatTy, {K, N}, name);
5646 auto dctMatH = dctMat->getHandle<float>();
5647 float dctFact = (float)sqrt(2.0 / (double)(N));
5648 for (dim_t k = 0; k < K; k++) {
5649 for (dim_t n = 0; n < N; n++) {
5650 dctMatH.at({k, n}) =
5651 dctFact * cos(M_PI / (double)(N) * ((double)(n) + 0.5) * (double)(k));
5652 }
5653 }
5654 return dctMat;
5655}
5656
5657MFCCNode *Function::createMFCC(llvm::StringRef name, NodeValue spectrogram,
5658 float sampleRate, float lowerFrequency,
5659 float upperFrequency, int64_t filterBankCount,
5660 int64_t numCoefficients) {
5661 // Create the Mel weights.
5662 dim_t spectrogramLength = spectrogram.dims()[1];
5663 Constant *melWeights;
5664 Constant *melRanges;
5665 createMelWeights(name, spectrogramLength, sampleRate, lowerFrequency,
5666 upperFrequency, filterBankCount, melWeights, melRanges);
5667
5668 // Create the DCT transform matrix.
5669 Constant *dctMat = createDCTMat(std::string(name) + ".DCTMat",
5670 filterBankCount, numCoefficients);
5671
5672 // Output shape will be windowCount x numCoefficients.
5673 dim_t windowCount = spectrogram.dims()[0];
5674 auto coefficientsTy = getParent()->uniqueType(
5675 ElemKind::FloatTy, {windowCount, static_cast<dim_t>(numCoefficients)});
5676
5677 // Create MFCC node.
5678 return addNode(new MFCCNode(name, coefficientsTy, spectrogram, melWeights,
5679 melRanges, dctMat, sampleRate, lowerFrequency,
5680 upperFrequency, filterBankCount,
5681 numCoefficients));
5682}
5683
5684ROIAlignNode *
5685Function::createROIAlign(llvm::StringRef name, NodeValue featureMap,
5686 NodeValue boxes, NodeValue batchIndices,
5687 uint32_t outputHeight, uint32_t outputWidth,
5688 uint32_t samplingRatio, float spatialScale,
5689 bool aligned, bool rotated, PoolingMode mode) {
5690 auto featureMapDims = featureMap.dims();
5691 auto boxesDims = boxes.dims();
5692 std::vector<dim_t> outDim = {boxesDims[0], outputHeight, outputWidth,
5693 featureMapDims[3]};
5694 auto outTy =
5695 getParent()->uniqueTypeWithNewShape(featureMap.getType(), outDim);
5696 return addNode(new ROIAlignNode(
5697 name, outTy, featureMap, boxes, batchIndices, mode, outputHeight,
5698 outputHeight, samplingRatio, spatialScale, aligned, rotated));
5699}
5700
5701BBoxTransformNode *Function::createBBoxTransform(
5702 llvm::StringRef name, NodeValue rois, NodeValue deltas, NodeValue imInfo,
5703 llvm::ArrayRef<float> weights, bool applyScale, bool rotated,
5704 bool angleBoundOn, int64_t angleBoundLo, int64_t angleBoundHi,
5705 float clipAngleThresh, bool legacyPlusOne) {
5706 auto deltasDims = deltas.dims();
5707 auto imInfoDims = imInfo.dims();
5708
5709 auto boxOutTy = getParent()->uniqueTypeWithNewShape(
5710 rois.getType(), {deltasDims[0], deltasDims[1]});
5711 // Forcing roiBatchSplitsTy to always be Float.
5712 auto roiBatchSplitsTy =
5713 getParent()->uniqueType(rois.getElementType(), {imInfoDims[0]});
5714
5715 return addNode(new BBoxTransformNode(
5716 name, boxOutTy, roiBatchSplitsTy, rois, deltas, imInfo, weights,
5717 applyScale, rotated, angleBoundOn, angleBoundLo, angleBoundHi,
5718 clipAngleThresh, legacyPlusOne));
5719}
5720
5721ExternalFunctionCallNode *Function::createExternalFunctionCall(
5722 llvm::StringRef name, TypeRef outTy, llvm::ArrayRef<glow::NodeValue> inputs,
5723 llvm::StringRef funcName, llvm::StringRef funcImpl,
5724 llvm::StringRef funcKind) {
5725 return addNode(new ExternalFunctionCallNode(name.str(), outTy, inputs,
5726 funcName.str(), funcImpl.str(),
5727 funcKind.str()));
5728}
5729
5730//===----------------------------------------------------------------------===//
5731// Graph dumping and printing
5732//===----------------------------------------------------------------------===//
5733
5734void Function::dump() const {
5735 llvm::outs() << "Graph structure " << getName() << ":\n";
5736 for (auto &n : nodes_) {
5737 llvm::outs() << n.getDebugDesc();
5738 }
5739}
5740
5741std::string Function::toString(bool skipUsersForStorage, bool skipName) const {
5742 std::string storage;
5743 llvm::raw_string_ostream os(storage);
5744 dump(os, skipUsersForStorage, skipName);
5745 return os.str();
5746}
5747
5748llvm::hash_code Function::getHash() const {
5749 // Omit function name when generating the hash.
5750 return llvm::hash_value(toString(/* skipUsersForStorage */ false,
5751 /* skipName */ true));
5752}
5753
5754void Function::dump(llvm::raw_ostream &os, bool skipUsersForStorage,
5755 bool skipName) const {
5756 os << "Graph structure";
5757 if (!skipName) {
5758 os << " " << getName();
5759 }
5760 os << ":\n";
5761 std::set<const Node *, SortNamed> sorted;
5762 for (const Node &n : nodes_) {
5763 sorted.insert(&n);
5764 }
5765 for (auto *n : sorted) {
5766 os << n->getDebugDesc();
5767 }
5768 for (auto *C : getNamedSorted(findConstants())) {
5769 os << C->getDebugDesc(skipUsersForStorage);
5770 }
5771 for (auto *P : getNamedSorted(findPlaceholders())) {
5772 os << P->getDebugDesc(skipUsersForStorage);
5773 }
5774}
5775
5776/// We can't use NodeWalker here, because it ignores result indices, which
5777/// are critical in generating detailed debug output.
5778class FunctionDottyPrinter : public AbstractDottyPrinter {
5779 // A set of already visited (during graph walk) nodes.
5780 std::unordered_set<Node *> visitedNodes_{};
5781
5782 /// Recursively traverses inputs of node \p N using Deep First Search.
5783 /// Each node will be visited no more than once. The method also dumps
5784 /// edges with their port identifiers in dotty format.
5785 void visitNode(Node *N) {
5786 if (visitedNodes_.find(N) != visitedNodes_.end())
5787 return;
5788 visitedNodes_.insert(N);
5789
5790 dumpNode(N, false);
5791
5792 // Print edges for the predicate field, if it's used.
5793 if (N->hasPredicate()) {
5794 auto pred = N->getPredicate();
5795 size_t resNo = pred.getResNo();
5796 std::ostringstream edge;
5797 edge << pred.getNode()->getName().str() << ":"
5798 << pred.getNode()->getOutputName(resNo).str() << " -> "
5799 << N->getName().str() << ":w";
5800 dumpEdgeStyle(N, 0, pred, edge);
5801 edges_.insert(edge.str());
5802 visitNode(pred);
5803 }
5804
5805 for (size_t i = 0; i < N->getNumInputs(); i++) {
5806 Node *to = N->getNthInput(i).getNode();
5807 size_t resNo = N->getNthInput(i).getResNo();
5808
5809 std::ostringstream edge;
5810 edge << to->getName().str() << ":" << to->getOutputName(resNo).str()
5811 << " -> " << N->getName().str() << ":" << N->getInputName(i);
5812 dumpEdgeStyle(N, i, to, edge);
5813 edges_.insert(edge.str());
5814
5815 visitNode(to);
5816 }
5817 }
5818
5819public:
5820 void visitGraph(Function *F) {
5821 // Sort nodes before printing the dot so we can diff dot files.
5822 std::set<Node *, SortNamed> sorted;
5823 for (Node &N : F->getNodes()) {
5824 sorted.insert(&N);
5825 }
5826 for (auto *N : sorted) {
5827 visitNode(N);
5828 }
5829 }
5830};
5831
5832std::string Function::dumpDAG() {
5833 llvm::SmallString<64> dotPath;
5834 llvm::sys::fs::createTemporaryFile("dotty_graph_dump", "dot", dotPath);
5835 dumpDAG(dotPath);
5836
5837 return std::string(dotPath.begin(), dotPath.end());
5838}
5839
5840/// Local utility function to convert an existing DOT file \p dotFile to
5841/// the right format suggested by the file name extension. For example if
5842/// the file name ends with ".pdf" then the file will be converted to PDF.
5843/// For conversion we use the "dot" application (assumed available) and
5844/// the conversion is done in place.
5845static void convertDotFileToRightFormat(llvm::StringRef dotFile) {
5846 static const std::vector<std::string> supportedFormats = {"pdf", "svg",
5847 "png"};
5848 // Get file extension.
5849 llvm::StringRef extWithDot = llvm::sys::path::extension(dotFile);
5850 std::string ext = extWithDot.take_back(extWithDot.size() - 1).str();
5851 // Convert to new format if supported.
5852 if (std::find(supportedFormats.begin(), supportedFormats.end(), ext) !=
5853 supportedFormats.end()) {
5854 std::string cmd =
5855 "dot -T" + ext + " " + dotFile.str() + " -o " + dotFile.str();
5856 std::string cmdErr =
5857 "Error running DOT conversion application with command '" + cmd +
5858 "'! Check that you have the 'dot' application installed on your "
5859 "system in order to convert files from DOT format to other formats! "
5860 "Otherwise choose the extension of the DOT file to '.dot'!";
5861 CHECK(!system(cmd.c_str())) << cmdErr;
5862 }
5863}
5864
5865void Function::dumpDAG(llvm::StringRef dotFilename) {
5866 llvm::StringRef legalDotFilename = dotFilename.take_back(255);
5867 llvm::outs() << "Writing dotty graph for Function to: " << legalDotFilename
5868 << '\n';
5869 if (dotFilename.size() > 255) {
5870 llvm::outs() << "WARNING: Filename " << dotFilename
5871 << " is longer than 255 characters, and so was truncated to "
5872 << legalDotFilename << '\n';
5873 }
5874
5875 FunctionDottyPrinter DP;
5876
5877 DP.visitGraph(this);
5878
5879 std::ofstream myfile;
5880 myfile.open(legalDotFilename.str());
5881 if (myfile.fail()) {
5882 LOG(ERROR) << "Unable to open " << legalDotFilename.str()
5883 << ", reason: " << strerror(errno);
5884 } else {
5885 DP.dumpAll(myfile);
5886 }
5887 myfile.close();
5888 convertDotFileToRightFormat(legalDotFilename);
5889}
5890
5891void Function::dumpDAG(const char *dotFilename) {
5892 dumpDAG(llvm::StringRef(dotFilename));
5893}
5894
5895Node *Function::getNodeByName(llvm::StringRef name) {
5896 for (auto &N : getNodes()) {
5897 if (N.getName().equals(name)) {
5898 return &N;
5899 }
5900 }
5901 return nullptr;
5902}
5903
5904NodeValue Function::getNodeValueByName(llvm::StringRef name) {
5905 auto strPair = name.split(':');
5906 // Search node, constant or placeholder.
5907 auto nodeName = strPair.first;
5908 Node *node = getNodeByName(nodeName);
5909 node = node ? node : getParent()->getConstantByName(nodeName);
5910 node = node ? node : getParent()->getPlaceholderByNameSlow(nodeName);
5911 if (!node || (node->getNumResults() == 0)) {
5912 return NodeValue();
5913 }
5914 // Get result number.
5915 if (node->getNumResults() == 1) {
5916 return NodeValue(node);
5917 } else {
5918 unsigned resNo = 0;
5919 CHECK(!strPair.second.getAsInteger(0, resNo)) << "Invalid node value name!";
5920 return NodeValue(node, resNo);
5921 }
5922}
5923
5924void Module::eraseConstant(ConstList::iterator I) {
5925 if (I == constants_.end())
5926 return;
5927 logStorageDeletion(functions_, *I);
5928 delete *I;
5929 constants_.erase(I);
5930}
5931
5932void Module::erasePlaceholder(PlaceholderList::iterator I) {
5933 if (I == placeholders_.end()) {
5934 return;
5935 }
5936
5937 logStorageDeletion(functions_, *I);
5938 delete *I;
5939 placeholders_.erase(I);
5940}
5941
5942void Function::eraseNode(NodesList::iterator I) {
5943 // Log node deletion.
5944 logCtx_->logNodeDeletion(*I);
5945
5946 nodes_.erase(I);
5947}
5948
5949Constant *Module::getConstantByName(llvm::StringRef name) const {
5950 for (auto *V : getConstants()) {
5951 if (V->getName() == name)
5952 return V;
5953 }
5954 return nullptr;
5955}
5956
5957void Function::randomizeConstants(
5958 const std::map<Kinded::Kind, std::set<unsigned>> &ignoredConstants) {
5959 LOG(INFO) << "Randomize Constants............";
5960 for (Constant *c : getParent()->getConstants()) {
5961 bool usedHere = false;
5962 bool usedElsewhere = false;
5963 bool ignored = false;
5964
5965 for (auto &user : c->getUsers()) {
5966 auto *nodeUser = user.getUser();
5967 if (nodeUser->getParent() == this) {
5968 usedHere = true;
5969 } else {
5970 usedElsewhere = true;
5971 }
5972
5973 auto kind = nodeUser->getKind();
5974 if (ignoredConstants.count(kind)) {
5975 for (auto idx : ignoredConstants.at(kind)) {
5976 if (nodeUser->getNthInput(idx).getNode() == c) {
5977 ignored = true;
5978 break;
5979 }
5980 }
5981 }
5982 }
5983
5984 if (!usedHere) {
5985 continue;
5986 }
5987
5988 if (usedElsewhere) {
5989 LOG(FATAL) << "Can't randomize Constant \"" << c->getName().str()
5990 << "\" because it is used by another function";
5991 }
5992
5993 if (ignored) {
5994 continue;
5995 }
5996
5997 auto &payload = c->getPayloadMutable();
5998
5999 switch (c->getElementType()) {
6000 case ElemKind::FloatTy: {
6001 auto H = payload.getHandle<float>();
6002 auto minMaxArg = H.minMaxArg();
6003 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
6004 break;
6005 }
6006 case ElemKind::Float16Ty: {
6007 auto H = payload.getHandle<float16_t>();
6008 auto minMaxArg = H.minMaxArg();
6009 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
6010 break;
6011 }
6012 case ElemKind::BFloat16Ty: {
6013 auto H = payload.getHandle<bfloat16_t>();
6014 auto minMaxArg = H.minMaxArg();
6015 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
6016 break;
6017 }
6018 case ElemKind::Int8QTy: {
6019 auto H = payload.getHandle<int8_t>();
6020 auto minMaxArg = H.minMaxArg();
6021 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
6022 break;
6023 }
6024 case ElemKind::UInt8QTy: {
6025 auto H = payload.getHandle<uint8_t>();
6026 auto minMaxArg = H.minMaxArg();
6027 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
6028 break;
6029 }
6030 case ElemKind::Int16QTy: {
6031 auto H = payload.getHandle<int16_t>();
6032 auto minMaxArg = H.minMaxArg();
6033 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
6034 break;
6035 }
6036 case ElemKind::Int32QTy: {
6037 auto H = payload.getHandle<int32_t>();
6038 auto minMaxArg = H.minMaxArg();
6039 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
6040 break;
6041 }
6042 case ElemKind::Int32ITy: {
6043 auto H = payload.getHandle<int32_t>();
6044 auto minMaxArg = H.minMaxArg();
6045 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
6046 break;
6047 }
6048 case ElemKind::Int64ITy: {
6049 auto H = payload.getHandle<int64_t>();
6050 auto minMaxArg = H.minMaxArg();
6051 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
6052 break;
6053 }
6054 case ElemKind::UInt8FusedQTy:
6055 payload.getHandle<uint8_t>().randomize(
6056 std::numeric_limits<uint8_t>::lowest(),
6057 std::numeric_limits<uint8_t>::max(), getPRNG());
6058 break;
6059 case ElemKind::UInt8FusedFP16QTy:
6060 payload.getHandle<uint8_t>().randomize(
6061 std::numeric_limits<uint8_t>::lowest(),
6062 std::numeric_limits<uint8_t>::max(), getPRNG());
6063 break;
6064 case ElemKind::UInt4FusedFP16QTy:
6065 payload.getHandle<uint8_t>().randomize(
6066 std::numeric_limits<uint8_t>::lowest(),
6067 std::numeric_limits<uint8_t>::max(), getPRNG());
6068 break;
6069 case ElemKind::UInt4FusedQTy:
6070 payload.getHandle<uint8_t>().randomize(
6071 std::numeric_limits<uint8_t>::lowest(),
6072 std::numeric_limits<uint8_t>::max(), getPRNG());
6073 break;
6074 case ElemKind::BoolTy:
6075 payload.getHandle<bool>().randomize(false, true, getPRNG());
6076 break;
6077 default:
6078 LOG(FATAL) << "Unsupported ElemKind";
6079 }
6080 }
6081}
6082
6083Placeholder *Module::getPlaceholderByNameSlow(llvm::StringRef name) const {
6084 for (auto *P : getPlaceholders()) {
6085 if (P->getName() == name) {
6086 return P;
6087 }
6088 }
6089
6090 return nullptr;
6091}
6092
6093void Module::eraseConstant(Constant *N) {
6094 auto &vars = getConstants();
6095 auto I = std::find(vars.begin(), vars.end(), N);
6096 eraseConstant(I);
6097}
6098
6099void Function::eraseNode(Node *N) {
6100 if (Constant *V = dyn_cast<Constant>(N)) {
6101 return getParent()->eraseConstant(V);
6102 }
6103 assert(std::find_if(nodes_.begin(), nodes_.end(),
6104 [N](const Node &node) -> bool { return &node == N; }) !=
6105 nodes_.end() &&
6106 "Could not find node to delete!");
6107 eraseNode(N->getIterator());
6108}
6109
6110PlaceholderList Function::findPlaceholders() {
6111 PlaceholderList list;
6112 for (auto &PH : parent_->getPlaceholders()) {
6113 for (auto &user : PH->getUsers()) {
6114 if (user.getUser()->getParent() == this) {
6115 list.push_back(PH);
6116 break;
6117 }
6118 }
6119 }
6120 return list;
6121}
6122
6123PlaceholderList Function::findPlaceholders() const {
6124 PlaceholderList list;
6125 for (auto &PH : parent_->getPlaceholders()) {
6126 for (auto &user : PH->getUsers()) {
6127 if (user.getUser()->getParent() == this) {
6128 list.push_back(PH);
6129 break;
6130 }
6131 }
6132 }
6133 return list;
6134}
6135
6136ConstList Function::findConstants() {
6137 ConstList list;
6138 for (auto &constant : parent_->getConstants()) {
6139 for (auto &user : constant->getUsers()) {
6140 if (user.getUser()->getParent() == this) {
6141 list.push_back(constant);
6142 break;
6143 }
6144 }
6145 }
6146 return list;
6147}
6148
6149ConstList Function::findConstants() const {
6150 ConstList list;
6151 for (auto &constant : parent_->getConstants()) {
6152 for (auto &user : constant->getUsers()) {
6153 if (user.getUser()->getParent() == this) {
6154 list.push_back(constant);
6155 break;
6156 }
6157 }
6158 }
6159 return list;
6160}
6161
6162Function *Function::clone(llvm::StringRef newName,
6163 llvm::DenseMap<const Node *, Node *> *map,
6164 llvm::DenseMap<const Node *, Node *> *currToNewMap) {
6165 Module *M = getParent();
6166 auto *newF = M->createFunction(newName);
6167 return clone(newF, map, currToNewMap);
6168}
6169
6170Function *
6171Function::clone(Function *newF, llvm::DenseMap<const Node *, Node *> *map,
6172 llvm::DenseMap<const Node *, Node *> *currToNewMap) const {
6173 // Maps current nodes to new nodes.
6174 llvm::DenseMap<const Node *, Node *> currToNew;
6175
6176 // Initialize the map from a user-provided map.
6177 if (currToNewMap) {
6178 currToNew.insert(currToNewMap->begin(), currToNewMap->end());
6179 }
6180
6181 // Clone all of the nodes in the function.
6182 for (auto &N : getNodes()) {
6183 Node *copy = N.clone();
6184 // Record the copy relationship between the graphs.
6185 currToNew[&N] = copy;
6186 newF->addNode(copy);
6187 if (N.hasPredicate()) {
6188 copy->setPredicate(N.getPredicate());
6189 }
6190 }
6191
6192 // At this point we have a new invalid function that points into nodes in
6193 // the original function. Here we update the links between the nodes in the
6194 // new function.
6195 for (auto &N : newF->getNodes()) {
6196 // Fix each one of the inputs of this node.
6197 for (unsigned inp = 0, e = N.getNumInputs(); inp < e; inp++) {
6198 auto input = N.getNthInput(inp);
6199
6200 auto it = currToNew.find(input.getNode());
6201 if (it == currToNew.end()) {
6202 assert(isa<Storage>(input.getNode()) &&
6203 "Could not find a mapping for some node!");
6204 continue;
6205 }
6206
6207 // Update the node with the edge to the current graph.
6208 N.setNthInput(inp, NodeValue(it->second, input.getResNo()));
6209 }
6210
6211 if (N.hasPredicate()) {
6212 auto it = currToNew.find(N.getPredicate().getNode());
6213 if (it != currToNew.end()) {
6214 N.setPredicate(NodeValue(it->second, N.getPredicate().getResNo()));
6215 }
6216 }
6217 }
6218
6219 // Record the node mapping into the external map.
6220 if (map) {
6221 assert(map->empty() && "The external map must be empty");
6222 for (auto it : currToNew) {
6223 map->insert(it);
6224 }
6225 }
6226
6227 assert(newF->getNodes().size() == getNodes().size() && "Invalid func size");
6228 return newF;
6229}
6230
6231/// Verify the input \p idx of a node \p N. Check that the node \p N is in the
6232/// use-list of the corresponding input node.
6233static bool verifyNodeInput(const Node &N, size_t idx) {
6234 auto input = N.getNthInput(idx);
6235 auto *refN = input.getNode();
6236 // Check that N is in the use-list of the input node and there is a proper
6237 // entry for it.
6238 for (auto &U : refN->getUsers()) {
6239 if (U.getUser() == &N && *U.get() == input) {
6240 return true;
6241 }
6242 }
6243
6244 report("Any node referencing another node N must be in the use-list of the "
6245 "node N");
6246 return false;
6247}
6248
6249Module *Module::clone() const {
6250 auto *M = new Module;
6251 return clone(M);
6252}
6253
6254Module *Module::clone(Module *M) const {
6255 // Maps current nodes to new nodes.
6256 llvm::DenseMap<const Node *, Node *> currToNew;
6257 // Clone placeholders.
6258 for (auto &PH : getPlaceholders()) {
6259 auto *copyPH = M->createPlaceholder(PH->getType(), PH->getName(),
6260 PH->isTraining(), PH->getLayout());
6261 currToNew[PH] = copyPH;
6262 }
6263 // Clone constants.
6264 for (auto &C : getConstants()) {
6265 // Cloner cannot decide on its own what to do with constants having unowned
6266 // payloads. Some kind of policy/hook maybe needed in the future for
6267 // deciding what needs to be done in such cases.
6268 DCHECK(!C->getPayload().isUnowned())
6269 << "Cannot copy constant " << C->getName().str()
6270 << ": Unowned payloads are not supported";
6271 auto *copyC = M->createConstant(C->getType(), C->getName(), C->getLayout());
6272 copyC->assign(&C->getPayload());
6273 currToNew[C] = copyC;
6274 }
6275 // Clone all functions.
6276 for (auto *F : getFunctions()) {
6277 // Create an empty clone function in the new module.
6278 auto *copyF = M->createFunction(F->getName());
6279 // Clone function's body into the newly created cloned function. Use the
6280 // currToNew to properly map constants and placeholders.
6281 F->clone(copyF, nullptr, &currToNew);
6282 // Update all types by cloned types.
6283 for (auto &N : copyF->getNodes()) {
6284 for (unsigned idx = 0, e = N.getNumResults(); idx < e; ++idx) {
6285 N.setType(idx, M->uniqueType(*N.getType(idx)));
6286 }
6287 }
6288 }
6289 return M;
6290}
6291
6292/// \returns True if \p n is a storage node (constant or placeholder) of the
6293/// function \p F.
6294static bool isGraphStorageNode(Node *n, const Function *F) {
6295 auto &vars = F->getParent()->getConstants();
6296 auto &placeholders = F->getParent()->getPlaceholders();
6297
6298 if (Constant *V = dyn_cast<Constant>(n)) {
6299 return std::find(vars.begin(), vars.end(), V) != vars.end();
6300 }
6301
6302 if (Placeholder *P = dyn_cast<Placeholder>(n)) {
6303 return std::find(placeholders.begin(), placeholders.end(), P) !=
6304 placeholders.end();
6305 }
6306
6307 return false;
6308}
6309
6310/// Insert \p node in \p nameToNode and report an error if the insertion fails.
6311/// \returns True if \p node was inserted into \p nameToNode. False otherwise.
6312/// When true is returned that means that \p nameToNode had no other nodes
6313/// registered under \p node.getName().
6314static bool
6315insertAndReport(std::unordered_map<std::string, const Node *> &nameToNode,
6316 const Node &node, const Function &function) {
6317 bool inserted = expectCompareTrue(
6318 "Node is not unique",
6319 nameToNode.insert({node.getName().str(), &node}).second, true, &function);
6320 if (!inserted) {
6321 std::string storage;
6322 llvm::raw_string_ostream msg(storage);
6323 /// Output extra information helping to find the error.
6324 msg << "The node with name '" << node.getName()
6325 << "' conflicts with a previous definition:\n";
6326 msg << "Current definition: " << node.getDebugDesc() << "\n";
6327 msg << "Previous definition: "
6328 << nameToNode[node.getName().str()]->getDebugDesc();
6329 report(msg.str().c_str());
6330 return false;
6331 }
6332 return true;
6333}
6334
6335bool Function::verify(const Backend *backend) const {
6336 bool isValid = true;
6337 // Check if the layout verifying is disabled, which will accept all layout for
6338 // any ops.
6339 VLOG(1) << "Layout requirements checking is "
6340 << (glow::flags::DisableLayoutVerifying ? "disabled" : "enabled");
6341 if (!glow::flags::DisableLayoutVerifying) {
6342 if (backend) {
6343 if (backend->getTensorLayoutRequirements().isEnabled()) {
6344 isValid &= expectCompareTrue(
6345 "Expected correct backend-specific layouts for the graph",
6346 verifyLayouts(*this, backend->getTensorLayoutRequirements()), true,
6347 this);
6348 }
6349 } else {
6350 // Always run verification pre-lowering / when we don't have backend:
6351 isValid &= expectCompareTrue(
6352 "Expected correct Glow canonical layouts for the graph",
6353 verifyLayouts(*this, CanonicalTensorLayout::getInstance()), true,
6354 this);
6355 }
6356 }
6357 std::unordered_map<std::string, const Node *> nameToNode;
6358
6359 for (auto *V : findConstants()) {
6360 isValid &= insertAndReport(nameToNode, *V, *this);
6361 isValid &= expectCompareTrue("Constant and its payload must have same type",
6362 *V->getType(), V->getPayload().getType(), V);
6363 }
6364
6365 nameToNode.clear();
6366 for (const auto &N : nodes_) {
6367 isValid &= insertAndReport(nameToNode, N, *this);
6368 }
6369
6370 // Any node referenced by one of the graph nodes should be part of the
6371 // Graph.
6372 for (const auto &N : nodes_) {
6373 for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) {
6374 auto &input = N.getNthInput(idx);
6375 // Verify each input of N.
6376 isValid &= verifyNodeInput(N, idx);
6377 bool foundNode =
6378 std::find(nodes_.begin(), nodes_.end(), *input) != nodes_.end();
6379 isValid &= expectCompareTrue(
6380 "Every node referenced by one of the graph nodes should be part of "
6381 "the graph",
6382 foundNode || isGraphStorageNode(input, this), true, &N);
6383 }
6384 }
6385
6386 // Check that all uses of each node refer to this node.
6387 for (const auto &N : nodes_) {
6388 for (const auto &U : N.getUsers()) {
6389 isValid &= expectCompareTrue<const Node *>(
6390 "All uses of a node should refer to this node", U.get()->getNode(),
6391 &N, &N);
6392 ;
6393 }
6394 }
6395
6396 // Check that all types used by nodes belong to the parent module.
6397 auto &types = getParent()->getTypes();
6398 for (const auto &N : nodes_) {
6399 for (size_t idx = 0, e = N.getNumResults(); idx < e; ++idx) {
6400 auto ty = N.getType(idx);
6401 bool foundType =
6402 std::find(types.begin(), types.end(), *ty) != types.end();
6403 isValid &= expectCompareTrue(
6404 "Every type used by one of the graph nodes should be part of "
6405 "the graph",
6406 foundType, true, &N);
6407 }
6408 }
6409
6410 // Check that there are no zero volume tensors
6411 for (const auto &N : nodes_) {
6412 // Check inputs
6413 for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) {
6414 auto dims = N.getNthInput(idx).dims();
6415 for (auto dim : dims) {
6416 if (dim == 0) {
6417 LOG(ERROR) << "Found 0 volume input in the " << idx
6418 << " input to node " << N.toString() << " with dims "
6419 << dims;
6420 return false;
6421 }
6422 }
6423 }
6424
6425 // Check results
6426 for (size_t idx = 0, e = N.getNumResults(); idx < e; ++idx) {
6427 auto dims = N.getNthResult(idx).dims();
6428 for (auto dim : dims) {
6429 if (dim == 0) {
6430 LOG(ERROR) << "Found 0 volume result in the " << idx
6431 << " result from node " << N.toString() << " with dims "
6432 << dims;
6433 return false;
6434 }
6435 }
6436 }
6437 }
6438
6439 std::unordered_map<const Placeholder *, const Node *> placeholderWrittenTo;
6440 for (const auto &N : nodes_) {
6441 isValid &=
6442 expectCompareTrue("Node is not linked to the function it belongs",
6443 N.getParent(), this, &N);
6444 isValid &= N.verify();
6445 // Make sure all the placeholders are at most written once, and that
6446 // constants are never written to.
6447 for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) {
6448 // Placeholders and Constants have no input, so they can only be
6449 // written to via overwritten inputs.
6450 if (!N.isOverwrittenNthInput(idx)) {
6451 continue;
6452 }
6453
6454 const Node *nthInputNode = N.getNthInput(idx).getNode();
6455 isValid &= expectCompareTrue(
6456 "Constants can never be used as an overwritten input",
6457 isa<Constant>(nthInputNode), false, nthInputNode);
6458
6459 // Unlike Constants, Placeholders can be used at most once as
6460 // overwritten inputs. Keep a map of Placeholders to Nodes that used
6461 // them as overwritten inputs, which is also used later to check for
6462 // read-after-write dependence violations.
6463 const auto *ph = dyn_cast<Placeholder>(nthInputNode);
6464 if (!ph) {
6465 continue;
6466 }
6467 auto varToFirstDef = placeholderWrittenTo.find(ph);
6468 bool writtenOnce = expectCompareTrue(
6469 "Placeholder has more than one write",
6470 varToFirstDef == placeholderWrittenTo.end(), true, ph);
6471 if (!writtenOnce) {
6472 isValid = false;
6473 std::string storage;
6474 llvm::raw_string_ostream msg(storage);
6475
6476 msg << "Placeholder " << ph->getDebugDesc() << '\n';
6477 msg << "has more than one write; second writer found:\n";
6478 msg << N.getDebugDesc() << '\n';
6479 msg << varToFirstDef->second->getDebugDesc() << '\n';
6480
6481 report(msg.str().c_str());
6482 }
6483
6484 placeholderWrittenTo[ph] = &N;
6485 }
6486 }
6487
6488 // Now check that the placeholders that are written to are either:
6489 // - Written by a save node, or
6490 // - Are only used by the node that writes them
6491 // If this check fails, that means we have implicit memory
6492 // dependencies that may not be honored by the scheduler.
6493 // Either the input IR is incorrect or the scheduler needs
6494 // fixing.
6495 for (const auto &varToWrite : placeholderWrittenTo) {
6496 if (isa<SaveNode>(varToWrite.second)) {
6497 continue;
6498 }
6499 for (const NodeUse &use : varToWrite.first->getUsers()) {
6500 const Node *user = use.getUser();
6501 // Ignore users outside this function.
6502 if (user->getParent() != this) {
6503 continue;
6504 }
6505 isValid &= expectCompareTrue(
6506 "Implicit read after write memory dependency may not be honored",
6507 user, varToWrite.second, user);
6508 }
6509 }
6510 return isValid;
6511}
6512
6513SaveNode *glow::getOutputSave(Function *F, Placeholder *PH) {
6514 // if parent is set for PH, check if it is the same as provided Function.
6515 auto *PHP = PH->getParent();
6516 if (PHP != nullptr && F != PHP) {
6517 return nullptr;
6518 }
6519 for (auto &use : PH->getUsers()) {
6520 if (auto *save = llvm::dyn_cast<SaveNode>(use.getUser())) {
6521 if (save->getParent() == F && save->getPlaceholder() == PH) {
6522 return save;
6523 }
6524 }
6525 }
6526 return nullptr;
6527}
6528
6529Node *glow::recursiveClone(Function *newF, Node *node, NodeMap &currToNew) {
6530 Node *copy = node->clone();
6531 currToNew[node] = copy;
6532 newF->addNode(copy);
6533 for (unsigned inp = 0, e = copy->getNumInputs(); inp < e; inp++) {
6534 auto input = copy->getNthInput(inp);
6535 auto it = currToNew.find(input.getNode());
6536 Node *newInput;
6537 if (it != currToNew.end()) {
6538 newInput = it->second;
6539 } else if (llvm::isa<Storage>(input.getNode())) {
6540 continue;
6541 } else {
6542 newInput = recursiveClone(newF, input.getNode(), currToNew);
6543 }
6544 copy->setNthInput(inp, NodeValue(newInput, input.getResNo()));
6545 }
6546 return copy;
6547}
6548
6549namespace glow {
6550/// If \p PH is an output placeholder, \returns true.
6551/// This is determined by checking if the PH has a user which uses the PH as an
6552/// overwritten input.
6553bool isOutput(const Placeholder *PH, const Function &F) {
6554 for (const auto &use : PH->getUsers()) {
6555 // Look through the inputs of the PH's users. If an input is overwritten
6556 // check if it's the PH, if it is return true.
6557 auto *user = use.getUser();
6558 // Consider only users inside the same function.
6559 if (user->getParent() != &F) {
6560 continue;
6561 }
6562 for (unsigned i = 0, numInputs = user->getNumInputs(); i < numInputs; i++) {
6563 // If the input is not overwritten we can continue.
6564 if (!user->isOverwrittenNthInput(i)) {
6565 continue;
6566 }
6567 auto input = use.getUser()->getNthInput(i);
6568 if (input.getNode() == PH) {
6569 return true;
6570 }
6571 }
6572 }
6573 return false;
6574}
6575
6576/// If \p PH is an input placeholder, \returns true.
6577bool isInput(const Placeholder *PH, const Function &F) {
6578 // Check that the PH is the input to a saveNode or is used by a non saveNode.
6579 for (const auto &use : PH->getUsers()) {
6580 // Consider only users inside the same function.
6581 if (use.getUser()->getParent() != &F) {
6582 continue;
6583 }
6584 // Check if PH is an input to a saveNode.
6585 if (auto *save = dyn_cast<SaveNode>(use.getUser())) {
6586 auto input = save->getInput();
6587 // If the PH is not an input to the saveNode we keep looking.
6588 if (input.getNode() != PH) {
6589 continue;
6590 }
6591 }
6592 return true;
6593 }
6594 return false;
6595}
6596
6597llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module &mod) {
6598 mod.dump(os);
6599 return os;
6600}
6601
6602llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module *mod) {
6603 assert(mod != nullptr && "Null Pointer.");
6604 mod->dump(os);
6605 return os;
6606}
6607
6608llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &F) {
6609 F.dump(os);
6610 return os;
6611}
6612
6613llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function *F) {
6614 assert(F != nullptr && "Null Pointer.");
6615 F->dump(os);
6616 return os;
6617}
6618
6619bool isConvolutionSameAsFullyConnected(const ConvolutionNode *node,
6620 bool enforceInput1x1) {
6621 bool isConv2D = (node->getInput().getType()->dims().size() == 4);
6622 if (!(isConv2D && node->getLayout() == ConvolutionLayout::NHWC &&
6623 !node->hasFusedActivation())) {
6624 return false;
6625 }
6626 auto filterDims = ShapeNHWC(node->getFilter().getType()->dims());
6627 ShapeHW kernels = ShapeHW(node->getKernels());
6628 ShapeHW strides = ShapeHW(node->getStrides());
6629 PaddingTLBR pads = PaddingTLBR(node->getPads());
6630 auto group = node->getGroup();
6631 auto dilation = node->getDilation();
6632
6633 bool isSame = (filterDims.h == 1) && (filterDims.w == 1);
6634 isSame &= (kernels.height == 1) && (kernels.width == 1);
6635 isSame &= (strides.height == 1) && (strides.width == 1);
6636 isSame &= (pads.top == 0) && (pads.left == 0) && (pads.bottom == 0) &&
6637 (pads.right == 0);
6638 isSame &= (group == 1);
6639 isSame &= std::all_of(dilation.begin(), dilation.end(),
6640 [](unsigned_t i) { return i == 1; });
6641
6642 if (enforceInput1x1) {
6643 auto inputDims = ShapeNHWC(node->getInput().getType()->dims());
6644 isSame &= (inputDims.h == 1) && (inputDims.w == 1);
6645 }
6646 return isSame;
6647}
6648
6649bool isGemmSameAsFullyConnected(const GemmNode *node) {
6650 NodeValue inpC = node->getC();
6651 return (node->getAlpha() == 1.0) && (node->getBeta() == 1.0) &&
6652 (inpC.getNode()) && (inpC.dims().size() == 1);
6653}
6654
6655} // namespace glow
6656