1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_GRAPH_GRAPH_H
17#define GLOW_GRAPH_GRAPH_H
18
19#include "glow/Base/Type.h"
20#include "glow/Graph/Log.h"
21#include "glow/Graph/Nodes.h"
22#include "glow/Quantization/Base/Base.h"
23
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/DenseMap.h"
26#include "llvm/ADT/StringSet.h"
27#include "llvm/ADT/ilist.h"
28#include "llvm/ADT/ilist_node.h"
29
30#include <list>
31#include <vector>
32
33namespace glow {
34class PlaceholderBindings;
35
36/// List of Types.
37using TypesList = std::list<Type>;
38/// Intrusive list of Nodes.
39using NodesList = llvm::iplist<glow::Node>;
40/// List of pointers to Nodes. The nodes are not owned by the list.
41using NodesPtrList = std::list<glow::Node *>;
42/// List of Functions.
43using FunctionList = std::list<Function *>;
44using ConstList = std::list<Constant *>;
45using PlaceholderList = std::list<Placeholder *>;
46using UnsignedArrayRef = llvm::ArrayRef<dim_t>;
47/// Map from original Nodes to cloned Nodes.
48using NodeMap = llvm::DenseMap<Node *, Node *>;
49/// State of a function. This can be used to control optimizations which depend
50/// on the state of the Function. This is a temporary workaround until GH Issue
51/// #3213 is complete.
52enum class FunctionState {
53 /// Indicates that the function has been created but not completely loaded.
54 FuncCreated,
55 /// Indicates that the function has been completely loaded.
56 FuncLoaded,
57};
58
59/// Helper names for common tensor layouts.
60#define ANY_LAYOUT "*"
61
62class Module final {
63 /// Stores the functions in the module.
64 FunctionList functions_;
65 /// A uniqued list of types. Types in this list can be equated by comparing
66 /// their addresses.
67 TypesList types_{};
68 /// Stores a list of unique Storage names that were used by the module at
69 /// some point.
70 llvm::StringSet<> usedStorageNames_{};
71 /// Stores a list of node names that were used by Functions of this module at
72 /// some point.
73 llvm::StringSet<> usedNodeNames_{};
74 /// Stores a list of node names that were present in the original model and
75 /// are good to be retained.
76 llvm::StringSet<> originalNames_{};
77 /// A list of constants that the Module owns.
78 ConstList constants_;
79 /// A list of placeholder nodes that the Module owns.
80 PlaceholderList placeholders_;
81 /// Deterministic PRNG used to initialize weights in this module.
82 PseudoRNG PRNG_;
83
84 /// Module log context that stores all logs related to this module.
85 LogContext moduleLogCtx_{nullptr};
86
87 /// Inserts the constant \p V to the list of constants.
88 Constant *addConstant(Constant *V);
89
90 friend class Function;
91
92public:
93 Module() = default;
94
95 ~Module();
96
97 /// \returns the prefix part of the provided \p name. E.g. for an input
98 /// of "relu__2" returns "relu".
99 static std::string getPrefix(llvm::StringRef name);
100
101 /// \returns unique legal name that's based on the string \p name. Legal
102 /// names are legal C identifiers in the form: "[a-zA-Z_][a-zA-Z0-9_]*".
103 /// The name may not be in \p stringTable or \p updateTable and will be
104 /// inserted into \p updateTable.
105 static llvm::StringRef uniqueName(llvm::StringRef name,
106 const llvm::StringSet<> &stringTable,
107 llvm::StringSet<> &updateTable,
108 const llvm::StringSet<> &originalNames);
109
110 /// Registers a \p name as used by some Node in this module.
111 void registerNodeName(llvm::StringRef name) {
112 // Don't care if it's already in the set.
113 usedNodeNames_.insert(name);
114 }
115
116 /// Registers a \p name from the original model, good to be retained.
117 void registerOriginalName(llvm::StringRef name) {
118 // Don't care if it's already in the set.
119 if (name.size()) {
120 originalNames_.insert(name);
121 }
122 }
123
124 /// \returns the pointer to list of original node names, good to be retained;
125 const llvm::StringSet<> *getOriginalNames() const { return &originalNames_; }
126
127 /// Registers a name as used by a Storage node (Constant or Placeholder) in
128 /// this module.
129 void registerStorageName(llvm::StringRef name) {
130 usedStorageNames_.insert(name);
131 }
132
133 /// \returns whether there's a Storage node already registered with \p name.
134 bool hasStorageName(llvm::StringRef name) {
135 return usedStorageNames_.count(name);
136 }
137
138 /// Return a pointer to a uniqued type \p T.
139 TypeRef uniqueType(const Type &T);
140
141 /// Return a pointer to a uniqued type \p T.
142 TypeRef uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims);
143
144 /// Return a pointer to a uniqued type \p T.
145 TypeRef uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, float scale,
146 int32_t offset);
147
148 /// Return a pointer to a uniqued type \p T.
149 /// The new type is identical to \p T, with a new shape \p dims.
150 TypeRef uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims);
151
152 /// The new type is identical to \p T, with a new shape \p dims and new \p
153 /// strides.
154 TypeRef uniqueTypeWithNewStrides(TypeRef T, llvm::ArrayRef<dim_t> dims,
155 llvm::ArrayRef<dim_t> strides);
156
157 /// The new type is identical to \p T, with a new shape \p dims and new \p
158 /// alignments.
159 TypeRef uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims,
160 llvm::ArrayRef<dim_t> alignments);
161
162 /// Return a pointer to a uniqued type \p T.
163 /// The new type is identical to \p T, with a new shape and strides taken from
164 /// the type \p shapeType.
165 TypeRef uniqueTypeWithNewShape(TypeRef T, TypeRef shapeType);
166
167 /// Return a pointer to a uniqued type \p T.
168 /// The new type is identical to \p T, with new scale and offset taken from
169 /// the type \p quantParamType.
170 TypeRef uniqueTypeWithNewQuantParams(TypeRef T, TypeRef quantParamType);
171
172 /// Return the void type.
173 TypeRef getVoidTy();
174
175 /// \returns True if a function by the name \p name exists in the module.
176 bool hasFunction(llvm::StringRef name);
177 /// \returns the function with the name \p name, or nullptr if the function
178 /// does not exist.
179 Function *getFunction(llvm::StringRef name);
180 /// \returns a new function with the name \p name.
181 Function *createFunction(llvm::StringRef name);
182 /// \returns the list of Functions that the Module owns.
183 FunctionList &getFunctions() { return functions_; }
184
185 const FunctionList &getFunctions() const { return functions_; }
186
187 /// Clears out all Functions from \ref functions_.
188 void clearFunctions();
189
190 /// \returns the list of types that the Module owns.
191 const TypesList &getTypes() const { return types_; }
192
193 /// Erase the constant \p N from the Module.
194 void eraseConstant(Constant *N);
195
196 /// Erase the constant \p I from the Module.
197 void eraseConstant(ConstList::iterator I);
198
199 /// Erase the placeholder \p I from the Module.
200 /// Note: we only provide an iterator version of this, as erasing Placeholders
201 /// is often unsafe.
202 void erasePlaceholder(PlaceholderList::iterator I);
203
204 /// \returns a pointer to the first Constant with the name \p name or nullptr
205 /// if no node has this name.
206 Constant *getConstantByName(llvm::StringRef name) const;
207
208 /// \returns the list of constants that the Module owns.
209 ConstList &getConstants() { return constants_; }
210
211 const ConstList &getConstants() const { return constants_; }
212
213 /// \returns the list of placeholders that the Module owns.
214 PlaceholderList &getPlaceholders() { return placeholders_; }
215
216 const PlaceholderList &getPlaceholders() const { return placeholders_; }
217
218 /// \returns a pointer to the placeholder with the name \p name or
219 /// nullptr if no placeholder has this name.
220 Placeholder *getPlaceholderByNameSlow(llvm::StringRef name) const;
221
222 /// @name High-level Storage builders.
223 ///@{
224
225 Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
226 llvm::StringRef name, bool isTrainable,
227 const std::string &layout = ANY_LAYOUT);
228
229 Placeholder *createPlaceholder(TypeRef T, llvm::StringRef name,
230 bool isTrainable,
231 const std::string &layout = ANY_LAYOUT);
232
233 Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
234 float scale, int32_t offset,
235 llvm::StringRef name, bool isTrainable,
236 const std::string &layout = ANY_LAYOUT);
237
238 Constant *createConstant(TypeRef T, llvm::StringRef name,
239 const std::string &layout = ANY_LAYOUT);
240
241 Constant *createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims,
242 llvm::StringRef name,
243 const std::string &layout = ANY_LAYOUT);
244
245 Constant *createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims, float scale,
246 int32_t offset, llvm::StringRef name,
247 const std::string &layout = ANY_LAYOUT);
248
249 Constant *createConstant(llvm::StringRef name, const Tensor &tensor,
250 const std::string &layout = ANY_LAYOUT);
251
252 Constant *createConstant(llvm::StringRef name, Tensor &&tensor,
253 const std::string &layout = ANY_LAYOUT);
254
255 ///@}
256
257 /// Verify the correctness of the Module.
258 /// \returns true when the function is valid. False otherwise.
259 bool verify() const;
260
261 /// Get the pseudo-random number generator used by this module.
262 PseudoRNG &getPRNG() { return PRNG_; }
263
264 /// Dump a textual representation of the Module into default output stream.
265 void dump() const;
266
267 /// Dump a textual representation of the Module to std::string.
268 std::string toString() const;
269
270 /// Dump a textual representation of the Module into provided output stream.
271 void dump(llvm::raw_ostream &os) const;
272
273 /// Dump a dotty graph that depicts the Module.
274 void dumpDAG();
275
276 /// Dump a dotty graph that depicts the Module.
277 void dumpDAG(llvm::StringRef dotFilename);
278
279 /// Dump a dotty graph that depicts the Module.
280 void dumpDAG(const char *dotFilename);
281
282 /// Erase all of the functions from the module.
283 void eraseFunctions();
284
285 /// Erase all the functions, Placeholders, Constants, etc.
286 void clear();
287
288 /// Clone a module.
289 /// \returns a new module that is a copy of the current module.
290 Module *clone() const;
291
292 /// Clone the current module into a user-provided module \p M.
293 /// \returns the user-provided module \p M that now contains a clone of the
294 /// current module.
295 Module *clone(Module *M) const;
296
297 /// Strips payloads from constants. This is useful when
298 /// the Module will be kept around for metadata but we want to reduce memory
299 /// use. Unlike clear this leaves PHs and Constants in the module.
300 void strip();
301
302 /// Erase a function \p F from the module.
303 void eraseFunction(Function *F);
304
305 /// \Returns the size in bytes of data used by constants.
306 uint64_t getConstantsSize();
307
308 /// \Returns the module log context.
309 LogContext *getModuleLogContext() { return &moduleLogCtx_; };
310
311 /// \returns whether any Node in the module are non-fused quantized with
312 /// scale == or != dummyScale, depending on \p expectDummy.
313 Error verifyDummyQParams(bool expectDummies);
314
315 // Don't copy or move this class around.
316 // The destructor will wipe the functions leaving
317 // the original Module only dangling pointers.
318 Module(const Module &) = delete;
319 Module(Module &&) = delete;
320 Module &operator=(const PlaceholderBindings &) = delete;
321 Module &operator=(PlaceholderBindings &&) = delete;
322};
323
324// Forward Declaration for verify's optional parameter
325class Backend;
326struct CompilationContext;
327
328/// Represents the compute graph.
329class Function final : public IRContainer {
330 /// A list of nodes that the Function owns.
331 NodesList nodes_;
332
333 /// A list of metadata PHs associated with the function.
334 std::vector<Placeholder *> metadataPlaceholders_;
335
336 /// Stores a list of unique node names that were used by the module at some
337 /// point.
338 llvm::StringSet<> uniqueNodeNames_{};
339
340 /// A reference to the owner of the function.
341 Module *parent_;
342
343 /// The log context associated with this function.
344 std::shared_ptr<LogContext> logCtx_;
345
346 /// The state of this function.
347 FunctionState state_;
348
349public:
350 Function(Module *parent, llvm::StringRef Name = {})
351 : IRContainer(Name), parent_(parent), state_(FunctionState::FuncCreated) {
352 logCtx_ = std::make_shared<LogContext>(parent);
353 logCtx_->pushEvent(parent->getModuleLogContext()->getClonedScope());
354 }
355
356 ~Function();
357
358 IRKind getIRKind() const override { return IRKind::GlowGraphIRKind; };
359
360 static bool classof(const IRContainer *I) {
361 return I->getIRKind() == IRKind::GlowGraphIRKind;
362 }
363
364 static bool classof(const Function *F) { return true; }
365
366 /// Clear out \ref nodes_ and \ref uniqueNodeNames_.
367 void clear();
368
369 /// Sets the state of the function.
370 void setState(FunctionState state) { state_ = state; }
371
372 /// Gets the state of the function.
373 FunctionState getState() { return state_; }
374
375 std::string getFilename() { return getName().rsplit('/').second.str(); }
376
377 /// Return the log context.
378 std::shared_ptr<LogContext> getLogContext() { return logCtx_; }
379
380 /// Add placeholder for metadata such as profiling.
381 void addMetadataPlaceholder(Placeholder *PH) {
382 metadataPlaceholders_.push_back(PH);
383 }
384
385 /// Get list of metadata placeholders.
386 const std::vector<Placeholder *> &getMetadataPlaceholders() const {
387 return metadataPlaceholders_;
388 }
389
390 Module *getParent() override { return parent_; }
391
392 /// Perform ordering of nodes_ based on node's name.
393 /// This is to make sure that performing optimizations have a deterministic
394 /// behavior on the graphs which have the same ops but different ordering in
395 /// nodes_.
396 /// Please do not call this in the middle of PyTorchModelLoading, since
397 /// constant propagation is heavily relied on the order of nodes in nodelist.
398 /// If the order is changed during model loading, the constant propagation may
399 /// cause unpredictable fatal error when building the graph.
400 void orderNodes() {
401 nodes_.sort(
402 [](const Node &a, const Node &b) { return a.getName() < b.getName(); });
403 }
404
405 /// Search the Module containing the function to gather and return a list of
406 /// placeholders that are used by the Function.
407 PlaceholderList findPlaceholders();
408 PlaceholderList findPlaceholders() const;
409
410 /// Search the Module containing the function to gather and return a list of
411 /// constants that are used by the Function.
412 ConstList findConstants();
413 ConstList findConstants() const;
414
415 const Module *getParent() const override { return parent_; }
416
417 /// Inserts the node \p N to the list of nodes, and returns the inserted node.
418 template <class NodeTy> NodeTy *addNode(NodeTy *N) {
419 N->setName(Module::uniqueName(N->getName(), parent_->usedStorageNames_,
420 uniqueNodeNames_, parent_->originalNames_));
421 parent_->registerNodeName(N->getName());
422 nodes_.push_back(N);
423
424 // Log the node creation.
425 logCtx_->logNodeCreation(*N);
426
427 return N;
428 }
429
430 /// Take ownership of \p N by removing it from its original Function, add it
431 /// to the current Function, and also unique its name.
432 void takeOwnershipOfNode(Node *N) {
433 N->getParent()->getNodes().remove(N);
434 N->setName(Module::uniqueName(N->getName(), parent_->usedStorageNames_,
435 uniqueNodeNames_, parent_->originalNames_));
436 parent_->registerNodeName(N->getName());
437 nodes_.push_back(N);
438 }
439
440 /// Get the pseudo-random number generator used by this module.
441 PseudoRNG &getPRNG() { return getParent()->getPRNG(); }
442
443 /// @name High-level, operation-level IRBuilder.
444 ///@{
445
446 /// Creates a PadNode with the given \p name and output type \p outTy which
447 /// pads the given \p input with the explicit pads \p pads according to the
448 /// padding mode \p mode and with the given value \p value. The padding mode
449 /// \p mode is one of enumeration values from \ref PaddingMode. For an input
450 /// with N dimensions (rank N) the \p pads must be a vector with 2*N values
451 /// with the following format:
452 /// pads = [pad_before(D1), pad_before(D2), ..., pad_before(DN),
453 /// pad_after (D1), pad_after (D2), ..., pad_after (DN)].
454 /// The mode PaddingMode::CONSTANT pads the input using the constant value
455 /// \p value and currently is the only mode supported.
456 PadNode *createPad(llvm::StringRef name, NodeValue input, TypeRef outTy,
457 unsigned_t mode, llvm::ArrayRef<int> pads, float value);
458
459 /// Creates a ConvolutionNode with the given \p name which convolves the 4D
460 /// \p input with \p filter and \bias. \p kernels defines the size of the
461 /// height and width dimensions of the filters. \p strides defines the number
462 /// of steps to take in the input for each output cell. \p pads defines how
463 /// many zero padding cells should be added to the input during convolution.
464 /// \p group defines the number of groups the input and output channels should
465 /// be divided into and convolved separately. \p dilation defines factor by
466 /// which gap between 2 neighboring kernel elements is expanded along each
467 /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW.
468
469 ConvolutionNode *
470 createConv(llvm::StringRef name, NodeValue input, NodeValue filter,
471 NodeValue bias, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
472 llvm::ArrayRef<unsigned_t> strides,
473 llvm::ArrayRef<unsigned_t> pads, unsigned_t group,
474 llvm::ArrayRef<unsigned_t> dilation = {1, 1},
475 ConvolutionLayout layout = ConvolutionLayout::NHWC);
476
477 /// Creates a ConvolutionNode with the given \p name which convolves the 4D
478 /// \p input with \p filter and \bias. \p kernel defines the size of the
479 /// height and width dimensions of the filters. \p stride defines the number
480 /// of steps to take in the input for each output cell. \p pad defines how
481 /// many zero padding cells should be added to the input during convolution.
482 /// \p group defines the number of groups the input and output channels should
483 /// be divided into and convolved separately. \p dilation defines factor by
484 /// which gap between 2 neighboring kernel elements is expanded along each
485 /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW.
486
487 ConvolutionNode *
488 createConv(llvm::StringRef name, NodeValue input, NodeValue filter,
489 NodeValue bias, TypeRef outTy, unsigned_t kernel,
490 unsigned_t stride, unsigned_t pad, unsigned_t group,
491 llvm::ArrayRef<unsigned_t> dilation = {1, 1},
492 ConvolutionLayout layout = ConvolutionLayout::NHWC);
493
494 /// Creates a Convolution3DNode with the given \p name which convolves the 5D
495 /// \p input with \p filter and \bias. \p kernels defines the size of the
496 /// height, width, and depth dimensions of the filters. \p strides defines the
497 /// the number of steps to take in the input for each output cell. \p pads
498 /// defines how many zero padding cells should be added to the input during
499 /// convolution. \p group defines the number of groups the input and output
500 /// channels should be divided into and convolved separately. \p outTy defines
501 /// the type of the output of the 3d convolution.
502 Convolution3DNode *createConv3D(llvm::StringRef name, NodeValue input,
503 NodeValue filter, NodeValue bias,
504 TypeRef outTy,
505 llvm::ArrayRef<unsigned_t> kernels,
506 llvm::ArrayRef<unsigned_t> strides,
507 llvm::ArrayRef<unsigned_t> pads,
508 unsigned_t group);
509
510 /// Creates a Convolution3DNode with the given \p name which convolves the 5D
511 /// \p input with \p filter and \bias. \p kernel defines the size of the
512 /// height, width, and depth dimensions of the filters. \p stride defines the
513 /// the number of steps to take in the input for each output cell. \p pad
514 /// defines how many zero padding cells should be added to the input during
515 /// convolution. \p group defines the number of groups the input and output
516 /// channels should be divided into and convolved separately. \p outTy defines
517 /// the type of the output of the 3d convolution.
518 Convolution3DNode *createConv3D(llvm::StringRef name, NodeValue input,
519 NodeValue filter, NodeValue bias,
520 TypeRef outTy, unsigned_t kernel,
521 unsigned_t stride, unsigned_t pad,
522 unsigned_t group);
523
524 /// Creates a ChannelwiseQuantizedConvolutionNode with the given \p name which
525 /// convolves the 4D/5D \p input with \p filter and \p bias. \p filterScales
526 /// and \p filterOffsets provide individual quantization parameters for each
527 /// filter group in \p filter while \p biasScales and \p biasOffsets provide
528 /// individual quantization parameters for each bias element corresponding to
529 /// each output channel. \p kernels defines the size of the height and width
530 /// dimensions of the filters. \p strides defines the number of steps to take
531 /// in the input for each output cell. \p pads defines how many zero padding
532 /// cells should be added to the input during convolution. \p group defines
533 /// the number of groups the input and output channels should be divided into
534 /// and convolved separately. \p dilation defines the filter dilation.
535 /// This function is flexible and has the following features:
536 /// - it can be provided with a floating-point \p filter and the function will
537 /// quantize automatically the filter channelwise using the given schema
538 /// \p schema and type \p filterElemQTy.
539 /// - it can be provided with a floating-point \p bias and the function will
540 /// quantize automatically the bias channelwise using the given schema
541 /// \p schema and type \p biasElemQTy.
542 /// - if \p filter is floating-point and \p filterScales or \p filterOffsets
543 /// are not provided then this function will derive them automatically.
544 /// - if \p filter is quantized then \p filterScales or \p filterOffsets are
545 /// mandatory.
546 /// - if \p bias is floating-point and \p biasScales or \p biasOffsets are not
547 /// provided then this function will derive them automatically.
548 /// - if \p bias is quantized and \p biasScales or \p biasOffsets are not
549 /// provided then this function will assume the implicit parameters
550 /// biasScales[i] = inputScale * filterScales[i] and biasOffsets[i] = 0.
551 /// To be noted that this case can handle safely only INT32 bias data type
552 /// because for INT8 type the bias will almost certainly be saturated.
553 /// This function will only quantize the filter if \p quantizeFilter is set
554 /// to true and will only quantize the bias if \p quantizeBias is set to true
555 /// such that a floating-point filter/bias can be attached to the node as-is
556 /// without any modifications in order for the backends to perform their own
557 /// custom quantization later if desired.
558 /// This function requires \p filter and \p bias operands to be constants.
559 ChannelwiseQuantizedConvolutionNode *createChannelwiseQuantizedConv(
560 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
561 NodeValue filterScales, NodeValue filterOffsets, NodeValue biasScales,
562 NodeValue biasOffsets, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
563 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
564 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation = {1, 1},
565 bool quantizeFilter = true, bool quantizeBias = true,
566 quantization::Schema schema = quantization::Schema::Asymmetric,
567 ElemKind filterElemQTy = ElemKind::Int8QTy,
568 ElemKind biasElemQTy = ElemKind::Int32QTy);
569
570 /// Creates a ConvTransposeNode with the given \p name which does transposed
571 /// convolution of the 4D \p input with \p filter and \bias. \p kernels define
572 /// the size of the height and width dimensions of the filters. \p strides
573 /// define the number of steps to take in the input for each output cell.
574 /// \p pads define how many zero padding cells should be added to the input
575 /// during convolution. \p group defines the number of groups the input and
576 /// output channels should be divided into and convolved separately.
577 ConvTransposeNode *createConvTranspose(
578 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
579 TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
580 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
581 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation = {1, 1});
582
583 /// Creates a createConvTransposeNode with the given \p name which does
584 /// transposed convolution of the 4D \p input with \p filter and \bias. \p
585 /// kernel defines the size of the height and width dimensions of the filters.
586 /// \p stride defines the number of steps to take in the input for each output
587 /// cell. \p pad defines how many zero padding cells should be added to the
588 /// input during convolution. \p group defines the number of groups the input
589 /// and output channels should be divided into and convolved separately.
590 ConvTransposeNode *
591 createConvTranspose(llvm::StringRef name, NodeValue input, NodeValue filter,
592 NodeValue bias, TypeRef outTy, unsigned_t kernel,
593 unsigned_t stride, unsigned_t pad, unsigned_t group,
594 llvm::ArrayRef<unsigned_t> dilation = {1, 1});
595
596 /// Creates and \returns a ConvertTo Node with name \p name of \p input to
597 /// output type \p outTy.
598 ConvertToNode *createConvertTo(llvm::StringRef name, NodeValue input,
599 TypeRef outTy);
600
601 /// Creates and \returns a ConvertTo Node with name \p name of \p input to
602 /// output ElemKind \p k.
603 ConvertToNode *createConvertTo(llvm::StringRef name, NodeValue input,
604 ElemKind k);
605
606 MaxPoolNode *createMaxPool(llvm::StringRef name, NodeValue input,
607 llvm::ArrayRef<unsigned_t> kernels,
608 llvm::ArrayRef<unsigned_t> strides,
609 llvm::ArrayRef<unsigned_t> pads,
610 ElemKind elemTyAMT = ElemKind::Int64ITy,
611 ConvolutionLayout layout = NHWC);
612
613 MaxPoolNode *createMaxPool(llvm::StringRef name, NodeValue input,
614 unsigned_t kernel, unsigned_t stride,
615 unsigned_t pad,
616 ElemKind elemTyAMT = ElemKind::Int64ITy,
617 ConvolutionLayout layout = NHWC);
618
619 AvgPoolNode *createAvgPool(llvm::StringRef name, NodeValue input,
620 llvm::ArrayRef<unsigned_t> kernels,
621 llvm::ArrayRef<unsigned_t> strides,
622 llvm::ArrayRef<unsigned_t> pads,
623 ConvolutionLayout layout = NHWC,
624 bool countIncludePads = true);
625
626 AvgPoolNode *createAvgPool(llvm::StringRef name, NodeValue input,
627 TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
628 llvm::ArrayRef<unsigned_t> strides,
629 llvm::ArrayRef<unsigned_t> pads,
630 ConvolutionLayout layout = NHWC,
631 bool countIncludePads = true);
632
633 AvgPoolNode *createAvgPool(llvm::StringRef name, NodeValue input,
634 unsigned_t kernel, unsigned_t stride,
635 unsigned_t pad, ConvolutionLayout layout = NHWC,
636 bool countIncludePads = true);
637
638 /// Creates and \returns an AdaptiveAvgPool node with \p name, \p input, and
639 /// \p outTy. The AdaptiveAvgPoolNode will perform average pooling over the
640 /// input so that the result is of the shape specified by \p outTy.
641 AdaptiveAvgPoolNode *createAdaptiveAvgPool(llvm::StringRef name,
642 NodeValue input, TypeRef outTy);
643
644 /// Creates and \returns a General Matrix Multiplication (Gemm) node with
645 /// given \p name which computes Y = alpha * A * B + beta * C. The operands
646 /// \p A and \p B are 2D matrices, the \p C operand is an optional 1D or 2D
647 /// matrix (broadcastable to the size of Y) and \p alpha and \p beta are float
648 /// scalars. The \p C operand is optional, if nullptr is given then it is not
649 /// used. If \p transposeA or \p transposeB is true then \p A or \p B is
650 /// additionally transposed prior to matrix multiplication.
651 /// If the output shape of Y is [M,N] then:
652 /// - The shape of \p A must be [M,K] or [K,M] (if transposed).
653 /// - The shape of \p B must be [K,N] or [N,K] (if transposed).
654 /// - The shape of \p C must be [N] (if 1D) or [M,N] (if 2D).
655 GemmNode *createGemm(llvm::StringRef name, NodeValue A, NodeValue B,
656 NodeValue C = nullptr, float alpha = 1.0,
657 float beta = 1.0, bool transposeA = false,
658 bool transposeB = false);
659
660 GemmNode *createGemm(llvm::StringRef name, TypeRef outTy, NodeValue A,
661 NodeValue B, NodeValue C = nullptr, float alpha = 1.0,
662 float beta = 1.0, bool transposeA = false,
663 bool transposeB = false);
664
665 /// Create and \returns a DynamicQuantizedFullyConnectedNode with \p name,
666 /// \p input, weights \p W, bias \p B, flag to indicate mode \p isSymmetric.
667 /// By default it is a dynamic quantized FC node, which takes fp16 inputs,
668 /// symmetrically quantized them, run FC on them, dequantize them and produces
669 /// fp16 output. If \p isSymmetric is set to false, then inputs are
670 /// asymmetrically quantized. if \p isPerBatchElement is set to false, then
671 /// inputs are per tensor quantized.
672 DynamicQuantizedFullyConnectedNode *createDynamicQuantizedFullyConnected(
673 llvm::StringRef name, NodeValue input, NodeValue W, NodeValue B,
674 bool isSymmetric = true, bool isPerBatchElement = true);
675
676 /// Create and \returns a DynamicRowwiseQuantizedFullyConnectedNode with \p
677 /// name, \p input, weights \p W, bias \p B, rowwise weight qparams \p wScale
678 /// and \p wOffset, flag to indicate mode \p isSymmetric. By default it is a
679 /// dynamic quantized FC node, which takes fp16 inputs, symmetrically
680 /// quantized them, run FC on them, dequantize them and produces fp16 output.
681 /// If \p isSymmetric is set to false, then inputs are asymmetrically
682 /// quantized. if \p isPerBatchElement is set to false, then inputs are per
683 /// tensor quantized.
684 DynamicRowwiseQuantizedFullyConnectedNode *
685 createDynamicRowwiseQuantizedFullyConnected(llvm::StringRef name,
686 NodeValue input, NodeValue W,
687 NodeValue B, NodeValue wScale,
688 NodeValue wOffset,
689 bool isSymmetric = true,
690 bool isPerBatchElement = true);
691
692 /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
693 /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened
694 /// along \p axis. Note, output type and outputDepth are inferred based on
695 /// the input types.
696 FullyConnectedNode *createFullyConnected(llvm::StringRef name,
697 NodeValue input, Storage *W,
698 Storage *B, unsigned_t axis = 1);
699
700 /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
701 /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened
702 /// along \p axis. Note, output type and outputDepth are inferred based on
703 /// the input types.
704 FullyConnectedNode *createFullyConnected(llvm::StringRef name,
705 NodeValue input, NodeValue W,
706 NodeValue B, unsigned_t axis = 1);
707
708 /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
709 /// \p W, bias \p B, and \p outTy. If \p input is not 2 dimensional then it is
710 /// flattened along \p axis. Note, outputDepth is inferred based on \p outTy.
711 FullyConnectedNode *createFullyConnected(llvm::StringRef name,
712 NodeValue input, NodeValue W,
713 NodeValue B, TypeRef outTy,
714 unsigned_t axis = 1);
715
716 /// Create a row-wise quantized fully connected node. This node is only used
717 /// in quantization. Args \p input and \p B are quantized in regular way, \p W
718 /// is the constant weights and is row-wise quantized using the given \p
719 /// scales and \p offsets. The output is quantized in the regular way, and its
720 /// type \p outTy is a quantized type.
721 RowwiseQuantizedFullyConnectedNode *createRowwiseQuantizedFullyConnected(
722 llvm::StringRef name, NodeValue input, NodeValue W, Constant *scales,
723 Constant *offsets, NodeValue B, TypeRef outTy);
724
725 /// Create a row-wise quantized fully connected node. This node is only used
726 /// in quantization. Args \p input and \p B are quantized in regular way, \p W
727 /// is the constant weights and will be row-wise quantized during node
728 /// creation time. The output is quantized in the regular way, and its type
729 /// \p outTy is a quantized type. if \p transposeWeight is true, \p W need to
730 /// be transposed first.
731 RowwiseQuantizedFullyConnectedNode *createRowwiseQuantizedFullyConnected(
732 llvm::StringRef name, NodeValue input, NodeValue W, NodeValue B,
733 TypeRef outTy, quantization::Schema schema, bool transposeWeight = false);
734
735 /// Implement an operation that computes the row-wise dot product of its
736 /// inputs. Consequently, \p X and \p Y must be either 1D or 2D tensors. This
737 /// lowered to a Mul node, and is followed by a BatchedReduceAdd if \p X and
738 /// \p Y are 2D. \returns either the Mul or BatchedReduceAdd node.
739 Node *createDotProduct(llvm::StringRef name, NodeValue X, NodeValue Y);
740
741 /// Create a node that computes the pairwise dot product of \p inputs, which
742 /// must be a list of 2D tensors with identical shape. \returns the
743 /// BatchedPairwiseDotProductNode.
744 BatchedPairwiseDotProductNode *
745 createBatchedPairwiseDotProduct(llvm::StringRef name,
746 llvm::ArrayRef<NodeValue> inputs);
747
748 /// Create a node that implements the elementwise linear operator. \p X is
749 /// 2D and \p w and \p b are 1D. \p w and \p b are broadcasted to match the
750 /// shape of \p X and then the output is computed by multiplying \p X and
751 /// broadcasted \p w and adding broadcasted \p b. \returns the
752 /// ElementwiseLinearNode. \p axis indicates the axis of the inputs (the other
753 /// axis of \p X is assumed to be the batch index).
754 Node *createElementwiseLinear(llvm::StringRef name, NodeValue X, NodeValue w,
755 NodeValue b, unsigned axis);
756
757 /// Create a ReLU node with the given \p name and \p input.
758 /// Result type will be implicitly set based on the \p input type.
759 ReluNode *createRelu(llvm::StringRef name, NodeValue input);
760 // deprecated.
761 ReluNode *createRELU(llvm::StringRef name, NodeValue input);
762
763 /// Create a ReLU node with the given \p name, \p input and
764 /// output type \p outTy.
765 ReluNode *createRelu(llvm::StringRef name, TypeRef outTy, NodeValue input);
766 // deprecated.
767 ReluNode *createRELU(llvm::StringRef name, NodeValue input, TypeRef outTy);
768
769 /// Create a series of nodes representing a GeLU with the given \p name and \p
770 /// input. Result type will be implicitly set based on the \p input type.
771 GeluNode *createGelu(llvm::StringRef name, NodeValue input);
772 // deprecated.
773 GeluNode *createGELU(llvm::StringRef name, NodeValue input);
774
775 /// Create a PReLU node with the given \p name, \p input and \p slope.
776 /// Result type will be implicitly set based on the \p input type.
777 PReluNode *createPRELU(llvm::StringRef name, NodeValue input,
778 NodeValue slope);
779
780 /// Create a PReLU node with the given \p name, \p input, \p slope and
781 /// output type \p outTy.
782 PReluNode *createPRELU(llvm::StringRef name, NodeValue input, NodeValue slope,
783 TypeRef outTy);
784
785 /// Create a Sigmoid node with the given \p name, \p input and
786 /// output type \p outTy.
787 SigmoidNode *createSigmoid(llvm::StringRef name, TypeRef outTy,
788 NodeValue input);
789
790 /// Create a Sigmoid node with the given \p name and \p input.
791 /// Result type will be implicitly set based on the \p input type.
792 SigmoidNode *createSigmoid(llvm::StringRef name, NodeValue input);
793
794 /// Create a Swish node with the given \p name and \p input.
795 /// If \p OT is nullptr, then result type will be implicitly set based on the
796 /// \p input type.
797 SwishNode *createSwish(llvm::StringRef name, NodeValue input);
798 SwishNode *createSwish(llvm::StringRef name, TypeRef OT, NodeValue input);
799 // deprecated.
800 SwishNode *createSwish(llvm::StringRef name, NodeValue input, TypeRef OT);
801
802 /// Create a HardSigmoid node with the given \p name, \p input, \p alpha and
803 /// \p beta. Result type will be implicitly set based on the \p input type.
804 ClipNode *createHardSigmoid(llvm::StringRef name, NodeValue input,
805 float alpha, float beta);
806
807 /// Create a HardSigmoid node with the given \p name, \p input,
808 /// \p alpha, \p beta and output type \p outTy.
809 ClipNode *createHardSigmoid(llvm::StringRef name, TypeRef outTy,
810 NodeValue input, float alpha, float beta);
811
812 /// Create a Tanh node with the given \p name, \p input and
813 /// output type \p outTy.
814 TanhNode *createTanh(llvm::StringRef name, TypeRef outTy, NodeValue input);
815
816 /// Create a Tanh node with the given \p name and \p input.
817 /// Result type will be implicitly set based on the \p input type.
818 TanhNode *createTanh(llvm::StringRef name, NodeValue input);
819
820 /// Create an Exp node with \p name, which calculates element-wise
821 /// exponential of \p input.
822 ExpNode *createExp(llvm::StringRef name, NodeValue input);
823
824 /// Create an Exp node with \p name with output type \p outTy, which
825 /// calculates element-wise exponential of \p input.
826 ExpNode *createExp(llvm::StringRef name, TypeRef outTy, NodeValue input);
827
828 /// Create a Log node with \p name, which calculates element-wise natural log
829 /// of \p input, with output type \p outTy.
830 LogNode *createLog(llvm::StringRef name, NodeValue input);
831 LogNode *createLog(llvm::StringRef name, TypeRef outTy, NodeValue input);
832 // deprecated
833 LogNode *createLog(llvm::StringRef name, NodeValue input, TypeRef outTy);
834
835 /// \returns a LogitNode with \p name given \p input and \p eps.
836 LogitNode *createLogit(llvm::StringRef name, NodeValue input, float eps);
837
838 /// Create a SoftPlus node with the given \p name, \p input and
839 /// output type \p outTy.
840 SoftPlusNode *createSoftPlus(llvm::StringRef name, NodeValue input,
841 TypeRef outTy = nullptr);
842
843 SoftMaxNode *createSoftMax(llvm::StringRef name, NodeValue input,
844 NodeValue selected, TypeRef outTy = nullptr,
845 float beta = 1.0);
846
847 LogSoftMaxNode *createLogSoftMax(llvm::StringRef name, NodeValue input,
848 NodeValue selected, TypeRef outTy = nullptr,
849 float beta = 1.0);
850
851 CrossEntropyLossNode *createCrossEntropyLoss(llvm::StringRef name,
852 NodeValue input,
853 NodeValue labels);
854
855 RegressionNode *createRegression(llvm::StringRef name, NodeValue input,
856 NodeValue expected);
857
858 /// Creates a node, which computes sigmoid cross entropy between two inputs.
859 SigmoidCrossEntropyWithLogitsNode *
860 createSigmoidCrossEntropyWithLogits(llvm::StringRef name, NodeValue logits,
861 NodeValue targets);
862
863 ReshapeNode *createReshape(llvm::StringRef name, NodeValue input,
864 UnsignedArrayRef shape,
865 llvm::StringRef layout = ANY_LAYOUT);
866
867 TransposeNode *createTranspose(llvm::StringRef name, NodeValue input,
868 llvm::ArrayRef<unsigned_t> shuffle,
869 const std::string &layout = ANY_LAYOUT);
870
871 /// Create a node with the name \p name which flips (reorders) the elements
872 /// of the input \p input along the given axis \p axis.
873 FlipNode *createFlip(llvm::StringRef name, NodeValue input, unsigned_t axis);
874
875 /// Create a Broadcast node that broadcasting the \p input Tensor based on
876 /// \p newShape and along the \p axis, which defines the offset between the
877 /// input dim and the newShape.
878 /// e.g. For input: [3] and newShape: [2, 3, 2], the axis will be 1.
879 /// For input: [3] and newShape: [2, 2, 3], the axis will be 2.
880 BroadcastNode *createBroadcast(llvm::StringRef name, NodeValue input,
881 UnsignedArrayRef newShape, unsigned_t axis);
882
883 /// Create concat node which concatenates input tensors along \p dimension.
884 ConcatNode *createConcat(llvm::StringRef name,
885 llvm::ArrayRef<NodeValue> inputs,
886 unsigned_t dimension);
887
888 /// Create concat node with the given return type \p outTy.
889 ConcatNode *createConcat(llvm::StringRef name,
890 llvm::ArrayRef<NodeValue> inputs,
891 unsigned_t dimension, TypeRef outTy);
892
893 /// Create a TileNode with \p name, \p input, \p tiles, and \p axis.
894 /// For example, an input tensor {{1,2,3,4}} of dimension 1x4 with tiles = 2
895 /// and axis = 0 would result in an output tensor {{1,2,3,4}, {1,2,3,4}} of
896 /// dimension 2x4.
897 TileNode *createTile(llvm::StringRef name, NodeValue input, unsigned_t tiles,
898 unsigned_t axis, TypeRef outTy = nullptr);
899
900 /// Create a TileNode with \p name, \p input which repeats the input data
901 /// the given count values \p tiles along the given \p axes.
902 TileNode *createTile(llvm::StringRef name, NodeValue input,
903 llvm::ArrayRef<unsigned_t> tiles,
904 llvm::ArrayRef<unsigned_t> axes);
905
906 /// Create an insert tensor node \p name, which inserts \p small into \p big
907 /// at offset into big \p start \p count times along \p axis.
908 InsertTensorNode *createInsertTensor(llvm::StringRef name, NodeValue big,
909 NodeValue small,
910 llvm::ArrayRef<dim_t> start,
911 unsigned_t count = 1,
912 unsigned_t axis = 0);
913
914 /// Create a slice node \p name with the given starting points for each
915 /// dimension \p begin and end points \p end (exclusive).
916 SliceNode *createSlice(llvm::StringRef name, NodeValue input,
917 UnsignedArrayRef begin, UnsignedArrayRef end);
918
919 /// Create a slice node with the given starting point for each dimension.
920 /// End points will be calculated based on the output type during execution.
921 SliceNode *createSlice(llvm::StringRef name, NodeValue input,
922 llvm::ArrayRef<dim_t> start, TypeRef outTy);
923
924 /// Shuffles dimension number \p kernel. Suppose original size is D. It will
925 /// be represented as groupX(D/group) matrix, transposed and concatenated back
926 /// to size D. For example, shuffle of {1, 2, 3, 4, 5, 6} with \p group = 2 is
927 /// {1, 4, 2, 5, 3, 6}
928 Node *createChannelShuffle(llvm::StringRef name, NodeValue input,
929 size_t group, size_t kernel);
930
931 /// Computes the indices of the max elements of the input tensor along the
932 /// provided \p axis. The resulted tensor has the same rank as the input if \p
933 /// keepDims equal 1. If \p keepdims equals 0, the resulted tensor has the
934 /// reduced dimension pruned. The type of the output tensor is \p elemTy.
935 ArgMaxNode *createArgMax(llvm::StringRef name, NodeValue input,
936 unsigned_t axis, bool keepDims,
937 ElemKind elemTy = ElemKind::Int64ITy);
938
939 /// Computes the indices of the min elements of the input tensor along the
940 /// provided \p axis. The resulted tensor has the same rank as the input if \p
941 /// keepDims equal 1. If \p keepdims equals 0, the resulted tensor has the
942 /// reduced dimension pruned. The type of the output tensor is \p elemTy.
943 ArgMinNode *createArgMin(llvm::StringRef name, NodeValue input,
944 unsigned_t axis, bool keepDims,
945 ElemKind elemTy = ElemKind::Int64ITy);
946
947 /// Removes single-dimensional entries from the shape of a tensor. The
948 /// parameter \p axes is a list of positive integers, indicating the
949 /// dimensions to squeeze. Impelmented as a single ReshapeNode. This is the
950 /// opposite of ExpandDims.
951 /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#squeeze
952 ReshapeNode *createSqueeze(llvm::StringRef name, NodeValue input,
953 llvm::ArrayRef<dim_t> axes);
954
955 /// Add single-dimensional entries to the shape of the \p input tensor at
956 /// locations in \p axes. \p axes is listed as seen in the output tensor.
957 /// Implemented as a single ReshapeNode. This is the opposite of Squeeze.
958 ReshapeNode *createExpandDims(llvm::StringRef name, NodeValue input,
959 llvm::ArrayRef<dim_t> axes);
960
961 /// Flattens the input tensor into a 2D matrix. If input tensor has shape
962 /// (d_0, d_1, ... d_n) then the output will have shape:
963 /// (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X d_n).
964 ReshapeNode *createFlatten(llvm::StringRef name, NodeValue input,
965 unsigned_t axis);
966
967 /// Flattens the input tensor into a 2D matrix. If input tensor has shape
968 /// (d_0, d_1, ... d_n) then the output will have shape:
969 /// ((d_0 X d_1 ... d_(axis-1) X d_(axis+1) ... X d_n), d_axis).
970 ReshapeNode *createFlattenV1(llvm::StringRef name, NodeValue input,
971 unsigned_t axis);
972
973 /// Create \p outputNum slice nodes of \p input. Slices happen along dimension
974 /// number \p axis. Array \p split defines lengths of slices. If \p split is
975 /// empty, \p input is split to equal sized parts.
976 void createSplit(llvm::StringRef name, NodeValue input, unsigned_t outputNum,
977 unsigned_t axis, llvm::ArrayRef<dim_t> split,
978 std::vector<SliceNode *> &outputs);
979
980 BatchNormalizationNode *createBatchNormalization(
981 llvm::StringRef name, TypeRef resType, NodeValue input, NodeValue beta,
982 NodeValue scale, NodeValue mean, NodeValue var, unsigned_t channelIdx = 0,
983 float epsilon = 1e-5, float momentum = 0.9);
984
985 /// Create and \returns an InstanceNormalizationNode with result type of \p
986 /// outTy that computes the instance normalization of \p input based on the \p
987 /// scale and \p bias combined with the computed mean and stddev of each
988 /// batch. \p epsilon is a small perterbation used to avoid division by 0
989 /// during normalization.
990 InstanceNormalizationNode *
991 createInstanceNormalization(llvm::StringRef name, NodeValue input,
992 NodeValue beta, NodeValue scale,
993 unsigned_t channelIdx = 0, float epsilon = 1e-5);
994
995 /// Creates and \returns a LayerNormalizationNode with result type of \p outTy
996 /// that computes the layer normalization of the inner most layers of \p input
997 /// based on the shape of \p scale and \p bias. \p epsilon is a small
998 /// perterbation used to avoid division by 0 during normalization.
999 LayerNormalizationNode *
1000 createLayerNormalization(llvm::StringRef name, TypeRef outTy, NodeValue input,
1001 NodeValue scale, NodeValue bias,
1002 float epsilon = 1e-5);
1003
1004 /// Bucketizes the input tensor based on monotonically increasing \p
1005 /// boundaries for each value in \p input. For each value x in input, the
1006 /// operator \returns index i given boundaries[i-1] < x <= boundaries[i]. If
1007 /// the value x is beyond the bounds of boundaries, 0 or len(boundaries) is
1008 /// returned as appropriate.
1009 BucketizeNode *createBucketizeNode(llvm::StringRef name, NodeValue input,
1010 llvm::ArrayRef<float> boundaries);
1011
1012 LocalResponseNormalizationNode *createLocalResponseNormalization(
1013 llvm::StringRef name, NodeValue input, unsigned_t halfWindowSize = 2,
1014 float alpha = 1e-4, float beta = 0.75, float k = 2.0);
1015
1016 /// Create a ModuloNode which performs the modulo operation elementwise on the
1017 /// \p input such that each element in the output is equal to the
1018 /// corresponding element in the input modulo \p divisor. If \p
1019 /// signFollowDivisor is true then any negative elements in the output will
1020 /// have divisor added to their final values.
1021 ModuloNode *createModulo(llvm::StringRef name, NodeValue input,
1022 int64_t divisor, bool signFollowDivisor = false);
1023
1024 /// Create a logical NOT node with name \p name and input \p input.
1025 NotNode *createNot(llvm::StringRef name, NodeValue input);
1026
1027 // Create a BitwiseNot node with name \p name and input \p input.
1028 BitwiseNotNode *createBitwiseNot(llvm::StringRef name, NodeValue input);
1029
1030#define UNARY_ARITHMETIC_FUN_DECL(NODE_NAME_) \
1031 NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue input); \
1032 NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, TypeRef Ty, \
1033 NodeValue input);
1034 UNARY_ARITHMETIC_FUN_DECL(Abs)
1035 UNARY_ARITHMETIC_FUN_DECL(Neg)
1036 UNARY_ARITHMETIC_FUN_DECL(Floor)
1037 UNARY_ARITHMETIC_FUN_DECL(Sign)
1038 UNARY_ARITHMETIC_FUN_DECL(Ceil)
1039 UNARY_ARITHMETIC_FUN_DECL(Round)
1040 UNARY_ARITHMETIC_FUN_DECL(Sqrt)
1041 UNARY_ARITHMETIC_FUN_DECL(Rsqrt)
1042 UNARY_ARITHMETIC_FUN_DECL(Reciprocal)
1043 UNARY_ARITHMETIC_FUN_DECL(Sin)
1044 UNARY_ARITHMETIC_FUN_DECL(Cos)
1045 UNARY_ARITHMETIC_FUN_DECL(Erf)
1046 UNARY_ARITHMETIC_FUN_DECL(Truncate)
1047 UNARY_ARITHMETIC_FUN_DECL(HardSwish)
1048#undef UNARY_ARITHMETIC_FUN_DECL
1049
1050#define ARITHMETIC_FUN_DECL(NODE_NAME_) \
1051 NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue LHS, \
1052 NodeValue RHS); \
1053 NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, TypeRef Ty, \
1054 NodeValue LHS, NodeValue RHS);
1055 ARITHMETIC_FUN_DECL(Add);
1056 ARITHMETIC_FUN_DECL(Mul);
1057 ARITHMETIC_FUN_DECL(Sub);
1058 ARITHMETIC_FUN_DECL(Div);
1059 ARITHMETIC_FUN_DECL(Max);
1060 ARITHMETIC_FUN_DECL(Min);
1061 ARITHMETIC_FUN_DECL(CmpEQ);
1062 ARITHMETIC_FUN_DECL(CmpNEQ);
1063 ARITHMETIC_FUN_DECL(CmpLT);
1064 ARITHMETIC_FUN_DECL(CmpLTE);
1065 ARITHMETIC_FUN_DECL(And);
1066 ARITHMETIC_FUN_DECL(Or);
1067 ARITHMETIC_FUN_DECL(Xor);
1068 ARITHMETIC_FUN_DECL(BitwiseAnd);
1069 ARITHMETIC_FUN_DECL(BitwiseOr);
1070 ARITHMETIC_FUN_DECL(BitwiseXor);
1071 ARITHMETIC_FUN_DECL(Pow);
1072 ARITHMETIC_FUN_DECL(Fmod);
1073#undef ARITHMETIC_FUN_DECL
1074
1075#define TRIGONOMETRIC_FUN_DECL(NODE_NAME_) \
1076 NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue input); \
1077 NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, TypeRef Ty, \
1078 NodeValue input);
1079 TRIGONOMETRIC_FUN_DECL(Acos)
1080 TRIGONOMETRIC_FUN_DECL(Asin)
1081 TRIGONOMETRIC_FUN_DECL(Atan)
1082#undef TRIGONOMETRIC_FUN_DECL
1083
1084 std::vector<NodeValue>
1085 broadcastInputs(int axis, const llvm::ArrayRef<NodeValue> inputs);
1086
1087 template <class T, class U>
1088 using enable_if_same_t = std::enable_if<std::is_same<T, U>::value, U>;
1089
1090#define BROADCAST_FUNC_COMMON_CODE(NUM_INPUTS) \
1091 constexpr size_t numInputs = sizeof...(Args); \
1092 static_assert(numInputs == NUM_INPUTS, \
1093 "Invalid input passed in to commonCreateBroadcast."); \
1094 std::vector<NodeValue> inputs = broadcastInputs(axis, {inputArgs...});
1095
1096#define DECLARE_BROADCAST_NODE(NODE_NAME, NUM_INPUTS) \
1097 template <class T, class... Args> \
1098 typename enable_if_same_t<T, NODE_NAME##Node>::type * \
1099 createNodeWithBroadcast(const std::string &name, int axis, \
1100 Args &&...inputArgs) { \
1101 BROADCAST_FUNC_COMMON_CODE(NUM_INPUTS) \
1102 return create##NODE_NAME(name, inputs[0].getType(), inputs[0], inputs[1]); \
1103 }
1104
1105 /// Template function that creates a node and normalizes its input shapes
1106 /// with the use of BroadCast nodes. If axis is -1, it calculates it
1107 /// automatically for multi directional broadcast.
1108 DECLARE_BROADCAST_NODE(Add, /* NUM_INPUTS */ 2)
1109 DECLARE_BROADCAST_NODE(Sub, /* NUM_INPUTS */ 2)
1110 DECLARE_BROADCAST_NODE(Mul, /* NUM_INPUTS */ 2)
1111 DECLARE_BROADCAST_NODE(Div, /* NUM_INPUTS */ 2)
1112 DECLARE_BROADCAST_NODE(And, /* NUM_INPUTS */ 2)
1113 DECLARE_BROADCAST_NODE(Xor, /* NUM_INPUTS */ 2)
1114 DECLARE_BROADCAST_NODE(Or, /* NUM_INPUTS */ 2)
1115 DECLARE_BROADCAST_NODE(BitwiseAnd, /* NUM_INPUTS */ 2)
1116 DECLARE_BROADCAST_NODE(BitwiseXor, /* NUM_INPUTS */ 2)
1117 DECLARE_BROADCAST_NODE(BitwiseOr, /* NUM_INPUTS */ 2)
1118 DECLARE_BROADCAST_NODE(Pow, /* NUM_INPUTS */ 2)
1119 DECLARE_BROADCAST_NODE(Fmod, /* NUM_INPUTS */ 2)
1120
1121#define DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(NODE_NAME, NUM_INPUTS, \
1122 OUTTYPEREF) \
1123 template <class T, class... Args> \
1124 typename enable_if_same_t<T, NODE_NAME##Node>::type * \
1125 createNodeWithBroadcastOutTy(const std::string &name, int axis, \
1126 TypeRef OUTTYPEREF, Args &&...inputArgs) { \
1127 BROADCAST_FUNC_COMMON_CODE(NUM_INPUTS) \
1128 return create##NODE_NAME(name, OUTTYPEREF, inputs[0], inputs[1]); \
1129 }
1130
1131 DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Add, /* NUM_INPUTS */ 2, outTy)
1132 DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Sub, /* NUM_INPUTS */ 2, outTy)
1133 DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Mul, /* NUM_INPUTS */ 2, outTy)
1134 DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Div, /* NUM_INPUTS */ 2, outTy)
1135 DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Min, /* NUM_INPUTS */ 2, outTy)
1136 DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Max, /* NUM_INPUTS */ 2, outTy)
1137 DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Fmod, /* NUM_INPUTS */ 2, outTy)
1138
1139#define DECLARE_CMP_BROADCAST_NODE(NODE_NAME) \
1140 template <class T, class... Args> \
1141 typename enable_if_same_t<T, NODE_NAME##Node>::type * \
1142 createNodeWithBroadcast(const std::string &name, int axis, \
1143 Args &&...inputArgs) { \
1144 BROADCAST_FUNC_COMMON_CODE(2) \
1145 return create##NODE_NAME(name, inputs[0], inputs[1]); \
1146 }
1147
1148 /// Template function that creates a node and normalizes its input shapes
1149 /// with the use of BroadCast nodes. If axis is -1, it calculates it
1150 /// automatically for multi directional broadcast.
1151 DECLARE_CMP_BROADCAST_NODE(CmpLT)
1152 DECLARE_CMP_BROADCAST_NODE(CmpEQ)
1153 DECLARE_CMP_BROADCAST_NODE(CmpNEQ)
1154 DECLARE_CMP_BROADCAST_NODE(CmpLTE)
1155 DECLARE_CMP_BROADCAST_NODE(Min)
1156 DECLARE_CMP_BROADCAST_NODE(Max)
1157
1158 /// Template function that creates a node and normalizes its input shapes
1159 /// with the use of BroadCast nodes. If axis is -1, it calculates it
1160 /// automatically for multi directional broadcast.
1161 template <class T, class... Args>
1162 typename enable_if_same_t<T, SelectNode>::type *
1163 createNodeWithBroadcast(const std::string &name, int axis,
1164 Args &&...inputArgs) {
1165 BROADCAST_FUNC_COMMON_CODE(3)
1166 return createSelect(name, inputs[1].getType(), inputs[0], inputs[1],
1167 inputs[2]);
1168 }
1169
1170#undef BROADCAST_FUNC_COMMON_CODE
1171#undef DECLARE_BROADCAST_NODE
1172#undef DECLARE_BROADCAST_NODE_WITH_OUT_TYPE
1173#undef DECLARE_CMP_BROADCAST_NODE
1174#undef BROADCAST_FUNC_COMMON_CODE
1175
1176 /// Create a FloorDivNode with given \p name which divides \p LHS with \p RHS
1177 /// and floors the quotient. If \p truncate is true then truncates the
1178 /// quotient instead of flooring.
1179 FloorDivNode *createFloorDiv(llvm::StringRef name, NodeValue LHS,
1180 NodeValue RHS, bool truncate = false);
1181
1182 /// Create a FloorDivNode with given \p name and output type \p outTy which
1183 /// divides \p LHS with \p RHS and floors the quotient. If \p truncate is true
1184 /// then truncates the quotient to zero instead of flooring.
1185 FloorDivNode *createFloorDiv(llvm::StringRef name, TypeRef outTy,
1186 NodeValue LHS, NodeValue RHS,
1187 bool truncate = false);
1188
1189 /// Create a FloorDivNode with given \p name which divides \p LHS with \p RHS
1190 /// and floors the quotient. If \p truncate is true then truncates the
1191 /// quotient to zero instead of flooring. The inputs are broadcasted based on
1192 /// \p axis.
1193 FloorDivNode *createFloorDivWithBroadcast(llvm::StringRef name, int axis,
1194 NodeValue LHS, NodeValue RHS,
1195 bool truncate = false);
1196
1197 /// Create a FloorDivNode with given \p name and output type \p outTy which
1198 /// divides \p LHS with \p RHS and floors the quotient. If \p truncate is true
1199 /// then truncates the quotient to zero instead of flooring. The inputs are
1200 /// broadcasted based on \p axis.
1201 FloorDivNode *createFloorDivWithBroadcast(llvm::StringRef name, int axis,
1202 TypeRef outTy, NodeValue LHS,
1203 NodeValue RHS,
1204 bool truncate = false);
1205
1206 /// Create an element-wise GREATER THAN comparison between \p LHS and \p RHS
1207 /// by creating a CmpLTNode with given \p name and swapped inputs.
1208 CmpLTNode *createCmpGT(llvm::StringRef name, NodeValue LHS, NodeValue RHS);
1209
1210 /// Create an element-wise GREATER THAN or EQUAL comparison between \p LHS and
1211 /// \p RHS by creating a CmpLTENode with given \p name and swapped inputs.
1212 CmpLTENode *createCmpGTE(llvm::StringRef name, NodeValue LHS, NodeValue RHS);
1213
1214 /// Create a MulNode with given \p name which multiplies \p input with itself
1215 /// to produce an equivalent Square node.
1216 MulNode *createSquare(llvm::StringRef name, NodeValue input);
1217
1218 /// Create a MulNode with given \p name and output type \p outTy which
1219 /// multiplies \p input with itself to produce an equivalent Square node.
1220 MulNode *createSquare(llvm::StringRef name, TypeRef outTy, NodeValue input);
1221
1222 /// Create a LeakyRELU with \p name, \p input and slope \p alpha.
1223 LeakyReluNode *createLeakyRELU(llvm::StringRef name, NodeValue input,
1224 float alpha);
1225
1226 /// Create a LeakyRELU with \p name, \p outTy, \p input and slope \p alpha.
1227 LeakyReluNode *createLeakyRELU(llvm::StringRef name, TypeRef outTy,
1228 NodeValue input, float alpha);
1229
1230 /// Create a node that produces an boolean output of the same shape as
1231 /// \p input in which each element indicates whether or not the corresponding
1232 /// element in \p input is NaN or not.
1233 IsNaNNode *createIsNaN(llvm::StringRef name, NodeValue input);
1234
1235 /// \returns a ReplaceNaNNode given \p name, \p input, and \p value.
1236 ReplaceNaNNode *createReplaceNaN(llvm::StringRef name, NodeValue input,
1237 float value);
1238
1239 PowNode *createPow(llvm::StringRef name, NodeValue base, float exp);
1240
1241 NonZeroNode *createNonZero(llvm::StringRef name, NodeValue Cond);
1242
1243 SelectNode *createSelect(llvm::StringRef name, NodeValue Cond, NodeValue LHS,
1244 NodeValue RHS);
1245
1246 SelectNode *createSelect(llvm::StringRef name, TypeRef outTy, NodeValue Cond,
1247 NodeValue LHS, NodeValue RHS);
1248
1249 SplatNode *createSplat(llvm::StringRef name, TypeRef ty, float value);
1250
1251 TouchNode *createTouch(llvm::StringRef name, TypeRef ty);
1252
1253 MatMulNode *createMatMul(llvm::StringRef name, NodeValue lhs, NodeValue rhs);
1254
1255 MatMulNode *createMatMul(llvm::StringRef name, TypeRef outTy, NodeValue lhs,
1256 NodeValue rhs);
1257
1258 /// \p lhs and \p rhs are 3d matrices, where the leading dimension is the
1259 /// batch size. For each batch element number i, lhs.slice(i) is multiplied by
1260 /// rhs.slice(i).
1261 BatchMatMulNode *createBatchMatMul(llvm::StringRef name, NodeValue lhs,
1262 NodeValue rhs);
1263
1264 /// Create a node, performing Norm operation. Output type is based on the
1265 /// input \p p type with dimensions specified with \p axes removed.
1266 VectorNormNode *createVectorNorm(llvm::StringRef name, NodeValue input,
1267 unsigned_t axis, unsigned_t p = 2);
1268
1269 /// Create a node, performing BatchedReduceAdd operation. Output type is
1270 /// based on the input \p batch type with dimensions specified with \p axes
1271 /// removed.
1272 BatchedReduceAddNode *createBatchedReduceAdd(llvm::StringRef name,
1273 NodeValue batch,
1274 llvm::ArrayRef<unsigned_t> axes);
1275
1276 /// Create a node, performing BatchedReduceAdd operation. Output type
1277 /// matches input \p outTy type.
1278 BatchedReduceAddNode *createBatchedReduceAdd(llvm::StringRef name,
1279 TypeRef outTy, NodeValue batch,
1280 llvm::ArrayRef<unsigned_t> axes);
1281
1282 /// Create a node, performing BatchedReduceSumSquare operation. Output type is
1283 /// based on the input \p batch type with dimensions specified with \p axes
1284 /// removed.
1285 BatchedReduceSumSquareNode *
1286 createBatchedReduceSumSquare(llvm::StringRef name, NodeValue batch,
1287 llvm::ArrayRef<unsigned_t> axes);
1288
1289 /// Create a node, performing BatchedReduceSumSquare operation. Output type
1290 /// matches input \p outTy type.
1291 BatchedReduceSumSquareNode *
1292 createBatchedReduceSumSquare(llvm::StringRef name, TypeRef outTy,
1293 NodeValue batch,
1294 llvm::ArrayRef<unsigned_t> axes);
1295
1296 /// Create a node, performing BatchedReduceMin operation. Output type
1297 /// matches input \p outTy type.
1298 BatchedReduceMinNode *createBatchedReduceMin(llvm::StringRef name,
1299 TypeRef outTy, NodeValue batch,
1300 llvm::ArrayRef<unsigned_t> axes);
1301
1302 /// Create a node, performing BatchedReduceMin operation. Output type is
1303 /// based on the input \p batch type with dimensions specified with \p axes
1304 /// removed.
1305 BatchedReduceMinNode *createBatchedReduceMin(llvm::StringRef name,
1306 NodeValue batch,
1307 llvm::ArrayRef<unsigned_t> axes);
1308
1309 /// Create a node, performing BatchedReduceMax operation. Output type
1310 /// matches input \p outTy type.
1311 BatchedReduceMaxNode *createBatchedReduceMax(llvm::StringRef name,
1312 TypeRef outTy, NodeValue batch,
1313 llvm::ArrayRef<unsigned_t> axes);
1314
1315 /// Create a node, performing BatchedReduceMax operation. Output type is
1316 /// based on the input \p batch type with dimensions specified with \p axes
1317 /// removed.
1318 BatchedReduceMaxNode *createBatchedReduceMax(llvm::StringRef name,
1319 NodeValue batch,
1320 llvm::ArrayRef<unsigned_t> axes);
1321
1322 /// Create a node, performing BatchedReduceMean operation. Output type
1323 /// matches input \p outTy type.
1324 BatchedReduceMeanNode *
1325 createBatchedReduceMean(llvm::StringRef name, TypeRef outTy, NodeValue batch,
1326 llvm::ArrayRef<unsigned_t> axes);
1327
1328 /// Create a node, performing BatchedReduceMean operation. Output type is
1329 /// based on the input \p batch type with dimensions specified with \p axes
1330 /// removed.
1331 BatchedReduceMeanNode *
1332 createBatchedReduceMean(llvm::StringRef name, NodeValue batch,
1333 llvm::ArrayRef<unsigned_t> axes);
1334
1335 /// Create a node, performing BatchedReduceProd operation. Output type is
1336 /// based on the input \p batch type with dimensions specified with \p axes
1337 /// removed.
1338 BatchedReduceProdNode *
1339 createBatchedReduceProd(llvm::StringRef name, NodeValue batch,
1340 llvm::ArrayRef<unsigned_t> axes);
1341
1342 /// Create a node, performing BatchedReduceProd operation. Output type
1343 /// matches input \p outTy type.
1344 BatchedReduceProdNode *
1345 createBatchedReduceProd(llvm::StringRef name, TypeRef outTy, NodeValue batch,
1346 llvm::ArrayRef<unsigned_t> axes);
1347
1348 BatchedAddNode *createBatchedAdd(llvm::StringRef name, NodeValue batch,
1349 NodeValue slice);
1350
1351 BatchedAddNode *createBatchedAdd(llvm::StringRef name, TypeRef outTy,
1352 NodeValue batch, NodeValue slice);
1353
1354 BatchedMulNode *createBatchedMul(llvm::StringRef name, NodeValue batch,
1355 NodeValue slice);
1356
1357 BatchedMulNode *createBatchedMul(llvm::StringRef name, TypeRef outTy,
1358 NodeValue batch, NodeValue slice);
1359
1360 /// Create a node performing a Cumulative Sum operation, output type matches
1361 /// \p input type.
1362 CumSumNode *createCumSum(llvm::StringRef name, NodeValue input,
1363 int64_t dim = 0, bool exclusive = false,
1364 bool reverse = false);
1365
1366 /// Implements an operation that accumulates the values in \p data along the
1367 /// first dimension into len(\p lengths) entries by summing together the first
1368 /// lengths[0] values, then the subsequent lengths[1] values, etc.
1369 /// sum(\p lengths) must equal the first dimension of \p data. This operation
1370 /// is similar to SparseLengthsSum but the input is a dense represention
1371 /// instead of a sparse one. In other words, it has already been Gathered.
1372 LengthsSumNode *createLengthsSum(llvm::StringRef name, NodeValue data,
1373 NodeValue lengths);
1374
1375 /// Create a node, performing SparseLengthsSum operation:
1376 /// Gathers slices of the outer-most dimension of Data indexed by Indices
1377 /// vector, and then accumulates them into len(Lengths) entries:
1378 /// first Lengths[0] slices are aggregated to Result[0], next Lengths[1]
1379 /// slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal
1380 /// to len(Indices). \p lengthsMode and \p avgLength represent meta
1381 /// information about the \p lengths, allowing the backend to use a
1382 /// specialized implementation.
1383 SparseLengthsSumNode *
1384 createSparseLengthsSum(llvm::StringRef name, NodeValue data,
1385 NodeValue indices, NodeValue lengths,
1386 LengthsMode lengthsMode = LengthsMode::Variable,
1387 float avgLength = NAN);
1388
1389 /// Same as SparseLengthsSum, but i-th slice is multiplied by weights[i].
1390 /// len(weights) must be equal to len(indices).
1391 SparseLengthsWeightedSumNode *createSparseLengthsWeightedSum(
1392 llvm::StringRef name, NodeValue data, NodeValue weights,
1393 NodeValue indices, NodeValue lengths,
1394 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1395
1396 /// Create an Embedding node
1397 /// weights is a 2D tensor capturing the embedding table
1398 /// indices is a tesnor of arbitrary shape containing the indices to extract
1399 /// padIdx, if given, zeros the output vector when encounters padIdx
1400 /// scale, if true, will scale gradients by the inverse of the frequency of
1401 /// words in mini-batch (currently not supported, default=false)
1402 /// sparse, if true, gradinet w.r.t. weight matrix will be a sparse tensor
1403 /// (currently not supported, default=false)
1404 EmbeddingNode *createEmbedding(llvm::StringRef name, NodeValue weights,
1405 NodeValue indices, int32_t padIdx, bool scale,
1406 bool sparse);
1407
1408 /// Create an EmbeddingBag node. If \p hasEndOffset is true then the node
1409 /// expects an extra offset to be appended to \p offsets which marks the end
1410 /// of the last range. \p lengthsMode and \p avgLength represent meta
1411 /// information about the \p lengths, allowing the backend to use a
1412 /// specialized implementation.
1413 EmbeddingBagNode *createEmbeddingBag(
1414 llvm::StringRef name, NodeValue data, NodeValue weights,
1415 NodeValue indices, NodeValue offsets, bool hasEndOffset = false,
1416 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1417
1418 /// Create an EmbeddingBagByteRowwiseOffsetsNode node. If \p hasEndOffset is
1419 /// true then the node expects an extra offset to be appended to \p offsets
1420 /// which marks the end of the last range. \p lengthsMode and \p avgLength
1421 /// represent meta information about the \p lengths, allowing the backend to
1422 /// use a specialized implementation.
1423 EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets(
1424 llvm::StringRef name, NodeValue data, NodeValue weights,
1425 NodeValue indices, NodeValue offsets, bool useFP16Accumulation = false,
1426 bool hasEndOffset = false,
1427 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1428
1429 /// Same as \ref createEmbeddingBagByteRowwiseOffsets(), but
1430 /// expects float input \p data, which is rowwise-quantized and fused
1431 /// internally. \p fusedElemKind represents the element kind to use for the
1432 /// final fused rowwise-quantized data. If \p hasEndOffset is true then the
1433 /// node expects an extra offset to be appended to \p offsets which marks the
1434 /// end of the last range.
1435 EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets(
1436 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
1437 NodeValue offsets, ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
1438 bool useFP16Accumulation = false, bool hasEndOffset = false,
1439 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1440
1441 /// Same as \ref createSparseLengthsWeightedSum(), but with \p outTy
1442 /// specified.
1443 SparseLengthsWeightedSumNode *createSparseLengthsWeightedSum(
1444 llvm::StringRef name, TypeRef outTy, NodeValue data, NodeValue weights,
1445 NodeValue indices, NodeValue lengths,
1446 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1447
1448 /// Creates and \returns a node of \p name, performing the SparseLengthsSum
1449 /// operation, using rowwise quantization for the input \p data with the \p
1450 /// scales and \p offsets as separate input tensors. Gathers slices of the
1451 /// outer-most dimension of data indexed by the \p indices vector, and then
1452 /// accumulates them into len(\p lengths) entries: first Lengths[0] slices are
1453 /// aggregated to Result[0], next Lengths[1] slices are aggregated to
1454 /// Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices).
1455 /// \p precision represents what precision to use for Scale, Offset, and
1456 /// Result. If \p useFP16Accumulation, then internal arithmetic will use FP16
1457 /// accumulation; otherwise defaults to FP32. \p lengthsMode and \p avgLength
1458 /// represent meta information about the \p lengths, allowing the backend to
1459 /// use a specialized implementation.
1460 RowwiseQuantizedSparseLengthsWeightedSumNode *
1461 createRowwiseQuantizedSparseLengthsSum(
1462 llvm::StringRef name, Storage *data, NodeValue scales, NodeValue offsets,
1463 NodeValue indices, NodeValue lengths,
1464 ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false,
1465 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1466
1467 /// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but expects
1468 /// float input \p data, which is rowwise-quantized internally.
1469 RowwiseQuantizedSparseLengthsWeightedSumNode *
1470 createRowwiseQuantizedSparseLengthsSum(
1471 llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
1472 quantization::Schema schema, ElemKind precision = ElemKind::FloatTy,
1473 bool useFP16Accumulation = false,
1474 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1475
1476 /// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but i-th slice is
1477 /// multiplied by weights[i]. len(weights) must be equal to len(indices).
1478 RowwiseQuantizedSparseLengthsWeightedSumNode *
1479 createRowwiseQuantizedSparseLengthsWeightedSum(
1480 llvm::StringRef name, Storage *data, NodeValue scales, NodeValue offsets,
1481 NodeValue weights, NodeValue indices, NodeValue lengths,
1482 ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false,
1483 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1484
1485 /// Same as \ref createRowwiseQuantizedSparseLengthsWeightedSum(), but expects
1486 /// float input \p data, which is rowwise-quantized internally.
1487 RowwiseQuantizedSparseLengthsWeightedSumNode *
1488 createRowwiseQuantizedSparseLengthsWeightedSum(
1489 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
1490 NodeValue lengths, quantization::Schema schema,
1491 ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false,
1492 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1493
1494 /// Creates and \returns a node of \p name, performing the SparseLengthsSum
1495 /// operation, using fused rowwise quantization for the input \p data wherein
1496 /// the scales and offsets are fused inline with each row of data. \p data
1497 /// must be of a fused ElemKind. Gathers slices of the outer-most dimension of
1498 /// data indexed by the \p indices vector, and then accumulates them into
1499 /// len(\p lengths) entries: first Lengths[0] slices are aggregated to
1500 /// Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e.
1501 /// sum(Lengths) must be equal to len(Indices). The precision for the Result
1502 /// is determined by the \p data input's ElemKind used for Scale and
1503 /// Offset. If \p useFP16Accumulation, then internal arithmetic will use FP16
1504 /// accumulation; otherwise defaults to FP32. \p lengthsMode and \p avgLength
1505 /// represent meta information about the \p lengths, allowing the backend to
1506 /// use a specialized implementation.
1507 FusedRowwiseQuantizedSparseLengthsSumNode *
1508 createFusedRowwiseQuantizedSparseLengthsSum(
1509 llvm::StringRef name, Storage *data, NodeValue indices, NodeValue lengths,
1510 bool useFP16Accumulation = false,
1511 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1512
1513 /// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but expects
1514 /// float input \p data, which is rowwise-quantized and fused internally.
1515 /// \p fusedElemKind represents the element kind to use for the final fused
1516 /// rowwise-quantized data.
1517 FusedRowwiseQuantizedSparseLengthsSumNode *
1518 createFusedRowwiseQuantizedSparseLengthsSum(
1519 llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
1520 ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
1521 bool useFP16Accumulation = false,
1522 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1523
1524 /// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but i-th slice
1525 /// is multiplied by weights[i]. len(weights) must be equal to len(indices).
1526 FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
1527 createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1528 llvm::StringRef name, NodeValue data, NodeValue weights,
1529 NodeValue indices, NodeValue lengths, bool useFP16Accumulation = false,
1530 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1531
1532 /// Same as \ref createFusedRowwiseQuantizedSparseLengthsWeightedSum(), but
1533 /// expects float input \p data, which is rowwise-quantized and fused
1534 /// internally. \p fusedElemKind represents the element kind to use for the
1535 /// final fused rowwise-quantized data.
1536 FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
1537 createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1538 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
1539 NodeValue lengths, ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
1540 bool useFP16Accumulation = false,
1541 LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1542
1543 /// Given a vector of segment lengths, calculates offsets of each segment and
1544 /// packs them next to the lengths. For the input vector of length N the
1545 /// output is a Nx2 matrix with (offset, lengths) packaged for each segment.
1546 LengthsToRangesNode *createLengthsToRanges(llvm::StringRef name,
1547 NodeValue lengths);
1548
1549 /// Given a vector of \p lengths, \returns a LengthsRangeFillNode. This Node
1550 /// calculates a range sequence given \p lengths, where the sum of the
1551 /// elements of \p lengths must be no greater than \p maxOutputSize, which is
1552 /// used to set the output type.
1553 LengthsRangeFillNode *createLengthsRangeFill(llvm::StringRef name,
1554 NodeValue lengths,
1555 unsigned_t maxOutputSize);
1556
1557 /// Implements an operation that converts the sparse representation given by
1558 /// the \p lengths, \p indices and \p values into a dense representation.
1559 /// This representation contains \p lengths[i] indices in batch i, and in each
1560 /// batch contains the value of \p values at the corresponding index given by
1561 /// \p indices. All indices that are not present in \p indices are filled with
1562 /// defaultValue. \p indices within the same batch should not contain
1563 /// duplicates. \p denseLastDim gives the last dimension of the output dense
1564 /// representation (ie. the second dimension).
1565 BatchSparseToDenseNode *
1566 createBatchSparseToDense(llvm::StringRef name, NodeValue lengths,
1567 NodeValue indices, NodeValue values,
1568 float defaultValue, unsigned_t denseLastDim);
1569
1570 /// Implements an operation that inserts zeros into \p data along axis=0 for
1571 /// indices where \p indicator is zero.
1572 FillExamplesWithIndicatorNode *
1573 createFillExamplesWithIndicator(llvm::StringRef name, NodeValue data,
1574 NodeValue indicator);
1575
1576 /// Implements an operation that converts the sparse representation given by
1577 /// the pair of \p indices and \p values into a dense representation, which
1578 /// only contains IDs from given \p mask. Indices cannot contain duplicates.
1579 /// \p lengths is used to distinguish elements that belong to different
1580 /// examples of one batch. That is, first \p lengths[0] index-value pairs
1581 /// belong to batch's example 0, next \p lengths[1] pairs belong to example 1
1582 /// and so on.
1583 SparseToDenseMaskNode *
1584 createSparseToDenseMask(llvm::StringRef name, NodeValue indices,
1585 NodeValue values, NodeValue defaultValue,
1586 NodeValue lengths, llvm::ArrayRef<dim_t> mask);
1587
1588 // TODO: add description
1589 SparseLabelSplitNode *
1590 createSparseLabelSplit(llvm::StringRef name, NodeValue lengths,
1591 NodeValue indices, NodeValue values, dim_t numLabels);
1592
1593 /// Given floats a input node \p input, floats \p mean and \p scale, and \p
1594 /// seed \returns a GaussianFillNode. The output shape is the same as that of
1595 /// \p input, filled with values drawn from a normal distribution with mean
1596 /// and std dev \p mean and \scale, respectively, seeded with seed \p seed
1597 GaussianFillNode *createGaussianFill(llvm::StringRef name, NodeValue input,
1598 float mean, float scale, float seed);
1599
1600 SaveNode *createSave(llvm::StringRef name, NodeValue input);
1601
1602 /// Creates and \returns a SaveNode of \p input to \p output. If \p skipSuffix
1603 /// then the name used is \p name, otherwise suffix "_save" is appended.
1604 SaveNode *createSave(llvm::StringRef name, NodeValue input,
1605 Placeholder *output, bool skipSuffix = false);
1606
1607 /// Create quantization profile node named \p name for the output tensor from
1608 /// \p input in PlaceholderBindings \p bindings. Capture observed node name in
1609 /// quantization profile node as original node can be replaced during lowering
1610 /// phase. Compute the histogram during profiling with \p numHistogramBins.
1611 QuantizationProfileNode *
1612 createQuantizationProfile(PlaceholderBindings &bindings, llvm::StringRef name,
1613 NodeValue input, dim_t numHistogramBins = 10);
1614
1615 /// Create lookup table for mapping between quantized operands. \p input and
1616 /// \p outTy must be quantized types. The table contains all numbers from the
1617 /// quantized range, e.g. 256 entries for int8 input. First position in the
1618 /// \p initValues corresponds to the minimum input number and the last
1619 /// position corresponds to the maximum input number.
1620 template <typename T = int8_t>
1621 IntLookupTableNode *
1622 createIntLookupTable(llvm::StringRef name, NodeValue input,
1623 llvm::ArrayRef<T> initValues, TypeRef outTy);
1624
1625 /// Create lookup table for mapping between quantized operands based on the
1626 /// floating point function \p func. \p input and \p outTy must be quantized
1627 /// types.
1628 IntLookupTableNode *createIntLookupTable(llvm::StringRef name,
1629 NodeValue input,
1630 std::function<float(float)> func,
1631 TypeRef outTy);
1632
1633 /// Create lookup table for operator \p lutOperator using the provided lookup
1634 /// \p table.
1635 LookupTableNode *createLookupTable(llvm::StringRef name, NodeValue input,
1636 LUTOperator lutOperator,
1637 std::vector<float> &lutOperatorArgs,
1638 NodeValue table, NodeValue idxTable,
1639 TypeRef outTy);
1640
1641 /// Create quantized log.
1642 IntLookupTableNode *createIntLog(llvm::StringRef name, NodeValue input,
1643 TypeRef outTy);
1644
1645 /// Create quantized exp.
1646 IntLookupTableNode *createIntExp(llvm::StringRef name, NodeValue input,
1647 TypeRef outTy);
1648
1649 /// Create quantized tanh.
1650 IntLookupTableNode *createIntTanh(llvm::StringRef name, NodeValue input,
1651 TypeRef outTy);
1652
1653 /// Create quantized sigmoid.
1654 IntLookupTableNode *createIntSigmoid(llvm::StringRef name, NodeValue input,
1655 TypeRef outTy);
1656
1657 TopKNode *createTopK(llvm::StringRef name, NodeValue input, unsigned_t k);
1658
1659 TopKNode *createTopK(llvm::StringRef name, NodeValue input, unsigned_t k,
1660 ElemKind outIndicesTyKind);
1661
1662 /// Given \p rpnMaxLevel , \p rpnMinLevel and \p rpnPostNmsTopN
1663 /// CollectRpnProposals merges rois in the \p roisIN based on \p roisProbIn
1664 /// and returns top proposals limited to rpnPostNmsTopN total, size (n x B),
1665 /// where B is box dimensions and based on dimension of input rois
1666 /// Format for upright boxes is (image_index, x1, y1, x2, y2).
1667 /// Format for rotated boxes (image_index, ctr_x, ctr_y, w, h, angle)
1668 /// rpnPostNmsTopN should be greater than zero.
1669 CollectRpnProposalsNode *createCollectRpnProposals(
1670 llvm::StringRef name, std::vector<NodeValue> &roisIn,
1671 std::vector<NodeValue> &roiProbsIn, int64_t rpnMaxLevel,
1672 int64_t rpnMinLevel, unsigned_t rpnPostNmsTopN);
1673
1674 /// Given \p data tensor of rank r >= 1, and \p indices tensor of rank q,
1675 /// gather entries of the \p axis dimension of data (default outer-most for
1676 /// axis = 0) indexed by indices and concatenate them in the output tensor of
1677 /// rank q + (r - 1).
1678 /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#gather
1679 GatherNode *createGather(llvm::StringRef name, NodeValue data,
1680 NodeValue indices, unsigned_t axis = 0);
1681
1682 /// Given \p data tensor of rank r >= 1, \p indices tensor of rank q >= 1,
1683 /// and \p batchDims integer b, this operator gathers slices of data
1684 /// into an output tensor of rank q + r - indices_shape[-1] - 1 - b.
1685 /// \p indices is a q-dimensional integer tensor, best thought of as a (q-1)
1686 /// dimensional tensor of index-tuples into \p data, where each element
1687 /// defines a slice of data.
1688 /// \p batchDims is an integer b indicating the number of batch dimensions,
1689 /// that is the leading number of dimensions of \p data tensor and
1690 /// \p indices that are representing the batches such that the gather starts
1691 /// from the b+1 dimension.
1692 /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#gathernd
1693 GatherNDNode *createGatherND(llvm::StringRef name, NodeValue data,
1694 NodeValue indices, unsigned_t batchDims = 0);
1695
1696 /// Create a node, performing GatherElements operation:
1697 /// GatherElements takes inputs \p data and indices of the same rank r >=
1698 /// 1 and a \p dim attribute that identifies an axis of data. It is an
1699 /// indexing operation that produces its output by indexing into the input
1700 /// data tensor at by elements of the indices tensor. Its output shape is
1701 /// the same as the shape of indices and consists of one value (gathered
1702 /// from the data) for each element in indices.
1703 GatherElementsNode *createGatherElements(llvm::StringRef name, NodeValue data,
1704 NodeValue indices, unsigned_t dim);
1705
1706 /// Create a node, performing GatherRanges operation:
1707 /// Gathers entries of \p data in groups specified by the "examples" in
1708 /// \p ranges. Each example in \p ranges contains a list of pairs of
1709 /// indices of the form (index, length) which specify which entries of \p
1710 /// data to gather. The ordering of elements in \p ranges and of pairs
1711 /// within an element is preserved in the output. In addition to the result
1712 /// of gathering ("output"), the lengths of the ranges gathered by each
1713 /// example in \p ranges is also produced as an output ("lengths").
1714 /// \p maxOutputSize is the maximum possible size of "output" and is used to
1715 /// set its type. Users must use "lengths" to interpret "output" correctly.
1716 /// \returns the GatherRangesNode.
1717 GatherRangesNode *createGatherRanges(llvm::StringRef name, NodeValue data,
1718 NodeValue ranges,
1719 unsigned_t maxOutputSize);
1720
1721 /// Copies each slice from \p slices into \p data at the corresponding index
1722 /// in \p indices, and \returns this new version of data. For example, given
1723 /// input data {{1,2},{3,4},{5,6}}, slices {{-3,-4}}, and indices {1}, the
1724 /// result is {{1,2},{-3,-4},{5,6}}. If \p cumulative is true, this node adds
1725 /// values instead of copying.
1726 ScatterDataNode *createScatterData(llvm::StringRef name, NodeValue data,
1727 NodeValue indices, NodeValue slices,
1728 bool cumulative = false);
1729
1730 /// Given 2D matrix \p data, 1D vector \p lengths (of the same size as width
1731 /// of \p data), and 1D vector \p values (of the same size as sum of
1732 /// \p lengths), expand each row of the \p data to a row of zeros and ones,
1733 /// according to One Hot Encoding. j-th element of resulting i-th row is one
1734 /// iff \p values[j] == \p data[i][some index within range of j].
1735 BatchOneHotNode *createBatchOneHot(llvm::StringRef name, NodeValue data,
1736 NodeValue lengths, NodeValue values);
1737
1738 /// Given Input tensor of [N,H,W,C], where N is the batch
1739 /// axis, H is the height, W is
1740 /// the width, C is the channel or depth. This produces Output tensor of [N,
1741 /// H/blockSize, W/blockSize, C * blockSize * blockSize].
1742 SpaceToDepthNode *createSpaceToDepth(llvm::StringRef name, NodeValue input,
1743 unsigned blockSize);
1744
1745 /// Create a sequence of Reshape and Transpose nodes representing DepthToSpace
1746 /// operator with \p blockSize in DCR or CRD mode based on \p isCRD flag.
1747 /// Assumes input layout to be NHWC. \returns the last node in the sequence.
1748 ReshapeNode *createDepthToSpace(llvm::StringRef name, NodeValue input,
1749 unsigned blockSize, bool isCRD = false);
1750
1751 /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel
1752 /// or depth, H is the height and W is the width, and \p scale tensor with
1753 /// tensor format same as \p input then ResizeNearest generates an Output
1754 /// tensor with resized spatial dimensions using nearest neighbor
1755 /// interpolation. The Output tensor is of shape [floor(N * \p scale[0]),
1756 /// floor(H * \p scale[1]), floor(W * \p scale[2]),
1757 /// floor(C * \p scale[3])]
1758 ResizeNearestNode *createResizeNearest(llvm::StringRef name, NodeValue input,
1759 llvm::ArrayRef<float> scale);
1760
1761 /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel
1762 /// or depth, H is the height and W is the width, with tensor format same as
1763 /// \p input then ResizeNearest generates an Output tensor with resized
1764 /// spatial dimensions using nearest neighbor interpolation. The Output tensor
1765 /// shape is specified with \p outTy.
1766 ResizeNearestNode *createResizeNearest(llvm::StringRef name, NodeValue input,
1767 TypeRef outTy);
1768
1769 /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel
1770 /// or depth, H is the height and W is the width, and \p scale tensor with
1771 /// tensor format same as \p input then ResizeBilinear generates an Output
1772 /// tensor with resized spatial dimensions using bilinear neighbor
1773 /// interpolation. The Output tensor is of shape [floor(N * \p scale[0]),
1774 /// floor(H * \p scale[1]), floor(W * \p scale[2]),
1775 /// floor(C * \p scale[3])]
1776 ResizeBilinearNode *createResizeBilinear(llvm::StringRef name,
1777 NodeValue input,
1778 llvm::ArrayRef<float> scale);
1779
1780 /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel
1781 /// or depth, H is the height and W is the width, with tensor format same as
1782 /// \p input then ResizeBilinear generates an Output tensor with resized
1783 /// spatial dimensions using bilinear neighbor interpolation. The Output
1784 /// tensor shape is specified with \p outTy.
1785 ResizeBilinearNode *createResizeBilinear(llvm::StringRef name,
1786 NodeValue input, TypeRef outTy);
1787
1788 /// Create quantization node which transforms floating point tensor to a
1789 /// quantized one with given Scale and Offset. Scale and Offset params are
1790 /// part of the \p outTy.
1791 QuantizeNode *createQuantize(llvm::StringRef name, NodeValue input,
1792 TypeRef outTy);
1793
1794 /// Create quantization node which transforms floating point tensor to a
1795 /// quantized one of kind \p q with given \p scale and \p offset.
1796 QuantizeNode *createQuantize(llvm::StringRef name, NodeValue input,
1797 ElemKind q, float scale, int32_t offset);
1798
1799 /// Create dequantization node which transforms quantized tensor to a
1800 /// floating point one with given Scale and Offset. Scale and Offset params
1801 /// are part of the \p input. Result dequantization kind is \p k.
1802 DequantizeNode *createDequantize(llvm::StringRef name, NodeValue input,
1803 ElemKind k);
1804
1805 /// Create dequantization node which transforms quantized tensor to a
1806 /// floating point type \p outTy one with given Scale and Offset. Scale and
1807 /// Offset params are part of the \p input.
1808 DequantizeNode *createDequantize(llvm::StringRef name, NodeValue input,
1809 TypeRef outTy);
1810
1811 /// Create transformation for quantized tensors to rescale based on the new
1812 /// Scale and Offset.
1813 RescaleQuantizedNode *createRescaleQuantized(llvm::StringRef name,
1814 NodeValue input, TypeRef outTy);
1815
1816 /// Create a series of nodes that implement a weighted sum. \p data and \p
1817 /// weights should have the same number of elements. The nodes in \p weights
1818 /// should all be of size 1. Each node d_i in \p data is element-wise
1819 /// multiplied by the corresponding weight value w_i found in \p weights,
1820 /// broadcasted to the same shape as d_i, and resulting in r_i. All r_i are
1821 /// element-wise summed, and the final add node in this sum is returned.
1822 Node *createWeightedSum(llvm::StringRef name, llvm::ArrayRef<NodeValue> data,
1823 llvm::ArrayRef<NodeValue> weights);
1824
1825 /// Create a series of nodes that implements a two-parameter
1826 /// rowwise Box-Cox transform. For each element of the \p input x, this is
1827 /// defined as:
1828 ///
1829 /// y = ln(max(x + lambda2, 1e-6)), if lambda1 == 0
1830 /// (max(x + lambda2, 1e-6)^lambda1 - 1)/lambda1, if lambda1 != 0
1831 ///
1832 /// The transform parameters \p lambda1 and \p lambda2 are vectors of size D
1833 /// that are broadcasted to match the size of \p input (NxD). The transform
1834 /// itself is implemented using elementwise Max, Add, Log (if lambda1 == 0),
1835 /// Pow, Splat, Sub, and Div (if lambda1 != 0) nodes with a Splat and Select
1836 /// node to select between the two cases listed above. \returns the final
1837 /// Select node. \p epsilon is used to ensure we do not divide by zero when
1838 /// calculating the lambda == 0 case, as we use a Select to choose which
1839 /// result to use, and so both paths are executed.
1840 Node *createBatchBoxCox(llvm::StringRef name, NodeValue input,
1841 NodeValue lambda1, NodeValue lambda2,
1842 float epsilon = std::numeric_limits<float>::min());
1843
1844 /// Create a Clip node with the given \p name, \p input, minimum clip value
1845 /// \p min, maximum clip value \p max and output type \p outTy.
1846 ClipNode *createClip(llvm::StringRef name, NodeValue input, TypeRef outTy,
1847 float min, float max);
1848
1849 /// Create a Clip node with the given \p name, \p input, minimum clip value
1850 /// \p min, maximum clip value \p max. Result type will be implicitly set
1851 /// based on the \p input type.
1852 ClipNode *createClip(llvm::StringRef name, NodeValue input, float min,
1853 float max);
1854
1855 /// Creates and \returns a ClipNode to the min/max range of FP16 with \p name
1856 /// of \p input. Result type will be implicitly set based on the \p input
1857 /// type.
1858 ClipNode *createClipMinMaxFP16(llvm::StringRef name, NodeValue input);
1859
1860 /// Creates and \returns a ClipNode to the min/max range of BFloat16 with \p
1861 /// name of \p input. Result type will be implicitly set based on the \p input
1862 /// type.
1863 ClipNode *createClipMinMaxBFloat16(llvm::StringRef name, NodeValue input);
1864
1865 /// @name The builder functions below are identical to the builder functions
1866 /// above except that they create nodes that use Placeholder instead of
1867 /// Variables. The methods create and initialize the tensors in the
1868 /// PlaceholderBindings. As soon as we finish the Placeholder migration we'll
1869 /// delete these methods and merge them with the builder methods above. See
1870 /// issue #1334.
1871 ///@{
1872
1873 BatchNormalizationNode *
1874 createBatchNormalization(PlaceholderBindings &bindings, llvm::StringRef name,
1875 NodeValue input, unsigned_t channelIdx = 0,
1876 float epsilon = 1e-5, float momentum = 0.9);
1877
1878 /// Create a BatchedUnaryEmbedding node. \p offsets which marks the end
1879 /// of the last range.
1880 BatchedUnaryEmbeddingsBagsNode *
1881 createBatchedUnaryEmbeddingsBags(llvm::StringRef name, NodeValue weights,
1882 NodeValue tableOffsets, NodeValue indices,
1883 NodeValue offsets);
1884
1885 /// Create an IntNBitSplitEmbeddingBags node.
1886 IntNBitSplitEmbeddingBagsNode *createIntNBitSplitEmbeddingBags(
1887 llvm::StringRef name, NodeValue devWeights, NodeValue uvmWeights,
1888 NodeValue weightsPlacements, NodeValue weightsOffsets,
1889 NodeValue weightsTys, NodeValue dimOffsets, int64_t totalDims,
1890 NodeValue indices, NodeValue offsets,
1891 SplitEmbeddingPoolingMode poolingMode,
1892 SplitEmbeddingSparseType outputDtype);
1893
1894 /// Create an IntNBitSplitEmbeddingWeightedBags node.
1895 IntNBitSplitEmbeddingWeightedBagsNode *
1896 createIntNBitSplitEmbeddingWeightedBags(
1897 llvm::StringRef name, NodeValue devWeights, NodeValue uvmWeights,
1898 NodeValue weightsPlacements, NodeValue weightsOffsets,
1899 NodeValue weightsTys, NodeValue dimOffsets, int64_t totalDims,
1900 NodeValue indices, NodeValue offsets,
1901 SplitEmbeddingPoolingMode poolingMode,
1902 SplitEmbeddingSparseType outputDtype, NodeValue indiceWeights);
1903
1904 /// Creates a ConvolutionNode with the given \p name which convolves the 4D
1905 /// \p input. \p kernels defines the size of the height and width dimensions
1906 /// of the convolutional filters. \p stride defines the the number of steps
1907 /// to take in the input for each output cell. \p pads defines how many zero
1908 /// padding cells should be added to the input during convolution. \p group
1909 /// defines the number of groups the input and output channels should be
1910 /// divided into and convolved separately. \p dilation defines factor by
1911 /// which gap between 2 neighboring kernel elements is expanded along each
1912 /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW.
1913 ConvolutionNode *createConv(PlaceholderBindings &bindings,
1914 llvm::StringRef name, NodeValue input,
1915 dim_t outChannels,
1916 llvm::ArrayRef<unsigned_t> kernels,
1917 llvm::ArrayRef<unsigned_t> strides,
1918 llvm::ArrayRef<unsigned_t> pads, unsigned_t group,
1919 llvm::ArrayRef<unsigned_t> dilation = {1, 1},
1920 ConvolutionLayout layout = NHWC);
1921
1922 /// Creates a ConvolutionNode with the given \p name which convolves the 4D
1923 /// \p input. \p kernel defines the size of the height and width dimensions of
1924 /// the convolutional filters. \p stride defines the the number of steps to
1925 /// take in the input for each output cell. \p pad defines how many zero
1926 /// padding cells should be added to the input during convolution. \p group
1927 /// defines the number of groups the input and output channels should be
1928 /// divided into and convolved separately.\p dilation defines factor by
1929 /// which gap between 2 neighboring kernel elements is expanded along each
1930 /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW.
1931 ConvolutionNode *createConv(PlaceholderBindings &bindings,
1932 llvm::StringRef name, NodeValue input,
1933 dim_t outChannels, unsigned_t kernel,
1934 unsigned_t stride, unsigned_t pad,
1935 unsigned_t group,
1936 llvm::ArrayRef<unsigned_t> dilation = {1, 1},
1937 ConvolutionLayout layout = NHWC);
1938
1939 /// Creates a Convolution3DNode with the given \p name which convolves the 5D
1940 /// \p input. \p kernels defines the size of the height, width, and depth
1941 /// dimensions of the convolutional filters. \p strides defines the the number
1942 /// of steps to take in the input for each output cell. \p pads defines how
1943 /// many zero padding cells should be added to the input during convolution.
1944 /// \p group defines the number of groups the input and output channels should
1945 /// be divided into and convolved separately.
1946 Convolution3DNode *createConv3D(PlaceholderBindings &bindings,
1947 llvm::StringRef name, NodeValue input,
1948 dim_t outChannels,
1949 llvm::ArrayRef<unsigned_t> kernels,
1950 llvm::ArrayRef<unsigned_t> strides,
1951 llvm::ArrayRef<unsigned_t> pads,
1952 unsigned_t group);
1953
1954 /// Creates a Convolution3DNode with the given \p name which convolves the 5D
1955 /// \p input. \p kernel defines the size of the height, width, and depth
1956 /// dimensions of the convolutional filters. \p stride defines the the number
1957 /// of steps to take in the input for each output cell. \p pad defines how
1958 /// many zero padding cells should be added to the input during convolution.
1959 /// \p group defines the number of groups the input and output channels should
1960 /// be divided into and convolved separately.
1961 Convolution3DNode *createConv3D(PlaceholderBindings &bindings,
1962 llvm::StringRef name, NodeValue input,
1963 size_t outChannels, unsigned_t kernel,
1964 unsigned_t stride, unsigned_t pad,
1965 unsigned_t group);
1966
1967 /// Creates a ConvTransposeNode with the given \p name which does transposed
1968 /// convolution on the 4D \p input. \p kernels define the size of the height
1969 /// and width dimensions of the convolution filters. \p strides define the
1970 /// number of steps to take in the input for each output cell. \p pads define
1971 /// how many zero padding cells should be added to the input during
1972 /// convolution. \p group defines the number of groups the input and output
1973 /// channels should be divided into and convolved separately.
1974 ConvTransposeNode *createConvTranspose(
1975 PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
1976 dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels,
1977 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
1978 unsigned_t group, llvm::ArrayRef<unsigned_t> dilation = {1, 1});
1979
1980 /// Creates a ConvTransposeNode with the given \p name which does transposed
1981 /// convolution on the 4D \p input. \p kernel defines the size of the height
1982 /// and width dimensions of the convolution filters. \p stride defines the
1983 /// number of steps to take in the input for each output cell. \p pad defines
1984 /// how many zero padding cells should be added to the input during
1985 /// convolution. \p group defines the number of groups the input and output
1986 /// channels should be divided into and convolved separately.
1987 ConvTransposeNode *
1988 createConvTranspose(PlaceholderBindings &bindings, llvm::StringRef name,
1989 NodeValue input, dim_t outChannels, unsigned_t kernel,
1990 unsigned_t stride, unsigned_t pad, unsigned_t group,
1991 llvm::ArrayRef<unsigned_t> dilation = {1, 1});
1992
1993 /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
1994 /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened
1995 /// along \p axis. Note, output type is inferred based on the input
1996 /// types. Trainable weight and bias variables are created implicitly.
1997 FullyConnectedNode *createFullyConnected(PlaceholderBindings &bindings,
1998 llvm::StringRef name,
1999 NodeValue input, dim_t outDepth,
2000 unsigned_t axis = 1);
2001
2002 /// Creates an RMSNorm pair. \p X should be a 2D tensor, \p gamma and \p beta
2003 /// should be 1D tensors.
2004 std::array<Node *, 2> createRMSNorm(llvm::StringRef name, NodeValue X,
2005 NodeValue gamma, NodeValue beta,
2006 float epsilon = .0f);
2007
2008 /// Create an unrolled single-layer Simple RNN cell with \p hiddenSize
2009 /// dimensionality of the hidden state and \p outputSize dimensionality of the
2010 /// output state. \p inputs define the input for the cell at each time step
2011 /// and the number of time steps is equal to the size of the \p inputs. The
2012 /// names of the created variables are prefixed by \p namePrefix.
2013 /// The output variables are written to \p outputs, they represent the
2014 /// activations of the output layer, unrolled over time.
2015 // The dimensionality of the output variables is \p batchSize x \p outputSize.
2016 void createSimpleRNN(PlaceholderBindings &bindings,
2017 llvm::StringRef namePrefix,
2018 const llvm::ArrayRef<NodeValue> inputs,
2019 unsigned batchSize, unsigned hiddenSize,
2020 unsigned outputSize, std::vector<NodeValue> &outputs);
2021
2022 /// Create an unrolled single-layer GRU cell with \p hiddenSize
2023 /// dimensionality of the hidden state and \p outputSize dimensionality of the
2024 /// output state. \p inputs define the input for the cell at each time step
2025 /// and the number of time steps is equal to the size of the \p inputs. The
2026 /// names of the created variables are prefixed by \p namePrefix.
2027 /// The output variables are written to \p outputs, they represent the
2028 /// activation of the output layer, unrolled over time.
2029 // The dimensionality of the output variables is \p batchSize x \p outputSize.
2030 void createGRU(PlaceholderBindings &bindings, llvm::StringRef namePrefix,
2031 const llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
2032 unsigned hiddenSize, unsigned outputSize,
2033 std::vector<NodeValue> &outputs);
2034
2035 /// Create an unrolled single-layer LSTM cell with \p hiddenSize
2036 /// dimensionality of the hidden state and \p outputSize dimensionality of the
2037 /// output state. \p inputs define the input for the cell at each time step
2038 /// and the number of time steps is equal to the size of the \p inputs. The
2039 /// names of the created variables are prefixed by \p namePrefix.
2040 /// The output variables are written to \p outputs, they represent the
2041 /// activation of the output layer, unrolled over time.
2042 // The dimensionality of the output variables is \p batchSize x \p outputSize.
2043 void createLSTM(PlaceholderBindings &bindings, llvm::StringRef namePrefix,
2044 const llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
2045 unsigned hiddenSize, unsigned outputSize,
2046 std::vector<NodeValue> &outputs);
2047
2048 /// Create an LSTM Unit Node with \p Input which shape is [batch,
2049 /// 4*hiddenSize] and follow the order i, f, g, o, and \p C as current cell
2050 /// state.
2051 LSTMUnitNode *createLSTMUnit(llvm::StringRef namePrefix, NodeValue Input,
2052 NodeValue C);
2053
2054 /// Helper function create a PyTorch style LSTM for one direction, and returns
2055 /// every output in a vector. \p T should be an iterator or reverse_iterator
2056 /// of a NodeValue vector, and /p inputItr is an iterator pointer of the input
2057 /// vector. \p Wx, \p Wh, \p Bx, \p Bh, \p H and \p C is i, f, g, o, hidden
2058 /// state and cell state, whose shape should be the same to
2059 /// createSingleDirectionLSTM.
2060 template <class T>
2061 std::vector<NodeValue> createSingleDirectionLSTM(
2062 std::string nameBase, T inputItr, const int timeSteps, NodeValue Wx,
2063 NodeValue Wh, NodeValue Bx, NodeValue Bh, NodeValue &H, NodeValue &C);
2064
2065 /// Helpfer function to create Pytorch Style Multiple Layer STM for one
2066 /// direction
2067 std::vector<NodeValue> createMultipleLayerSingleDirectionLSTM(
2068 std::string nameBase, NodeValue input, unsigned batchSize,
2069 unsigned inputSize, const int timeSteps, std::vector<NodeValue> &Wx,
2070 std::vector<NodeValue> &Wh, std::vector<NodeValue> &Bx,
2071 std::vector<NodeValue> &Bh, NodeValue &H, NodeValue &C);
2072
2073 /// Helpfer function to create sliced input for LSTM
2074 std::vector<NodeValue>
2075 createSlicedInput(NodeValue input, std::string &nameBase, unsigned batchSize,
2076 unsigned inputSize, const int timeSteps);
2077
2078 /// Create PyTorch style LSTM with fixed weights and biases.
2079 /// The order of \p Wx \p Wh \p Bx and \p Bh is i, f, g, o,
2080 /// The \p inputs shape should be (numSteps, batchSize, hiddenSize),
2081 /// while \p Wx shape should be (inputSize, hiddenSize * 4),
2082 /// Wh shape should be (hiddenSize, hiddenSize * 4),
2083 /// \p Bx and \p Bh shape should be (hiddenSize * 4).
2084 /// If \p isBidirectional == true, \p WxR, \p WhR, \p BxR and \p BhR
2085 /// also need to be provided, indicates the reversed weights and biases.
2086 /// \p Ht and \p Ct are initial hidden state and cell.
2087 /// For more details, please read:
2088 /// https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
2089 void createPyTorchLSTM(llvm::StringRef namePrefix, NodeValue inputs,
2090 std::vector<NodeValue> &Wx, std::vector<NodeValue> &Wh,
2091 std::vector<NodeValue> &Bx, std::vector<NodeValue> &Bh,
2092 NodeValue &Ht, NodeValue &Ct, NodeValue &outputs,
2093 bool isBidirectional = false,
2094 NodeValue WxR = NodeValue(),
2095 NodeValue WhR = NodeValue(),
2096 NodeValue BxR = NodeValue(),
2097 NodeValue BhR = NodeValue());
2098
2099 /// Type definition for the direction of an RNN module (RNN, GRU, LSTM).
2100 enum class RnnDirection {
2101 Forward,
2102 Reverse,
2103 Bidirectional,
2104 };
2105
2106 /// Definition for a lambda used to create an activation node for RNN modules.
2107 using RnnActivation = std::function<Node *(llvm::StringRef, Node *)>;
2108
2109 /// Create an unrolled multi-layer RNN according to the ONNX definition:
2110 /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN
2111 /// The RNN has the following inputs:
2112 /// - input \p X with size [S, B, ISize].
2113 /// - weigts \p W with size [N, HSize, ISize].
2114 /// - reccurence weights \p R with size [N, HSize, HSize].
2115 /// - bias weights \p B with size [N, 2 * HSize].
2116 /// - initial hidden state \p initial_h with size [N, B, HSize].
2117 /// where S is the sequence length, N is the number of directions, B is the
2118 /// batch size, ISize is the input size and HSize is the hidden size.
2119 /// The RNN has the following outputs:
2120 /// - output \p Y with size [S, N, B, HSize].
2121 /// - final hidden state \p Y_h with size [N, B, HSize].
2122 /// The direction of the instatiated RNN is given by \p direction. The RNN
2123 /// will use the activation functions defined by the \p activations array:
2124 /// - [f] in case the RNN is unidirectional (1 function).
2125 /// - [f] for the forward cell followed by [f] for the reverse cell in
2126 /// case the RNN is bidirectional (4 functions).
2127 /// The inputs \p B and \p initial_h are optional (assumed 0 if nullptr is
2128 /// provided). The names of all the nodes created are prefixed with
2129 /// \p namePrefix.
2130 void createOnnxRNN(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
2131 NodeValue R, NodeValue B, NodeValue initial_h,
2132 NodeValue &Y, NodeValue &Y_h, unsigned hiddenSize,
2133 RnnDirection direction,
2134 std::vector<RnnActivation> &activations);
2135
2136 /// Create an unrolled multi-layer GRU according to the ONNX definition:
2137 /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU
2138 /// The GRU has the following inputs:
2139 /// - input \p X with size [S, B, ISize].
2140 /// - weigts \p W with size [N, 3 * HSize, ISize].
2141 /// - reccurence weights \p R with size [N, 3 * HSize, HSize].
2142 /// - bias weights \p B with size [N, 6 * HSize].
2143 /// - initial hidden state \p initial_h with size [N, B, HSize].
2144 /// where S is the sequence length, N is the number of directions, B is the
2145 /// batch size, ISize is the input size and HSize is the hidden size.
2146 /// The GRU has the following outputs:
2147 /// - output \p Y with size [S, N, B, HSize].
2148 /// - final hidden state \p Y_h with size [N, B, HSize].
2149 /// The direction of the instatiated GRU is given by \p direction. The GRU
2150 /// will use the activation functions defined by the \p activations array:
2151 /// - [f,g] in case the GRU is unidirectional (2 functions).
2152 /// - [f,g] for the forward cell followed by [f,g] for the reverse cell in
2153 /// case the GRU is bidirectional (4 functions).
2154 /// The inputs \p B and \p initial_h are optional (assumed 0 if nullptr is
2155 /// provided). The names of all the nodes created are prefixed with
2156 /// \p namePrefix. The boolean parameter \p linearBeforeReset defines whether
2157 /// the reset for the previous hidden state occurs before/after the linear
2158 /// expression.
2159 void createOnnxGRU(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
2160 NodeValue R, NodeValue B, NodeValue initial_h,
2161 NodeValue &Y, NodeValue &Y_h, unsigned hiddenSize,
2162 RnnDirection direction,
2163 std::vector<RnnActivation> &activations,
2164 bool linearBeforeReset = false);
2165
2166 /// Create an unrolled multi-layer LSTM according to the ONNX definition:
2167 /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
2168 /// The LSTM has the following inputs:
2169 /// - input \p X with size [S, B, ISize].
2170 /// - weigts \p W with size [N, 4 * HSize, ISize].
2171 /// - reccurence weights \p R with size [N, 4 * HSize, HSize].
2172 /// - bias weights \p B with size [N, 8 * HSize].
2173 /// - initial hidden state \p initial_h with size [N, B, HSize].
2174 /// - initial cell state \p initial_c with size [N, B, HSize].
2175 /// - peephole weights \p P with size [N, 3 * HSize].
2176 /// where S is the sequence length, N is the number of directions, B is the
2177 /// batch size, ISize is the input size and HSize is the hidden size.
2178 /// The LSTM has the following outputs:
2179 /// - output \p Y with size [S, N, B, HSize].
2180 /// - final hidden state \p Y_h with size [N, B, HSize].
2181 /// - final cell state \p Y_c with size [N, B, HSize].
2182 /// The direction of the instatiated LSTM is given by \p direction. The LSTM
2183 /// will use the activation functions defined by \p activations array:
2184 /// - [f,g,h] in case the LSTM is unidirectional (3 functions).
2185 /// - [f,g,h] for the forward cell followed by [f,g,h] for the reverse cell in
2186 /// case the LSTM is bidirectional (6 functions).
2187 /// The inputs \p B, \p initial_h, \p initial_c and \p P are optional (assumed
2188 /// 0 if nullptr is provided). The names of all the nodes created are prefixed
2189 /// with \p namePrefix. The boolean parameter \p inputForget defines whether
2190 /// the input and forget gates should be coupled (compute the input gate from
2191 /// the forget gate).
2192 void createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
2193 NodeValue R, NodeValue B, NodeValue initial_h,
2194 NodeValue initial_c, NodeValue P, NodeValue &Y,
2195 NodeValue &Y_h, NodeValue &Y_c, unsigned hiddenSize,
2196 RnnDirection direction,
2197 std::vector<RnnActivation> &activations,
2198 bool inputForget = false);
2199 /// @}
2200
2201 /// Create a TraceEvent in the runtime profile, which triggers collection of
2202 /// runtime statistics.
2203 TraceEventNode *createTraceEvent(llvm::StringRef eventName,
2204 llvm::StringRef eventType, Node *data,
2205 unsigned index);
2206
2207 /// Creates NMSv4 node that does NMS for one class.
2208 /// Inputs
2209 /// - \p boxes Tensor with box coordinates.
2210 /// - \p scores Tensor with scores per box.
2211 /// - \p centerPointBox Indicates format of the box per ONNX spec.
2212 /// - \p iouThreshold Threshold for box overlap.
2213 /// - \p scoreThreshold Threshold for box scores.
2214 NonMaxSuppressionNode *
2215 createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes,
2216 NodeValue scores, int64_t centerPointBox,
2217 int64_t maxOutputBoxesPerClass, float iouThreshold,
2218 float scoreThreshold);
2219
2220 /// Creates NMSv4 node that does NMS for one class.
2221 /// Inputs
2222 /// - \p boxes Tensor with box coordinates.
2223 /// - \p scores Tensor with scores per box.
2224 /// - \p centerPointBox Indicates format of the box per ONNX spec.
2225 /// - \p iouThreshold Threshold for box overlap.
2226 /// - \p scoreThreshold Threshold for box scores.
2227 /// - \p ElemKind Output ElemKind.
2228 NonMaxSuppressionNode *
2229 createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes,
2230 NodeValue scores, int64_t centerPointBox,
2231 int64_t maxOutputBoxesPerClass, float iouThreshold,
2232 float scoreThreshold, ElemKind elTy);
2233
2234 /// Creates NMSv4 node that does NMS for one class.
2235 /// Inputs
2236 /// - \p boxes Tensor with box coordinates.
2237 /// - \p scores Tensor with scores per box.
2238 /// - \p centerPointBox Indicates format of the box per ONNX spec.
2239 /// - \p iouThreshold Threshold for box overlap.
2240 /// - \p scoreThreshold Threshold for box scores.
2241 /// - \p indicesTy Type of indices output.
2242 /// - \p numberOfSelectedIndicesTy \p Type of second output for number of
2243 /// boxes detected.
2244 NonMaxSuppressionNode *
2245 createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes,
2246 NodeValue scores, int64_t centerPointBox,
2247 int64_t maxOutputBoxesPerClass, float iouThreshold,
2248 float scoreThreshold, TypeRef indicesTy,
2249 TypeRef numberOfSelectedIndicesTy);
2250
2251 /// Performs class wise NMS based on ONNX specification, with padding and ONNX
2252 /// layout output.
2253 /// Inputs
2254 /// - \p boxes Tensor with box coordinates.
2255 /// - \p scores Tensor with scores per box.
2256 /// - \p centerPointBox Indicates format of the box per ONNX spec.
2257 /// - \p iouThreshold Threshold for box overlap.
2258 /// - \p scoreThreshold Threshold for box scores.
2259 NonMaxSuppressionNode *
2260 createNonMaxSuppressionONNX(llvm::StringRef name, NodeValue boxes,
2261 NodeValue scores, int64_t centerPointBox,
2262 int64_t maxOutputBoxesPerClass,
2263 float iouThreshold, float scoreThreshold);
2264
2265 /// Performs class wise NMS based on ONNX specification, with padding and ONNX
2266 /// layout output.
2267 /// Inputs
2268 /// - \p boxes Tensor with box coordinates.
2269 /// - \p scores Tensor with scores per box.
2270 /// - \p centerPointBox Indicates format of the box per ONNX spec.
2271 /// - \p iouThreshold Threshold for box overlap.
2272 /// - \p scoreThreshold Threshold for box scores.
2273 NonMaxSuppressionNode *createNonMaxSuppressionONNX(
2274 llvm::StringRef name, NodeValue boxes, NodeValue scores,
2275 int64_t centerPointBox, int64_t maxOutputBoxesPerClass,
2276 float iouThreshold, float scoreThreshold, ElemKind elTy);
2277
2278 /// Performs class wise NMS based on ONNX specification, with padding and ONNX
2279 /// layout output.
2280 /// Inputs
2281 /// - \p boxes Tensor with box coordinates.
2282 /// - \p scores Tensor with scores per box.
2283 /// - \p centerPointBox Indicates format of the box per ONNX spec.
2284 /// - \p iouThreshold Threshold for box overlap.
2285 /// - \p scoreThreshold Threshold for box scores.
2286 NonMaxSuppressionNode *createNonMaxSuppressionONNX(
2287 llvm::StringRef name, NodeValue boxes, NodeValue scores,
2288 int64_t centerPointBox, int64_t maxOutputBoxesPerClass,
2289 float iouThreshold, float scoreThreshold, TypeRef indicesTy);
2290
2291 /// Create a TensorFlowLite custom node called "DetectionPostProcess" which
2292 /// corresponds to a custom NonMaxSuppresion node.
2293 /// The node has the following inputs:
2294 /// - \p boxes with size [N, B, 4]
2295 /// - \p scores with size [N, B, C]
2296 /// - \p anchors with size [B, 4]
2297 /// where N is the batch size, B is the number of boxes and C is the number
2298 /// of classes.
2299 /// The node has the following attributes (parameters):
2300 /// - \p numClasses - Number of effective classes (without the background).
2301 /// - \p maxDetections - The maximum number of detections.
2302 /// - \p maxClassesPerDetection - Maximum classes per detection (Fast NMS).
2303 /// - \p maxDetectionsPerClass - Maximum detections per class (Regular NMS).
2304 /// - \p iouThreshold - Detection threshold for IoU metric.
2305 /// - \p scoreThreshold - Detection threshold for scores.
2306 /// - \p xScale - X scale used for decoding the boxes.
2307 /// - \p yScale - Y scale used for decoding the boxes.
2308 /// - \p hScale - H scale used for decoding the boxes.
2309 /// - \p wScale - W scale used for decoding the boxes.
2310 /// - \p regularNMS - Whether the NMS is "Regular" or "Fast".
2311 /// The node will have the following outputs:
2312 /// - DetectionBoxes - the chosen boxes (float)
2313 /// - DetectionClasses - the classes of the chosen boxes (int32)
2314 /// - DetectionScores - the scores of the chosen boxes (float)
2315 /// - NumDetections - number of chosen (detected) boxes (int32)
2316 /// The first three output tensors will be allocated using the maximum
2317 /// number of possible detections (worst case scenario) but the actual
2318 /// usage will be given by the 'NumDetections' output.
2319 TFLiteDetectionPostProcessNode *createTFLiteDetectionPostProcess(
2320 llvm::StringRef name, NodeValue boxes, NodeValue scores,
2321 NodeValue anchors, int32_t numClasses, int32_t maxDetections,
2322 int32_t maxClassesPerDetection, int32_t maxDetectionsPerClass,
2323 float iouThreshold, float scoreThreshold, float xScale, float yScale,
2324 float hScale, float wScale, bool regularNMS);
2325
2326 /// Create a constant node with a 1D cosine windowing function defined as:
2327 /// w[n] = 0.5 - 0.5 * cos(2 * pi * n / N) for n = 0 .. N - 1 where N
2328 /// is the window \p length. The node name will be \p name.
2329 Constant *createCosineWindow(llvm::StringRef name, dim_t length);
2330
2331 /// Create a constant node with the twiddle factors for a 1D complex FFT:
2332 /// W(N, k) = exp(-j * 2 * pi * k / N) for k = 0 ... N -1, where N is the
2333 /// \p fftLength. The constant node will contain 2 * \p fftLength real float
2334 /// values corresponding to \p fftLength complex values with the real and
2335 /// imaginary parts interleaved: real[0], imag[0], real[1], imag[1], etc.
2336 /// The node name will be \p name.
2337 Constant *createFFTTwiddleFactors(llvm::StringRef name, dim_t fftLength);
2338
2339 /// Create a constant node with the bit reverse indices for a 1D FFT, that
2340 /// is the corresponding index obtained after reversing the bit order for
2341 /// each of the values k = 0 ... N -1 where N is the \p fftLength. The node
2342 /// will contain \p fftLength int32 values. The node name will be \p name.
2343 Constant *createFFTBitReverseIndices(llvm::StringRef name, dim_t fftLength);
2344
2345 /// Create a constant node with the complex weights used to map the results
2346 /// of N/2 point complex FFT to a N point real FFT. This allows an efficient
2347 /// implementation of the N point FFT for a real data x[n] with n = 0 .. N-1
2348 /// by first computing the N/2 complex FFT G[k] for the complex signal g[n]
2349 /// defined as g[n] = x[2*n+0] + j * x[2*n+1] with n = 0 ... N/2-1 and then
2350 /// computing the final N point FFT X[k] for the original data x[n] by using
2351 /// X[k] = G[k] * A[k] + conj(G[N/2-k]) * (1 - A[k]) for k = 0 ... N/2 (for
2352 /// a real signal the FFT is conjugate symmetrical and therefore only the
2353 /// first N/2+1 output points of X[k] should be computed, the others being
2354 /// redundant). The relation should also use the definitions G[N/2] = G[0] and
2355 /// then A[k] = 1/2 * (1 - j * exp(-j * 2 * pi * k / N)) for k = 0 ... N/2.
2356 /// The FFT length parameter N is given as \p fftLength. This constant node
2357 /// will contain the complex values of A[k] for k = 0 ... L-1 where L is the
2358 /// sequence length given as \p outLength (the required length L is smaller
2359 /// than N/2+1 since A[k] has such properties that the second half of the
2360 /// sequence can be easily deduced from first half). This constant node will
2361 /// contain 2 * \p outLength real float values corresponding to \p outLength
2362 /// complex values A[k] with the real and imaginary parts interleaved.
2363 Constant *createFFTComplexToRealWeights(llvm::StringRef name, dim_t fftLength,
2364 dim_t outLength);
2365
2366 /// This node computes the spectrogram of a 1D mono audio signal \p input by
2367 /// extracting windows of size \p windowSize with stride \p windowStride and
2368 /// computing for each window the spectrum power (magnitude squared) or simply
2369 /// the magnitude depending on the flag \p magnitudeSquared. If the length of
2370 /// the \p input is [inputLength] samples then the size of the spectrogram is
2371 /// [windowCount, fftLength/2+1] where:
2372 /// - windowCount = floor((inputLength-windowSize)/windowStride)+1 is the
2373 /// number of windows extracted from the input.
2374 /// - fftLength is the FFT length used to compute the spectrogram which is the
2375 /// next power of 2 (e.g. for a window size of 640 the fftLength is 1024).
2376 /// The input audio data values are commonly float values scaled in the range
2377 /// [-1.0, 1.0]. If the audio data is decoded from a WAV file into int8/int16
2378 /// values then those values are commonly scaled down with 2^7/2^15 before
2379 /// using this node. The node name will be \p name. This node is inspired from
2380 /// TensorFlow (tensorflow.python.ops.gen_audio_ops.audio_spectrogram).
2381 AudioSpectrogramNode *createAudioSpectrogram(llvm::StringRef name,
2382 NodeValue input,
2383 int64_t windowSize,
2384 int64_t windowStride,
2385 bool magnitudeSquared = true);
2386
2387 /// Create as constants the Mel weights \p melWeights and ranges \p melRanges
2388 /// required for the MFCC (Mel Frequency Cepstral Coefficient) transform for a
2389 /// spectrogram of length \p spectrogramLength (which must be of the form
2390 /// 2 ^ N + 1) obtained for an audio signal with the given \p sampleRate
2391 /// (in Hertz) by mapping the spectrogram coefficients in \p filterBankCount
2392 /// bins on a Mel scale between \p lowerFrequency and \p upperFrequency
2393 /// (in Hertz) using a filterbank of triangular windows. The constant nodes
2394 /// will be named using \p prefix.
2395 void createMelWeights(llvm::StringRef prefix, dim_t spectrogramLength,
2396 float sampleRate, float lowerFrequency,
2397 float upperFrequency, dim_t filterBankCount,
2398 Constant *&melWeights, Constant *&melRanges);
2399
2400 /// Create the DCT-II transform matrix coefficients as a constant defined as:
2401 /// d[k][n] = sqrt(2 / N) * cos(pi / N * (n + 1/2) * k) with n = 0 .. N - 1
2402 /// and k = 0 .. K - 1 where \p N is the input data length and \p K is the
2403 /// output data length. The common case is that for which the input length
2404 /// \p N is equal to the output length \p K but a separate output length
2405 /// argument \p K <= \p N allows creating a partial DCT matrix used to compute
2406 /// only the first \p K results from the full DCT-II transform. The DCT matrix
2407 /// size will be \p K x \p N. The node name will be \p name.
2408 Constant *createDCTMat(llvm::StringRef name, dim_t N, dim_t K);
2409
2410 /// Computes the MFCC (Mel Frequency Cepstral Coefficient) for the given
2411 /// \p spectrogram and is commonly used as feature extractor for voice/speech
2412 /// audio data in voice command or keyword spotting applications. The input
2413 /// \p spectrogram is a power spectrogram and not a magnitude (computed using
2414 /// the 'AudioSpectrogram' node with the 'magnitudeSquared' flag set to True).
2415 /// The MFCC transform is computed using the given \p sampleRate (in Hertz)
2416 /// by mapping the spectrogram coefficients in \p filterBankCount bins on a
2417 /// Mel scale between \p lowerFrequency and \p upperFrequency (in Hertz) using
2418 /// a filterbank of triangular windows, taking the natural logarithm and then
2419 /// keeping the first \p numCoefficients from the DCT-II transform. If the
2420 /// input \p spectrogram size is [windowCount, spectrogramLen] then the output
2421 /// node size will be [windowCount, numCoefficients] since the MFCC transform
2422 /// is performed separately for each window of [spectrogramLen] input samples
2423 /// by yielding \p numCoefficients output samples. This node is inspired from
2424 /// TensorFlow (tensorflow.python.ops.gen_audio_ops.mfcc).
2425 MFCCNode *createMFCC(llvm::StringRef name, NodeValue spectrogram,
2426 float sampleRate, float lowerFrequency,
2427 float upperFrequency, int64_t filterBankCount,
2428 int64_t numCoefficients);
2429
2430 /// Performs the ROIAlign operation given the \p featureMap and the \p boxes.
2431 /// ROIAlign is similar to crop and resize followed by pooling. The
2432 /// co-ordinates to extract the crops are specified in \p boxes. Each cropped
2433 /// image has to be resized to have the shape specified by \p outputHeight and
2434 /// \p outputWidth. The \p samplingRatio specifies the number of samples to
2435 /// take from each bin (along both the dimensions) for the purpose of pooling.
2436 /// This node is defined in:
2437 /// (https://github.com/onnx/onnx/blob/master/docs/Operators.md#RoiAlign).
2438 /// \p aligned flag is an addition to Onnx definition to indicate if box
2439 /// coordinates are aligned to the center of a pixel (VS top-left corner).
2440 ROIAlignNode *createROIAlign(llvm::StringRef name, NodeValue featureMap,
2441 NodeValue boxes, NodeValue batchIndices,
2442 uint32_t outputHeight, uint32_t outputWidth,
2443 uint32_t samplingRatio, float spatialScale,
2444 bool aligned, bool rotated = false,
2445 PoolingMode mode = PoolingMode::AVG);
2446
2447 /// Transform proposal bounding boxes to target bounding box using bounding
2448 /// box regression deltas.
2449 /// Inputs:
2450 /// \p rois - Bounding box proposals in pixel coordinates.
2451 /// Size (M, 4), format [x1, y1, x2, y2], or
2452 /// Size (M, 5), format [batch_index, x1, y1, x2, y2].
2453 /// If proposals from multiple images in a batch are present, they
2454 /// should be grouped sequentially and in incremental order.
2455 /// For rotated boxes, this would have an additional angle (in degrees)
2456 /// in the format [<optionaal_batch_id>, ctr_x, ctr_y, w, h, angle].
2457 /// \p deltas - bounding box translations and scales,
2458 /// size (M, 4*K), format [dx, dy, dw, dh], K = # classes.
2459 /// For rotated boxes, size (M, 5*K, format [dx, dy, dw, dh, da].)
2460 /// \p imInfo - Image dimensions, size (batch_size, 3),
2461 /// format [img_height, img_width, img_scale]
2462 /// Arguments:
2463 /// \p weights - vector<float> weights [wx, wy, ww, wh] for the deltas
2464 /// \p applyScale - transform the boxes to the scaled image space after
2465 /// applying the bbox deltas. Set to false to match the detectron code, set to
2466 /// true for keypoint models and for backward compatibility rotated - If true,
2467 /// then boxes (rois and deltas) include angle info to handle rotation. The
2468 /// format will be [ctr_x, ctr_y, width, height, angle (in degrees)].
2469 /// \p angleBoundOn - If set, for rotated boxes, angle is normalized to be
2470 /// within [angle_bound_lo, angle_bound_hi].
2471 /// \p angleBoundLo - If set, for rotated boxes, angle is normalized to be
2472 /// within [angle_bound_lo, angle_bound_hi].
2473 /// \p angleBoundHi - If set, for rotated boxes, angle is normalized to be
2474 /// within [angle_bound_lo, angle_bound_hi].
2475 /// \p clipAngleThresh - For RRPN, clip almost horizontal boxes within this
2476 /// threshold of tolerance for backward compatibility. Set to negative value
2477 /// for no clipping.
2478 /// Outputs:
2479 /// boxOut - Pixel coordinates of the transformed bounding boxes,
2480 /// Size (M, 4*K), format [x1, y1, x2, y2]. For rotated boxes, size (M, 5*K),
2481 /// format [ctr_x, ctr_y, w, h, angle].
2482 /// roiBatchSplits - Tensor of shape (batch_size) with each element
2483 /// denoting the number of RoIs belonging to the corresponding image in batch
2484 /// See definition:
2485 /// https://github.com/pytorch/pytorch/blob/master/caffe2/operators/bbox_transform_op.cc#L10
2486 BBoxTransformNode *
2487 createBBoxTransform(llvm::StringRef name, NodeValue rois, NodeValue deltas,
2488 NodeValue imInfo, llvm::ArrayRef<float> weights,
2489 bool applyScale, bool rotated, bool angleBoundOn,
2490 int64_t angleBoundLo, int64_t angleBoundHi,
2491 float clipAngleThresh, bool legacyPlusOne);
2492
2493 /// Create an ExternFunctionCall node. \p funcImpl will contain body
2494 /// of or reference to the function which can be invoked.
2495 /// \p funcKind contains the type of function. The type of function could be
2496 /// source code, like OpenCL, CUDA, or could be a binary or
2497 /// a handle to an external function.
2498 ExternalFunctionCallNode *
2499 createExternalFunctionCall(llvm::StringRef name, TypeRef outTy,
2500 llvm::ArrayRef<glow::NodeValue> inputs,
2501 llvm::StringRef funcName, llvm::StringRef funcImpl,
2502 llvm::StringRef funcKind);
2503
2504 /// Erase the node \p N from the Function.
2505 void eraseNode(Node *N);
2506
2507 /// Erase the node \p I from the Function.
2508 void eraseNode(NodesList::iterator I);
2509
2510 /// Clone the current function into a new function with the name \p newName in
2511 /// the same module. If \p map is non-null then the procedure records the
2512 /// mapping between the old node to the new node in \p map. If \p currToNewMap
2513 /// is non-null it is used as the initial state of the currToNew map inside
2514 /// the cloner.
2515 /// \returns a new function that is a copy of the current function.
2516 Function *clone(llvm::StringRef newName,
2517 llvm::DenseMap<const Node *, Node *> *map = nullptr,
2518 llvm::DenseMap<const Node *, Node *> *currToNewMap = nullptr);
2519
2520 /// Clone the current function into a user-provided function \p newF. The
2521 /// function \p newF is not automatically added to a module by the clone call.
2522 /// If \p map is non-null then the procedure records the mapping between the
2523 /// old node to the new node in \p map. If \p currToNewMap is non-null it is
2524 /// used as the initial state of the currToNew map inside the cloner. \returns
2525 /// a user-provided function \p newF that now contains a clone of the current
2526 /// function.
2527 Function *
2528 clone(Function *newF, llvm::DenseMap<const Node *, Node *> *map = nullptr,
2529 llvm::DenseMap<const Node *, Node *> *currToNewMap = nullptr) const;
2530
2531 /// Verify the correctness of the Function. If \p backend is provided, checks
2532 /// backend-specific layout requirements. Else checks the requirements based
2533 /// on Glow's "canonical" layout. \returns true when the function is valid.
2534 /// False otherwise.
2535 bool verify(const Backend *backend = nullptr) const;
2536
2537 /// Dump a textual representation of the Function into provided output stream.
2538 void dump() const;
2539
2540 /// Dump a textual representation of the Function to std::string. If
2541 /// \p skipUsersForStorage then user counts for Storage will not be dumped.
2542 /// If \p skipName then the name of the Function will not be dumped.
2543 std::string toString(bool skipUsersForStorage = false,
2544 bool skipName = false) const;
2545
2546 /// \returns a hash code of the function.
2547 llvm::hash_code getHash() const;
2548
2549 /// Dump a textual representation of the Function into default output stream.
2550 /// If \p skipUsersForStorage then user counts for Storage will not be dumped.
2551 /// If \p skipName then the name of the Function will not be dumped.
2552 void dump(llvm::raw_ostream &os, bool skipUsersForStorage = false,
2553 bool skipName = false) const;
2554
2555 /// Dump a dotty graph that depicts the function into a file.
2556 /// \returns full path to the file.
2557 std::string dumpDAG();
2558
2559 /// Dump a dotty graph that depicts the function.
2560 void dumpDAG(llvm::StringRef dotFilename);
2561
2562 /// Dump a dotty graph that depicts the function.
2563 void dumpDAG(const char *dotFilename);
2564
2565 /// \returns the list of nodes that the Function owns.
2566 NodesList &getNodes() { return nodes_; }
2567
2568 const NodesList &getNodes() const { return nodes_; }
2569
2570 /// \returns a node with the name \p name or nullptr if no node was found.
2571 Node *getNodeByName(llvm::StringRef name);
2572
2573 /// \returns a node value using the \p name which has the same format as the
2574 /// one used by the \ref NodeValue::generateNodeOutputName which is
2575 /// "nodeName:outputNumber". The returned node value has a nullptr for the
2576 /// node if not found in the Function or if the node has no outputs (for
2577 /// example SaveNode). The searched node value can be one of a graph node,
2578 /// constant or placeholder.
2579 NodeValue getNodeValueByName(llvm::StringRef name);
2580
2581 /// \returns pointer to the class member for the nodes list.
2582 static NodesList Function::*getNodesMemberPtr() { return &Function::nodes_; }
2583
2584 /// Randomize all of the Constants in the function. If a Constant with users
2585 /// in this Function also has users in other Functions then this will result
2586 /// in a FATAL. \p ignoredConstants is a map Kinds of nodes to the input
2587 /// indices for that node that should be ignored (not randomized).
2588 void randomizeConstants(
2589 const std::map<Kinded::Kind, std::set<unsigned>> &ignoredConstants = {});
2590};
2591
2592struct TrainingConfig;
2593
2594using VariableGradientsList =
2595 std::list<std::pair<Placeholder *, Placeholder *>>;
2596
2597/// Create a new Function that 'trains' the input Function. We differentiate the
2598/// nodes and insert code to update the weights based on the \p config
2599/// parameters.
2600/// If \p varGrads is set then instead of inserting code to update the weights,
2601/// the procedure adds code to record the last gradient value: a list of
2602/// (var, grad_var) pairs associating variables with their gradient variables.
2603/// This feature is used by the gradient-check unit tests.
2604/// \returns a new function with the name \p newFuncName.
2605Function *differentiate(Function *F, const TrainingConfig &config,
2606 llvm::StringRef newFuncName = "",
2607 VariableGradientsList *varGrads = nullptr);
2608
2609/// \returns the first SaveNode user of the placeholder \p PH or
2610/// nullptr if none are found.
2611SaveNode *getOutputSave(Function *F, Placeholder *PH);
2612
2613/// Clone \p node and its sources into \p newF using old-to-new mapping \p
2614/// currToNew.
2615Node *recursiveClone(Function *newF, Node *node, NodeMap &currToNew);
2616
2617/// If \p PH is an output placeholder in the Function \p F,
2618/// \returns true.
2619/// This is determined by checking if the PH has a user which uses the PH as an
2620/// overwritten input.
2621bool isOutput(const Placeholder *PH, const Function &F);
2622
2623/// If \p PH is an input placeholderin the Function \p F,
2624/// \returns true.
2625/// This is determined by checking if the PH is the input to a saveNode or is
2626/// used by a non saveNode.
2627bool isInput(const Placeholder *PH, const Function &F);
2628
2629/// Helper vectors for common transpose shuffles.
2630#define NCH2NHC \
2631 { 0u, 2u, 1u }
2632#define NCHW2NHWC \
2633 { 0u, 2u, 3u, 1u }
2634#define NCTHW2NTHWC \
2635 { 0u, 2u, 3u, 4u, 1u }
2636#define NHWC2NCHW \
2637 { 0u, 3u, 1u, 2u }
2638#define NTHWC2NCTHW \
2639 { 0u, 4u, 1u, 2u, 3u }
2640#define HWCN2NHWC \
2641 { 3u, 0u, 1u, 2u }
2642#define NHWC2HWNC \
2643 { 1u, 2u, 0u, 3u }
2644#define CNHW2NHWC \
2645 { 1u, 2u, 3u, 0u }
2646#define NHWC2CHWN \
2647 { 3u, 1u, 2u, 0u }
2648#define CHWN2NHWC \
2649 { 3u, 1u, 2u, 0u }
2650#define D2S_DCR \
2651 { 0u, 1u, 3u, 2u, 4u, 5u }
2652#define D2S_CRD \
2653 { 0u, 1u, 4u, 2u, 5u, 3u }
2654
2655llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module &mod);
2656
2657llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module *mod);
2658
2659llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &F);
2660
2661llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function *F);
2662
2663/// \returns whether the Convolution node \p node is equivalent with a
2664/// FullyConnected node. This happens for a 2D NHWC Convolution with 1x1 filter
2665/// with strides 1, pads 0, group 1 and dilations 1.
2666bool isConvolutionSameAsFullyConnected(const ConvolutionNode *node,
2667 bool enfoceInput1x1 = false);
2668
2669/// \returns whether the Gemm node \p node is equivalent with a FullyConnected
2670/// node. This happens when alpha and beta are 1.0 and the C operand is 1D.
2671bool isGemmSameAsFullyConnected(const GemmNode *node);
2672
2673} // namespace glow
2674
2675#endif // GLOW_GRAPH_GRAPH_H
2676