1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_GRAPH_GRAPH_H |
17 | #define GLOW_GRAPH_GRAPH_H |
18 | |
19 | #include "glow/Base/Type.h" |
20 | #include "glow/Graph/Log.h" |
21 | #include "glow/Graph/Nodes.h" |
22 | #include "glow/Quantization/Base/Base.h" |
23 | |
24 | #include "llvm/ADT/ArrayRef.h" |
25 | #include "llvm/ADT/DenseMap.h" |
26 | #include "llvm/ADT/StringSet.h" |
27 | #include "llvm/ADT/ilist.h" |
28 | #include "llvm/ADT/ilist_node.h" |
29 | |
30 | #include <list> |
31 | #include <vector> |
32 | |
33 | namespace glow { |
34 | class PlaceholderBindings; |
35 | |
36 | /// List of Types. |
37 | using TypesList = std::list<Type>; |
38 | /// Intrusive list of Nodes. |
39 | using NodesList = llvm::iplist<glow::Node>; |
40 | /// List of pointers to Nodes. The nodes are not owned by the list. |
41 | using NodesPtrList = std::list<glow::Node *>; |
42 | /// List of Functions. |
43 | using FunctionList = std::list<Function *>; |
44 | using ConstList = std::list<Constant *>; |
45 | using PlaceholderList = std::list<Placeholder *>; |
46 | using UnsignedArrayRef = llvm::ArrayRef<dim_t>; |
47 | /// Map from original Nodes to cloned Nodes. |
48 | using NodeMap = llvm::DenseMap<Node *, Node *>; |
49 | /// State of a function. This can be used to control optimizations which depend |
50 | /// on the state of the Function. This is a temporary workaround until GH Issue |
51 | /// #3213 is complete. |
52 | enum class FunctionState { |
53 | /// Indicates that the function has been created but not completely loaded. |
54 | FuncCreated, |
55 | /// Indicates that the function has been completely loaded. |
56 | FuncLoaded, |
57 | }; |
58 | |
59 | /// Helper names for common tensor layouts. |
60 | #define ANY_LAYOUT "*" |
61 | |
62 | class Module final { |
63 | /// Stores the functions in the module. |
64 | FunctionList functions_; |
65 | /// A uniqued list of types. Types in this list can be equated by comparing |
66 | /// their addresses. |
67 | TypesList types_{}; |
68 | /// Stores a list of unique Storage names that were used by the module at |
69 | /// some point. |
70 | llvm::StringSet<> usedStorageNames_{}; |
71 | /// Stores a list of node names that were used by Functions of this module at |
72 | /// some point. |
73 | llvm::StringSet<> usedNodeNames_{}; |
74 | /// Stores a list of node names that were present in the original model and |
75 | /// are good to be retained. |
76 | llvm::StringSet<> originalNames_{}; |
77 | /// A list of constants that the Module owns. |
78 | ConstList constants_; |
79 | /// A list of placeholder nodes that the Module owns. |
80 | PlaceholderList placeholders_; |
81 | /// Deterministic PRNG used to initialize weights in this module. |
82 | PseudoRNG PRNG_; |
83 | |
84 | /// Module log context that stores all logs related to this module. |
85 | LogContext moduleLogCtx_{nullptr}; |
86 | |
87 | /// Inserts the constant \p V to the list of constants. |
88 | Constant *addConstant(Constant *V); |
89 | |
90 | friend class Function; |
91 | |
92 | public: |
93 | Module() = default; |
94 | |
95 | ~Module(); |
96 | |
97 | /// \returns the prefix part of the provided \p name. E.g. for an input |
98 | /// of "relu__2" returns "relu". |
99 | static std::string getPrefix(llvm::StringRef name); |
100 | |
101 | /// \returns unique legal name that's based on the string \p name. Legal |
102 | /// names are legal C identifiers in the form: "[a-zA-Z_][a-zA-Z0-9_]*". |
103 | /// The name may not be in \p stringTable or \p updateTable and will be |
104 | /// inserted into \p updateTable. |
105 | static llvm::StringRef uniqueName(llvm::StringRef name, |
106 | const llvm::StringSet<> &stringTable, |
107 | llvm::StringSet<> &updateTable, |
108 | const llvm::StringSet<> &originalNames); |
109 | |
110 | /// Registers a \p name as used by some Node in this module. |
111 | void registerNodeName(llvm::StringRef name) { |
112 | // Don't care if it's already in the set. |
113 | usedNodeNames_.insert(name); |
114 | } |
115 | |
116 | /// Registers a \p name from the original model, good to be retained. |
117 | void registerOriginalName(llvm::StringRef name) { |
118 | // Don't care if it's already in the set. |
119 | if (name.size()) { |
120 | originalNames_.insert(name); |
121 | } |
122 | } |
123 | |
124 | /// \returns the pointer to list of original node names, good to be retained; |
125 | const llvm::StringSet<> *getOriginalNames() const { return &originalNames_; } |
126 | |
127 | /// Registers a name as used by a Storage node (Constant or Placeholder) in |
128 | /// this module. |
129 | void registerStorageName(llvm::StringRef name) { |
130 | usedStorageNames_.insert(name); |
131 | } |
132 | |
133 | /// \returns whether there's a Storage node already registered with \p name. |
134 | bool hasStorageName(llvm::StringRef name) { |
135 | return usedStorageNames_.count(name); |
136 | } |
137 | |
138 | /// Return a pointer to a uniqued type \p T. |
139 | TypeRef uniqueType(const Type &T); |
140 | |
141 | /// Return a pointer to a uniqued type \p T. |
142 | TypeRef uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims); |
143 | |
144 | /// Return a pointer to a uniqued type \p T. |
145 | TypeRef uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, float scale, |
146 | int32_t offset); |
147 | |
148 | /// Return a pointer to a uniqued type \p T. |
149 | /// The new type is identical to \p T, with a new shape \p dims. |
150 | TypeRef uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims); |
151 | |
152 | /// The new type is identical to \p T, with a new shape \p dims and new \p |
153 | /// strides. |
154 | TypeRef uniqueTypeWithNewStrides(TypeRef T, llvm::ArrayRef<dim_t> dims, |
155 | llvm::ArrayRef<dim_t> strides); |
156 | |
157 | /// The new type is identical to \p T, with a new shape \p dims and new \p |
158 | /// alignments. |
159 | TypeRef uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims, |
160 | llvm::ArrayRef<dim_t> alignments); |
161 | |
162 | /// Return a pointer to a uniqued type \p T. |
163 | /// The new type is identical to \p T, with a new shape and strides taken from |
164 | /// the type \p shapeType. |
165 | TypeRef uniqueTypeWithNewShape(TypeRef T, TypeRef shapeType); |
166 | |
167 | /// Return a pointer to a uniqued type \p T. |
168 | /// The new type is identical to \p T, with new scale and offset taken from |
169 | /// the type \p quantParamType. |
170 | TypeRef uniqueTypeWithNewQuantParams(TypeRef T, TypeRef quantParamType); |
171 | |
172 | /// Return the void type. |
173 | TypeRef getVoidTy(); |
174 | |
175 | /// \returns True if a function by the name \p name exists in the module. |
176 | bool hasFunction(llvm::StringRef name); |
177 | /// \returns the function with the name \p name, or nullptr if the function |
178 | /// does not exist. |
179 | Function *getFunction(llvm::StringRef name); |
180 | /// \returns a new function with the name \p name. |
181 | Function *createFunction(llvm::StringRef name); |
182 | /// \returns the list of Functions that the Module owns. |
183 | FunctionList &getFunctions() { return functions_; } |
184 | |
185 | const FunctionList &getFunctions() const { return functions_; } |
186 | |
187 | /// Clears out all Functions from \ref functions_. |
188 | void clearFunctions(); |
189 | |
190 | /// \returns the list of types that the Module owns. |
191 | const TypesList &getTypes() const { return types_; } |
192 | |
193 | /// Erase the constant \p N from the Module. |
194 | void eraseConstant(Constant *N); |
195 | |
196 | /// Erase the constant \p I from the Module. |
197 | void eraseConstant(ConstList::iterator I); |
198 | |
199 | /// Erase the placeholder \p I from the Module. |
200 | /// Note: we only provide an iterator version of this, as erasing Placeholders |
201 | /// is often unsafe. |
202 | void erasePlaceholder(PlaceholderList::iterator I); |
203 | |
204 | /// \returns a pointer to the first Constant with the name \p name or nullptr |
205 | /// if no node has this name. |
206 | Constant *getConstantByName(llvm::StringRef name) const; |
207 | |
208 | /// \returns the list of constants that the Module owns. |
209 | ConstList &getConstants() { return constants_; } |
210 | |
211 | const ConstList &getConstants() const { return constants_; } |
212 | |
213 | /// \returns the list of placeholders that the Module owns. |
214 | PlaceholderList &getPlaceholders() { return placeholders_; } |
215 | |
216 | const PlaceholderList &getPlaceholders() const { return placeholders_; } |
217 | |
218 | /// \returns a pointer to the placeholder with the name \p name or |
219 | /// nullptr if no placeholder has this name. |
220 | Placeholder *getPlaceholderByNameSlow(llvm::StringRef name) const; |
221 | |
222 | /// @name High-level Storage builders. |
223 | ///@{ |
224 | |
225 | Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims, |
226 | llvm::StringRef name, bool isTrainable, |
227 | const std::string &layout = ANY_LAYOUT); |
228 | |
229 | Placeholder *createPlaceholder(TypeRef T, llvm::StringRef name, |
230 | bool isTrainable, |
231 | const std::string &layout = ANY_LAYOUT); |
232 | |
233 | Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims, |
234 | float scale, int32_t offset, |
235 | llvm::StringRef name, bool isTrainable, |
236 | const std::string &layout = ANY_LAYOUT); |
237 | |
238 | Constant *createConstant(TypeRef T, llvm::StringRef name, |
239 | const std::string &layout = ANY_LAYOUT); |
240 | |
241 | Constant *createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims, |
242 | llvm::StringRef name, |
243 | const std::string &layout = ANY_LAYOUT); |
244 | |
245 | Constant *createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims, float scale, |
246 | int32_t offset, llvm::StringRef name, |
247 | const std::string &layout = ANY_LAYOUT); |
248 | |
249 | Constant *createConstant(llvm::StringRef name, const Tensor &tensor, |
250 | const std::string &layout = ANY_LAYOUT); |
251 | |
252 | Constant *createConstant(llvm::StringRef name, Tensor &&tensor, |
253 | const std::string &layout = ANY_LAYOUT); |
254 | |
255 | ///@} |
256 | |
257 | /// Verify the correctness of the Module. |
258 | /// \returns true when the function is valid. False otherwise. |
259 | bool verify() const; |
260 | |
261 | /// Get the pseudo-random number generator used by this module. |
262 | PseudoRNG &getPRNG() { return PRNG_; } |
263 | |
264 | /// Dump a textual representation of the Module into default output stream. |
265 | void dump() const; |
266 | |
267 | /// Dump a textual representation of the Module to std::string. |
268 | std::string toString() const; |
269 | |
270 | /// Dump a textual representation of the Module into provided output stream. |
271 | void dump(llvm::raw_ostream &os) const; |
272 | |
273 | /// Dump a dotty graph that depicts the Module. |
274 | void dumpDAG(); |
275 | |
276 | /// Dump a dotty graph that depicts the Module. |
277 | void dumpDAG(llvm::StringRef dotFilename); |
278 | |
279 | /// Dump a dotty graph that depicts the Module. |
280 | void dumpDAG(const char *dotFilename); |
281 | |
282 | /// Erase all of the functions from the module. |
283 | void eraseFunctions(); |
284 | |
285 | /// Erase all the functions, Placeholders, Constants, etc. |
286 | void clear(); |
287 | |
288 | /// Clone a module. |
289 | /// \returns a new module that is a copy of the current module. |
290 | Module *clone() const; |
291 | |
292 | /// Clone the current module into a user-provided module \p M. |
293 | /// \returns the user-provided module \p M that now contains a clone of the |
294 | /// current module. |
295 | Module *clone(Module *M) const; |
296 | |
297 | /// Strips payloads from constants. This is useful when |
298 | /// the Module will be kept around for metadata but we want to reduce memory |
299 | /// use. Unlike clear this leaves PHs and Constants in the module. |
300 | void strip(); |
301 | |
302 | /// Erase a function \p F from the module. |
303 | void eraseFunction(Function *F); |
304 | |
305 | /// \Returns the size in bytes of data used by constants. |
306 | uint64_t getConstantsSize(); |
307 | |
308 | /// \Returns the module log context. |
309 | LogContext *getModuleLogContext() { return &moduleLogCtx_; }; |
310 | |
311 | /// \returns whether any Node in the module are non-fused quantized with |
312 | /// scale == or != dummyScale, depending on \p expectDummy. |
313 | Error verifyDummyQParams(bool expectDummies); |
314 | |
315 | // Don't copy or move this class around. |
316 | // The destructor will wipe the functions leaving |
317 | // the original Module only dangling pointers. |
318 | Module(const Module &) = delete; |
319 | Module(Module &&) = delete; |
320 | Module &operator=(const PlaceholderBindings &) = delete; |
321 | Module &operator=(PlaceholderBindings &&) = delete; |
322 | }; |
323 | |
324 | // Forward Declaration for verify's optional parameter |
325 | class Backend; |
326 | struct CompilationContext; |
327 | |
328 | /// Represents the compute graph. |
329 | class Function final : public IRContainer { |
330 | /// A list of nodes that the Function owns. |
331 | NodesList nodes_; |
332 | |
333 | /// A list of metadata PHs associated with the function. |
334 | std::vector<Placeholder *> metadataPlaceholders_; |
335 | |
336 | /// Stores a list of unique node names that were used by the module at some |
337 | /// point. |
338 | llvm::StringSet<> uniqueNodeNames_{}; |
339 | |
340 | /// A reference to the owner of the function. |
341 | Module *parent_; |
342 | |
343 | /// The log context associated with this function. |
344 | std::shared_ptr<LogContext> logCtx_; |
345 | |
346 | /// The state of this function. |
347 | FunctionState state_; |
348 | |
349 | public: |
350 | Function(Module *parent, llvm::StringRef Name = {}) |
351 | : IRContainer(Name), parent_(parent), state_(FunctionState::FuncCreated) { |
352 | logCtx_ = std::make_shared<LogContext>(parent); |
353 | logCtx_->pushEvent(parent->getModuleLogContext()->getClonedScope()); |
354 | } |
355 | |
356 | ~Function(); |
357 | |
358 | IRKind getIRKind() const override { return IRKind::GlowGraphIRKind; }; |
359 | |
360 | static bool classof(const IRContainer *I) { |
361 | return I->getIRKind() == IRKind::GlowGraphIRKind; |
362 | } |
363 | |
364 | static bool classof(const Function *F) { return true; } |
365 | |
366 | /// Clear out \ref nodes_ and \ref uniqueNodeNames_. |
367 | void clear(); |
368 | |
369 | /// Sets the state of the function. |
370 | void setState(FunctionState state) { state_ = state; } |
371 | |
372 | /// Gets the state of the function. |
373 | FunctionState getState() { return state_; } |
374 | |
375 | std::string getFilename() { return getName().rsplit('/').second.str(); } |
376 | |
377 | /// Return the log context. |
378 | std::shared_ptr<LogContext> getLogContext() { return logCtx_; } |
379 | |
380 | /// Add placeholder for metadata such as profiling. |
381 | void addMetadataPlaceholder(Placeholder *PH) { |
382 | metadataPlaceholders_.push_back(PH); |
383 | } |
384 | |
385 | /// Get list of metadata placeholders. |
386 | const std::vector<Placeholder *> &getMetadataPlaceholders() const { |
387 | return metadataPlaceholders_; |
388 | } |
389 | |
390 | Module *getParent() override { return parent_; } |
391 | |
392 | /// Perform ordering of nodes_ based on node's name. |
393 | /// This is to make sure that performing optimizations have a deterministic |
394 | /// behavior on the graphs which have the same ops but different ordering in |
395 | /// nodes_. |
396 | /// Please do not call this in the middle of PyTorchModelLoading, since |
397 | /// constant propagation is heavily relied on the order of nodes in nodelist. |
398 | /// If the order is changed during model loading, the constant propagation may |
399 | /// cause unpredictable fatal error when building the graph. |
400 | void orderNodes() { |
401 | nodes_.sort( |
402 | [](const Node &a, const Node &b) { return a.getName() < b.getName(); }); |
403 | } |
404 | |
405 | /// Search the Module containing the function to gather and return a list of |
406 | /// placeholders that are used by the Function. |
407 | PlaceholderList findPlaceholders(); |
408 | PlaceholderList findPlaceholders() const; |
409 | |
410 | /// Search the Module containing the function to gather and return a list of |
411 | /// constants that are used by the Function. |
412 | ConstList findConstants(); |
413 | ConstList findConstants() const; |
414 | |
415 | const Module *getParent() const override { return parent_; } |
416 | |
417 | /// Inserts the node \p N to the list of nodes, and returns the inserted node. |
418 | template <class NodeTy> NodeTy *addNode(NodeTy *N) { |
419 | N->setName(Module::uniqueName(N->getName(), parent_->usedStorageNames_, |
420 | uniqueNodeNames_, parent_->originalNames_)); |
421 | parent_->registerNodeName(N->getName()); |
422 | nodes_.push_back(N); |
423 | |
424 | // Log the node creation. |
425 | logCtx_->logNodeCreation(*N); |
426 | |
427 | return N; |
428 | } |
429 | |
430 | /// Take ownership of \p N by removing it from its original Function, add it |
431 | /// to the current Function, and also unique its name. |
432 | void takeOwnershipOfNode(Node *N) { |
433 | N->getParent()->getNodes().remove(N); |
434 | N->setName(Module::uniqueName(N->getName(), parent_->usedStorageNames_, |
435 | uniqueNodeNames_, parent_->originalNames_)); |
436 | parent_->registerNodeName(N->getName()); |
437 | nodes_.push_back(N); |
438 | } |
439 | |
440 | /// Get the pseudo-random number generator used by this module. |
441 | PseudoRNG &getPRNG() { return getParent()->getPRNG(); } |
442 | |
443 | /// @name High-level, operation-level IRBuilder. |
444 | ///@{ |
445 | |
446 | /// Creates a PadNode with the given \p name and output type \p outTy which |
447 | /// pads the given \p input with the explicit pads \p pads according to the |
448 | /// padding mode \p mode and with the given value \p value. The padding mode |
449 | /// \p mode is one of enumeration values from \ref PaddingMode. For an input |
450 | /// with N dimensions (rank N) the \p pads must be a vector with 2*N values |
451 | /// with the following format: |
452 | /// pads = [pad_before(D1), pad_before(D2), ..., pad_before(DN), |
453 | /// pad_after (D1), pad_after (D2), ..., pad_after (DN)]. |
454 | /// The mode PaddingMode::CONSTANT pads the input using the constant value |
455 | /// \p value and currently is the only mode supported. |
456 | PadNode *createPad(llvm::StringRef name, NodeValue input, TypeRef outTy, |
457 | unsigned_t mode, llvm::ArrayRef<int> pads, float value); |
458 | |
459 | /// Creates a ConvolutionNode with the given \p name which convolves the 4D |
460 | /// \p input with \p filter and \bias. \p kernels defines the size of the |
461 | /// height and width dimensions of the filters. \p strides defines the number |
462 | /// of steps to take in the input for each output cell. \p pads defines how |
463 | /// many zero padding cells should be added to the input during convolution. |
464 | /// \p group defines the number of groups the input and output channels should |
465 | /// be divided into and convolved separately. \p dilation defines factor by |
466 | /// which gap between 2 neighboring kernel elements is expanded along each |
467 | /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW. |
468 | |
469 | ConvolutionNode * |
470 | createConv(llvm::StringRef name, NodeValue input, NodeValue filter, |
471 | NodeValue bias, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels, |
472 | llvm::ArrayRef<unsigned_t> strides, |
473 | llvm::ArrayRef<unsigned_t> pads, unsigned_t group, |
474 | llvm::ArrayRef<unsigned_t> dilation = {1, 1}, |
475 | ConvolutionLayout layout = ConvolutionLayout::NHWC); |
476 | |
477 | /// Creates a ConvolutionNode with the given \p name which convolves the 4D |
478 | /// \p input with \p filter and \bias. \p kernel defines the size of the |
479 | /// height and width dimensions of the filters. \p stride defines the number |
480 | /// of steps to take in the input for each output cell. \p pad defines how |
481 | /// many zero padding cells should be added to the input during convolution. |
482 | /// \p group defines the number of groups the input and output channels should |
483 | /// be divided into and convolved separately. \p dilation defines factor by |
484 | /// which gap between 2 neighboring kernel elements is expanded along each |
485 | /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW. |
486 | |
487 | ConvolutionNode * |
488 | createConv(llvm::StringRef name, NodeValue input, NodeValue filter, |
489 | NodeValue bias, TypeRef outTy, unsigned_t kernel, |
490 | unsigned_t stride, unsigned_t pad, unsigned_t group, |
491 | llvm::ArrayRef<unsigned_t> dilation = {1, 1}, |
492 | ConvolutionLayout layout = ConvolutionLayout::NHWC); |
493 | |
494 | /// Creates a Convolution3DNode with the given \p name which convolves the 5D |
495 | /// \p input with \p filter and \bias. \p kernels defines the size of the |
496 | /// height, width, and depth dimensions of the filters. \p strides defines the |
497 | /// the number of steps to take in the input for each output cell. \p pads |
498 | /// defines how many zero padding cells should be added to the input during |
499 | /// convolution. \p group defines the number of groups the input and output |
500 | /// channels should be divided into and convolved separately. \p outTy defines |
501 | /// the type of the output of the 3d convolution. |
502 | Convolution3DNode *createConv3D(llvm::StringRef name, NodeValue input, |
503 | NodeValue filter, NodeValue bias, |
504 | TypeRef outTy, |
505 | llvm::ArrayRef<unsigned_t> kernels, |
506 | llvm::ArrayRef<unsigned_t> strides, |
507 | llvm::ArrayRef<unsigned_t> pads, |
508 | unsigned_t group); |
509 | |
510 | /// Creates a Convolution3DNode with the given \p name which convolves the 5D |
511 | /// \p input with \p filter and \bias. \p kernel defines the size of the |
512 | /// height, width, and depth dimensions of the filters. \p stride defines the |
513 | /// the number of steps to take in the input for each output cell. \p pad |
514 | /// defines how many zero padding cells should be added to the input during |
515 | /// convolution. \p group defines the number of groups the input and output |
516 | /// channels should be divided into and convolved separately. \p outTy defines |
517 | /// the type of the output of the 3d convolution. |
518 | Convolution3DNode *createConv3D(llvm::StringRef name, NodeValue input, |
519 | NodeValue filter, NodeValue bias, |
520 | TypeRef outTy, unsigned_t kernel, |
521 | unsigned_t stride, unsigned_t pad, |
522 | unsigned_t group); |
523 | |
524 | /// Creates a ChannelwiseQuantizedConvolutionNode with the given \p name which |
525 | /// convolves the 4D/5D \p input with \p filter and \p bias. \p filterScales |
526 | /// and \p filterOffsets provide individual quantization parameters for each |
527 | /// filter group in \p filter while \p biasScales and \p biasOffsets provide |
528 | /// individual quantization parameters for each bias element corresponding to |
529 | /// each output channel. \p kernels defines the size of the height and width |
530 | /// dimensions of the filters. \p strides defines the number of steps to take |
531 | /// in the input for each output cell. \p pads defines how many zero padding |
532 | /// cells should be added to the input during convolution. \p group defines |
533 | /// the number of groups the input and output channels should be divided into |
534 | /// and convolved separately. \p dilation defines the filter dilation. |
535 | /// This function is flexible and has the following features: |
536 | /// - it can be provided with a floating-point \p filter and the function will |
537 | /// quantize automatically the filter channelwise using the given schema |
538 | /// \p schema and type \p filterElemQTy. |
539 | /// - it can be provided with a floating-point \p bias and the function will |
540 | /// quantize automatically the bias channelwise using the given schema |
541 | /// \p schema and type \p biasElemQTy. |
542 | /// - if \p filter is floating-point and \p filterScales or \p filterOffsets |
543 | /// are not provided then this function will derive them automatically. |
544 | /// - if \p filter is quantized then \p filterScales or \p filterOffsets are |
545 | /// mandatory. |
546 | /// - if \p bias is floating-point and \p biasScales or \p biasOffsets are not |
547 | /// provided then this function will derive them automatically. |
548 | /// - if \p bias is quantized and \p biasScales or \p biasOffsets are not |
549 | /// provided then this function will assume the implicit parameters |
550 | /// biasScales[i] = inputScale * filterScales[i] and biasOffsets[i] = 0. |
551 | /// To be noted that this case can handle safely only INT32 bias data type |
552 | /// because for INT8 type the bias will almost certainly be saturated. |
553 | /// This function will only quantize the filter if \p quantizeFilter is set |
554 | /// to true and will only quantize the bias if \p quantizeBias is set to true |
555 | /// such that a floating-point filter/bias can be attached to the node as-is |
556 | /// without any modifications in order for the backends to perform their own |
557 | /// custom quantization later if desired. |
558 | /// This function requires \p filter and \p bias operands to be constants. |
559 | ChannelwiseQuantizedConvolutionNode *createChannelwiseQuantizedConv( |
560 | llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias, |
561 | NodeValue filterScales, NodeValue filterOffsets, NodeValue biasScales, |
562 | NodeValue biasOffsets, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels, |
563 | llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads, |
564 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation = {1, 1}, |
565 | bool quantizeFilter = true, bool quantizeBias = true, |
566 | quantization::Schema schema = quantization::Schema::Asymmetric, |
567 | ElemKind filterElemQTy = ElemKind::Int8QTy, |
568 | ElemKind biasElemQTy = ElemKind::Int32QTy); |
569 | |
570 | /// Creates a ConvTransposeNode with the given \p name which does transposed |
571 | /// convolution of the 4D \p input with \p filter and \bias. \p kernels define |
572 | /// the size of the height and width dimensions of the filters. \p strides |
573 | /// define the number of steps to take in the input for each output cell. |
574 | /// \p pads define how many zero padding cells should be added to the input |
575 | /// during convolution. \p group defines the number of groups the input and |
576 | /// output channels should be divided into and convolved separately. |
577 | ConvTransposeNode *createConvTranspose( |
578 | llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias, |
579 | TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels, |
580 | llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads, |
581 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation = {1, 1}); |
582 | |
583 | /// Creates a createConvTransposeNode with the given \p name which does |
584 | /// transposed convolution of the 4D \p input with \p filter and \bias. \p |
585 | /// kernel defines the size of the height and width dimensions of the filters. |
586 | /// \p stride defines the number of steps to take in the input for each output |
587 | /// cell. \p pad defines how many zero padding cells should be added to the |
588 | /// input during convolution. \p group defines the number of groups the input |
589 | /// and output channels should be divided into and convolved separately. |
590 | ConvTransposeNode * |
591 | createConvTranspose(llvm::StringRef name, NodeValue input, NodeValue filter, |
592 | NodeValue bias, TypeRef outTy, unsigned_t kernel, |
593 | unsigned_t stride, unsigned_t pad, unsigned_t group, |
594 | llvm::ArrayRef<unsigned_t> dilation = {1, 1}); |
595 | |
596 | /// Creates and \returns a ConvertTo Node with name \p name of \p input to |
597 | /// output type \p outTy. |
598 | ConvertToNode *createConvertTo(llvm::StringRef name, NodeValue input, |
599 | TypeRef outTy); |
600 | |
601 | /// Creates and \returns a ConvertTo Node with name \p name of \p input to |
602 | /// output ElemKind \p k. |
603 | ConvertToNode *createConvertTo(llvm::StringRef name, NodeValue input, |
604 | ElemKind k); |
605 | |
606 | MaxPoolNode *createMaxPool(llvm::StringRef name, NodeValue input, |
607 | llvm::ArrayRef<unsigned_t> kernels, |
608 | llvm::ArrayRef<unsigned_t> strides, |
609 | llvm::ArrayRef<unsigned_t> pads, |
610 | ElemKind elemTyAMT = ElemKind::Int64ITy, |
611 | ConvolutionLayout layout = NHWC); |
612 | |
613 | MaxPoolNode *createMaxPool(llvm::StringRef name, NodeValue input, |
614 | unsigned_t kernel, unsigned_t stride, |
615 | unsigned_t pad, |
616 | ElemKind elemTyAMT = ElemKind::Int64ITy, |
617 | ConvolutionLayout layout = NHWC); |
618 | |
619 | AvgPoolNode *createAvgPool(llvm::StringRef name, NodeValue input, |
620 | llvm::ArrayRef<unsigned_t> kernels, |
621 | llvm::ArrayRef<unsigned_t> strides, |
622 | llvm::ArrayRef<unsigned_t> pads, |
623 | ConvolutionLayout layout = NHWC, |
624 | bool countIncludePads = true); |
625 | |
626 | AvgPoolNode *createAvgPool(llvm::StringRef name, NodeValue input, |
627 | TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels, |
628 | llvm::ArrayRef<unsigned_t> strides, |
629 | llvm::ArrayRef<unsigned_t> pads, |
630 | ConvolutionLayout layout = NHWC, |
631 | bool countIncludePads = true); |
632 | |
633 | AvgPoolNode *createAvgPool(llvm::StringRef name, NodeValue input, |
634 | unsigned_t kernel, unsigned_t stride, |
635 | unsigned_t pad, ConvolutionLayout layout = NHWC, |
636 | bool countIncludePads = true); |
637 | |
638 | /// Creates and \returns an AdaptiveAvgPool node with \p name, \p input, and |
639 | /// \p outTy. The AdaptiveAvgPoolNode will perform average pooling over the |
640 | /// input so that the result is of the shape specified by \p outTy. |
641 | AdaptiveAvgPoolNode *createAdaptiveAvgPool(llvm::StringRef name, |
642 | NodeValue input, TypeRef outTy); |
643 | |
644 | /// Creates and \returns a General Matrix Multiplication (Gemm) node with |
645 | /// given \p name which computes Y = alpha * A * B + beta * C. The operands |
646 | /// \p A and \p B are 2D matrices, the \p C operand is an optional 1D or 2D |
647 | /// matrix (broadcastable to the size of Y) and \p alpha and \p beta are float |
648 | /// scalars. The \p C operand is optional, if nullptr is given then it is not |
649 | /// used. If \p transposeA or \p transposeB is true then \p A or \p B is |
650 | /// additionally transposed prior to matrix multiplication. |
651 | /// If the output shape of Y is [M,N] then: |
652 | /// - The shape of \p A must be [M,K] or [K,M] (if transposed). |
653 | /// - The shape of \p B must be [K,N] or [N,K] (if transposed). |
654 | /// - The shape of \p C must be [N] (if 1D) or [M,N] (if 2D). |
655 | GemmNode *createGemm(llvm::StringRef name, NodeValue A, NodeValue B, |
656 | NodeValue C = nullptr, float alpha = 1.0, |
657 | float beta = 1.0, bool transposeA = false, |
658 | bool transposeB = false); |
659 | |
660 | GemmNode *createGemm(llvm::StringRef name, TypeRef outTy, NodeValue A, |
661 | NodeValue B, NodeValue C = nullptr, float alpha = 1.0, |
662 | float beta = 1.0, bool transposeA = false, |
663 | bool transposeB = false); |
664 | |
665 | /// Create and \returns a DynamicQuantizedFullyConnectedNode with \p name, |
666 | /// \p input, weights \p W, bias \p B, flag to indicate mode \p isSymmetric. |
667 | /// By default it is a dynamic quantized FC node, which takes fp16 inputs, |
668 | /// symmetrically quantized them, run FC on them, dequantize them and produces |
669 | /// fp16 output. If \p isSymmetric is set to false, then inputs are |
670 | /// asymmetrically quantized. if \p isPerBatchElement is set to false, then |
671 | /// inputs are per tensor quantized. |
672 | DynamicQuantizedFullyConnectedNode *createDynamicQuantizedFullyConnected( |
673 | llvm::StringRef name, NodeValue input, NodeValue W, NodeValue B, |
674 | bool isSymmetric = true, bool isPerBatchElement = true); |
675 | |
676 | /// Create and \returns a DynamicRowwiseQuantizedFullyConnectedNode with \p |
677 | /// name, \p input, weights \p W, bias \p B, rowwise weight qparams \p wScale |
678 | /// and \p wOffset, flag to indicate mode \p isSymmetric. By default it is a |
679 | /// dynamic quantized FC node, which takes fp16 inputs, symmetrically |
680 | /// quantized them, run FC on them, dequantize them and produces fp16 output. |
681 | /// If \p isSymmetric is set to false, then inputs are asymmetrically |
682 | /// quantized. if \p isPerBatchElement is set to false, then inputs are per |
683 | /// tensor quantized. |
684 | DynamicRowwiseQuantizedFullyConnectedNode * |
685 | createDynamicRowwiseQuantizedFullyConnected(llvm::StringRef name, |
686 | NodeValue input, NodeValue W, |
687 | NodeValue B, NodeValue wScale, |
688 | NodeValue wOffset, |
689 | bool isSymmetric = true, |
690 | bool isPerBatchElement = true); |
691 | |
692 | /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights |
693 | /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened |
694 | /// along \p axis. Note, output type and outputDepth are inferred based on |
695 | /// the input types. |
696 | FullyConnectedNode *createFullyConnected(llvm::StringRef name, |
697 | NodeValue input, Storage *W, |
698 | Storage *B, unsigned_t axis = 1); |
699 | |
700 | /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights |
701 | /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened |
702 | /// along \p axis. Note, output type and outputDepth are inferred based on |
703 | /// the input types. |
704 | FullyConnectedNode *createFullyConnected(llvm::StringRef name, |
705 | NodeValue input, NodeValue W, |
706 | NodeValue B, unsigned_t axis = 1); |
707 | |
708 | /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights |
709 | /// \p W, bias \p B, and \p outTy. If \p input is not 2 dimensional then it is |
710 | /// flattened along \p axis. Note, outputDepth is inferred based on \p outTy. |
711 | FullyConnectedNode *createFullyConnected(llvm::StringRef name, |
712 | NodeValue input, NodeValue W, |
713 | NodeValue B, TypeRef outTy, |
714 | unsigned_t axis = 1); |
715 | |
716 | /// Create a row-wise quantized fully connected node. This node is only used |
717 | /// in quantization. Args \p input and \p B are quantized in regular way, \p W |
718 | /// is the constant weights and is row-wise quantized using the given \p |
719 | /// scales and \p offsets. The output is quantized in the regular way, and its |
720 | /// type \p outTy is a quantized type. |
721 | RowwiseQuantizedFullyConnectedNode *createRowwiseQuantizedFullyConnected( |
722 | llvm::StringRef name, NodeValue input, NodeValue W, Constant *scales, |
723 | Constant *offsets, NodeValue B, TypeRef outTy); |
724 | |
725 | /// Create a row-wise quantized fully connected node. This node is only used |
726 | /// in quantization. Args \p input and \p B are quantized in regular way, \p W |
727 | /// is the constant weights and will be row-wise quantized during node |
728 | /// creation time. The output is quantized in the regular way, and its type |
729 | /// \p outTy is a quantized type. if \p transposeWeight is true, \p W need to |
730 | /// be transposed first. |
731 | RowwiseQuantizedFullyConnectedNode *createRowwiseQuantizedFullyConnected( |
732 | llvm::StringRef name, NodeValue input, NodeValue W, NodeValue B, |
733 | TypeRef outTy, quantization::Schema schema, bool transposeWeight = false); |
734 | |
735 | /// Implement an operation that computes the row-wise dot product of its |
736 | /// inputs. Consequently, \p X and \p Y must be either 1D or 2D tensors. This |
737 | /// lowered to a Mul node, and is followed by a BatchedReduceAdd if \p X and |
738 | /// \p Y are 2D. \returns either the Mul or BatchedReduceAdd node. |
739 | Node *createDotProduct(llvm::StringRef name, NodeValue X, NodeValue Y); |
740 | |
741 | /// Create a node that computes the pairwise dot product of \p inputs, which |
742 | /// must be a list of 2D tensors with identical shape. \returns the |
743 | /// BatchedPairwiseDotProductNode. |
744 | BatchedPairwiseDotProductNode * |
745 | createBatchedPairwiseDotProduct(llvm::StringRef name, |
746 | llvm::ArrayRef<NodeValue> inputs); |
747 | |
748 | /// Create a node that implements the elementwise linear operator. \p X is |
749 | /// 2D and \p w and \p b are 1D. \p w and \p b are broadcasted to match the |
750 | /// shape of \p X and then the output is computed by multiplying \p X and |
751 | /// broadcasted \p w and adding broadcasted \p b. \returns the |
752 | /// ElementwiseLinearNode. \p axis indicates the axis of the inputs (the other |
753 | /// axis of \p X is assumed to be the batch index). |
754 | Node *createElementwiseLinear(llvm::StringRef name, NodeValue X, NodeValue w, |
755 | NodeValue b, unsigned axis); |
756 | |
757 | /// Create a ReLU node with the given \p name and \p input. |
758 | /// Result type will be implicitly set based on the \p input type. |
759 | ReluNode *createRelu(llvm::StringRef name, NodeValue input); |
760 | // deprecated. |
761 | ReluNode *createRELU(llvm::StringRef name, NodeValue input); |
762 | |
763 | /// Create a ReLU node with the given \p name, \p input and |
764 | /// output type \p outTy. |
765 | ReluNode *createRelu(llvm::StringRef name, TypeRef outTy, NodeValue input); |
766 | // deprecated. |
767 | ReluNode *createRELU(llvm::StringRef name, NodeValue input, TypeRef outTy); |
768 | |
769 | /// Create a series of nodes representing a GeLU with the given \p name and \p |
770 | /// input. Result type will be implicitly set based on the \p input type. |
771 | GeluNode *createGelu(llvm::StringRef name, NodeValue input); |
772 | // deprecated. |
773 | GeluNode *createGELU(llvm::StringRef name, NodeValue input); |
774 | |
775 | /// Create a PReLU node with the given \p name, \p input and \p slope. |
776 | /// Result type will be implicitly set based on the \p input type. |
777 | PReluNode *createPRELU(llvm::StringRef name, NodeValue input, |
778 | NodeValue slope); |
779 | |
780 | /// Create a PReLU node with the given \p name, \p input, \p slope and |
781 | /// output type \p outTy. |
782 | PReluNode *createPRELU(llvm::StringRef name, NodeValue input, NodeValue slope, |
783 | TypeRef outTy); |
784 | |
785 | /// Create a Sigmoid node with the given \p name, \p input and |
786 | /// output type \p outTy. |
787 | SigmoidNode *createSigmoid(llvm::StringRef name, TypeRef outTy, |
788 | NodeValue input); |
789 | |
790 | /// Create a Sigmoid node with the given \p name and \p input. |
791 | /// Result type will be implicitly set based on the \p input type. |
792 | SigmoidNode *createSigmoid(llvm::StringRef name, NodeValue input); |
793 | |
794 | /// Create a Swish node with the given \p name and \p input. |
795 | /// If \p OT is nullptr, then result type will be implicitly set based on the |
796 | /// \p input type. |
797 | SwishNode *createSwish(llvm::StringRef name, NodeValue input); |
798 | SwishNode *createSwish(llvm::StringRef name, TypeRef OT, NodeValue input); |
799 | // deprecated. |
800 | SwishNode *createSwish(llvm::StringRef name, NodeValue input, TypeRef OT); |
801 | |
802 | /// Create a HardSigmoid node with the given \p name, \p input, \p alpha and |
803 | /// \p beta. Result type will be implicitly set based on the \p input type. |
804 | ClipNode *createHardSigmoid(llvm::StringRef name, NodeValue input, |
805 | float alpha, float beta); |
806 | |
807 | /// Create a HardSigmoid node with the given \p name, \p input, |
808 | /// \p alpha, \p beta and output type \p outTy. |
809 | ClipNode *createHardSigmoid(llvm::StringRef name, TypeRef outTy, |
810 | NodeValue input, float alpha, float beta); |
811 | |
812 | /// Create a Tanh node with the given \p name, \p input and |
813 | /// output type \p outTy. |
814 | TanhNode *createTanh(llvm::StringRef name, TypeRef outTy, NodeValue input); |
815 | |
816 | /// Create a Tanh node with the given \p name and \p input. |
817 | /// Result type will be implicitly set based on the \p input type. |
818 | TanhNode *createTanh(llvm::StringRef name, NodeValue input); |
819 | |
820 | /// Create an Exp node with \p name, which calculates element-wise |
821 | /// exponential of \p input. |
822 | ExpNode *createExp(llvm::StringRef name, NodeValue input); |
823 | |
824 | /// Create an Exp node with \p name with output type \p outTy, which |
825 | /// calculates element-wise exponential of \p input. |
826 | ExpNode *createExp(llvm::StringRef name, TypeRef outTy, NodeValue input); |
827 | |
828 | /// Create a Log node with \p name, which calculates element-wise natural log |
829 | /// of \p input, with output type \p outTy. |
830 | LogNode *createLog(llvm::StringRef name, NodeValue input); |
831 | LogNode *createLog(llvm::StringRef name, TypeRef outTy, NodeValue input); |
832 | // deprecated |
833 | LogNode *createLog(llvm::StringRef name, NodeValue input, TypeRef outTy); |
834 | |
835 | /// \returns a LogitNode with \p name given \p input and \p eps. |
836 | LogitNode *createLogit(llvm::StringRef name, NodeValue input, float eps); |
837 | |
838 | /// Create a SoftPlus node with the given \p name, \p input and |
839 | /// output type \p outTy. |
840 | SoftPlusNode *createSoftPlus(llvm::StringRef name, NodeValue input, |
841 | TypeRef outTy = nullptr); |
842 | |
843 | SoftMaxNode *createSoftMax(llvm::StringRef name, NodeValue input, |
844 | NodeValue selected, TypeRef outTy = nullptr, |
845 | float beta = 1.0); |
846 | |
847 | LogSoftMaxNode *createLogSoftMax(llvm::StringRef name, NodeValue input, |
848 | NodeValue selected, TypeRef outTy = nullptr, |
849 | float beta = 1.0); |
850 | |
851 | CrossEntropyLossNode *createCrossEntropyLoss(llvm::StringRef name, |
852 | NodeValue input, |
853 | NodeValue labels); |
854 | |
855 | RegressionNode *createRegression(llvm::StringRef name, NodeValue input, |
856 | NodeValue expected); |
857 | |
858 | /// Creates a node, which computes sigmoid cross entropy between two inputs. |
859 | SigmoidCrossEntropyWithLogitsNode * |
860 | createSigmoidCrossEntropyWithLogits(llvm::StringRef name, NodeValue logits, |
861 | NodeValue targets); |
862 | |
863 | ReshapeNode *createReshape(llvm::StringRef name, NodeValue input, |
864 | UnsignedArrayRef shape, |
865 | llvm::StringRef layout = ANY_LAYOUT); |
866 | |
867 | TransposeNode *createTranspose(llvm::StringRef name, NodeValue input, |
868 | llvm::ArrayRef<unsigned_t> shuffle, |
869 | const std::string &layout = ANY_LAYOUT); |
870 | |
871 | /// Create a node with the name \p name which flips (reorders) the elements |
872 | /// of the input \p input along the given axis \p axis. |
873 | FlipNode *createFlip(llvm::StringRef name, NodeValue input, unsigned_t axis); |
874 | |
875 | /// Create a Broadcast node that broadcasting the \p input Tensor based on |
876 | /// \p newShape and along the \p axis, which defines the offset between the |
877 | /// input dim and the newShape. |
878 | /// e.g. For input: [3] and newShape: [2, 3, 2], the axis will be 1. |
879 | /// For input: [3] and newShape: [2, 2, 3], the axis will be 2. |
880 | BroadcastNode *createBroadcast(llvm::StringRef name, NodeValue input, |
881 | UnsignedArrayRef newShape, unsigned_t axis); |
882 | |
883 | /// Create concat node which concatenates input tensors along \p dimension. |
884 | ConcatNode *createConcat(llvm::StringRef name, |
885 | llvm::ArrayRef<NodeValue> inputs, |
886 | unsigned_t dimension); |
887 | |
888 | /// Create concat node with the given return type \p outTy. |
889 | ConcatNode *createConcat(llvm::StringRef name, |
890 | llvm::ArrayRef<NodeValue> inputs, |
891 | unsigned_t dimension, TypeRef outTy); |
892 | |
893 | /// Create a TileNode with \p name, \p input, \p tiles, and \p axis. |
894 | /// For example, an input tensor {{1,2,3,4}} of dimension 1x4 with tiles = 2 |
895 | /// and axis = 0 would result in an output tensor {{1,2,3,4}, {1,2,3,4}} of |
896 | /// dimension 2x4. |
897 | TileNode *createTile(llvm::StringRef name, NodeValue input, unsigned_t tiles, |
898 | unsigned_t axis, TypeRef outTy = nullptr); |
899 | |
900 | /// Create a TileNode with \p name, \p input which repeats the input data |
901 | /// the given count values \p tiles along the given \p axes. |
902 | TileNode *createTile(llvm::StringRef name, NodeValue input, |
903 | llvm::ArrayRef<unsigned_t> tiles, |
904 | llvm::ArrayRef<unsigned_t> axes); |
905 | |
906 | /// Create an insert tensor node \p name, which inserts \p small into \p big |
907 | /// at offset into big \p start \p count times along \p axis. |
908 | InsertTensorNode *createInsertTensor(llvm::StringRef name, NodeValue big, |
909 | NodeValue small, |
910 | llvm::ArrayRef<dim_t> start, |
911 | unsigned_t count = 1, |
912 | unsigned_t axis = 0); |
913 | |
914 | /// Create a slice node \p name with the given starting points for each |
915 | /// dimension \p begin and end points \p end (exclusive). |
916 | SliceNode *createSlice(llvm::StringRef name, NodeValue input, |
917 | UnsignedArrayRef begin, UnsignedArrayRef end); |
918 | |
919 | /// Create a slice node with the given starting point for each dimension. |
920 | /// End points will be calculated based on the output type during execution. |
921 | SliceNode *createSlice(llvm::StringRef name, NodeValue input, |
922 | llvm::ArrayRef<dim_t> start, TypeRef outTy); |
923 | |
924 | /// Shuffles dimension number \p kernel. Suppose original size is D. It will |
925 | /// be represented as groupX(D/group) matrix, transposed and concatenated back |
926 | /// to size D. For example, shuffle of {1, 2, 3, 4, 5, 6} with \p group = 2 is |
927 | /// {1, 4, 2, 5, 3, 6} |
928 | Node *createChannelShuffle(llvm::StringRef name, NodeValue input, |
929 | size_t group, size_t kernel); |
930 | |
931 | /// Computes the indices of the max elements of the input tensor along the |
932 | /// provided \p axis. The resulted tensor has the same rank as the input if \p |
933 | /// keepDims equal 1. If \p keepdims equals 0, the resulted tensor has the |
934 | /// reduced dimension pruned. The type of the output tensor is \p elemTy. |
935 | ArgMaxNode *createArgMax(llvm::StringRef name, NodeValue input, |
936 | unsigned_t axis, bool keepDims, |
937 | ElemKind elemTy = ElemKind::Int64ITy); |
938 | |
939 | /// Computes the indices of the min elements of the input tensor along the |
940 | /// provided \p axis. The resulted tensor has the same rank as the input if \p |
941 | /// keepDims equal 1. If \p keepdims equals 0, the resulted tensor has the |
942 | /// reduced dimension pruned. The type of the output tensor is \p elemTy. |
943 | ArgMinNode *createArgMin(llvm::StringRef name, NodeValue input, |
944 | unsigned_t axis, bool keepDims, |
945 | ElemKind elemTy = ElemKind::Int64ITy); |
946 | |
947 | /// Removes single-dimensional entries from the shape of a tensor. The |
948 | /// parameter \p axes is a list of positive integers, indicating the |
949 | /// dimensions to squeeze. Impelmented as a single ReshapeNode. This is the |
950 | /// opposite of ExpandDims. |
951 | /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#squeeze |
952 | ReshapeNode *createSqueeze(llvm::StringRef name, NodeValue input, |
953 | llvm::ArrayRef<dim_t> axes); |
954 | |
955 | /// Add single-dimensional entries to the shape of the \p input tensor at |
956 | /// locations in \p axes. \p axes is listed as seen in the output tensor. |
957 | /// Implemented as a single ReshapeNode. This is the opposite of Squeeze. |
958 | ReshapeNode *createExpandDims(llvm::StringRef name, NodeValue input, |
959 | llvm::ArrayRef<dim_t> axes); |
960 | |
961 | /// Flattens the input tensor into a 2D matrix. If input tensor has shape |
962 | /// (d_0, d_1, ... d_n) then the output will have shape: |
963 | /// (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X d_n). |
964 | ReshapeNode *createFlatten(llvm::StringRef name, NodeValue input, |
965 | unsigned_t axis); |
966 | |
967 | /// Flattens the input tensor into a 2D matrix. If input tensor has shape |
968 | /// (d_0, d_1, ... d_n) then the output will have shape: |
969 | /// ((d_0 X d_1 ... d_(axis-1) X d_(axis+1) ... X d_n), d_axis). |
970 | ReshapeNode *createFlattenV1(llvm::StringRef name, NodeValue input, |
971 | unsigned_t axis); |
972 | |
973 | /// Create \p outputNum slice nodes of \p input. Slices happen along dimension |
974 | /// number \p axis. Array \p split defines lengths of slices. If \p split is |
975 | /// empty, \p input is split to equal sized parts. |
976 | void createSplit(llvm::StringRef name, NodeValue input, unsigned_t outputNum, |
977 | unsigned_t axis, llvm::ArrayRef<dim_t> split, |
978 | std::vector<SliceNode *> &outputs); |
979 | |
980 | BatchNormalizationNode *createBatchNormalization( |
981 | llvm::StringRef name, TypeRef resType, NodeValue input, NodeValue beta, |
982 | NodeValue scale, NodeValue mean, NodeValue var, unsigned_t channelIdx = 0, |
983 | float epsilon = 1e-5, float momentum = 0.9); |
984 | |
985 | /// Create and \returns an InstanceNormalizationNode with result type of \p |
986 | /// outTy that computes the instance normalization of \p input based on the \p |
987 | /// scale and \p bias combined with the computed mean and stddev of each |
988 | /// batch. \p epsilon is a small perterbation used to avoid division by 0 |
989 | /// during normalization. |
990 | InstanceNormalizationNode * |
991 | createInstanceNormalization(llvm::StringRef name, NodeValue input, |
992 | NodeValue beta, NodeValue scale, |
993 | unsigned_t channelIdx = 0, float epsilon = 1e-5); |
994 | |
995 | /// Creates and \returns a LayerNormalizationNode with result type of \p outTy |
996 | /// that computes the layer normalization of the inner most layers of \p input |
997 | /// based on the shape of \p scale and \p bias. \p epsilon is a small |
998 | /// perterbation used to avoid division by 0 during normalization. |
999 | LayerNormalizationNode * |
1000 | createLayerNormalization(llvm::StringRef name, TypeRef outTy, NodeValue input, |
1001 | NodeValue scale, NodeValue bias, |
1002 | float epsilon = 1e-5); |
1003 | |
1004 | /// Bucketizes the input tensor based on monotonically increasing \p |
1005 | /// boundaries for each value in \p input. For each value x in input, the |
1006 | /// operator \returns index i given boundaries[i-1] < x <= boundaries[i]. If |
1007 | /// the value x is beyond the bounds of boundaries, 0 or len(boundaries) is |
1008 | /// returned as appropriate. |
1009 | BucketizeNode *createBucketizeNode(llvm::StringRef name, NodeValue input, |
1010 | llvm::ArrayRef<float> boundaries); |
1011 | |
1012 | LocalResponseNormalizationNode *createLocalResponseNormalization( |
1013 | llvm::StringRef name, NodeValue input, unsigned_t halfWindowSize = 2, |
1014 | float alpha = 1e-4, float beta = 0.75, float k = 2.0); |
1015 | |
1016 | /// Create a ModuloNode which performs the modulo operation elementwise on the |
1017 | /// \p input such that each element in the output is equal to the |
1018 | /// corresponding element in the input modulo \p divisor. If \p |
1019 | /// signFollowDivisor is true then any negative elements in the output will |
1020 | /// have divisor added to their final values. |
1021 | ModuloNode *createModulo(llvm::StringRef name, NodeValue input, |
1022 | int64_t divisor, bool signFollowDivisor = false); |
1023 | |
1024 | /// Create a logical NOT node with name \p name and input \p input. |
1025 | NotNode *createNot(llvm::StringRef name, NodeValue input); |
1026 | |
1027 | // Create a BitwiseNot node with name \p name and input \p input. |
1028 | BitwiseNotNode *createBitwiseNot(llvm::StringRef name, NodeValue input); |
1029 | |
1030 | #define UNARY_ARITHMETIC_FUN_DECL(NODE_NAME_) \ |
1031 | NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue input); \ |
1032 | NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, TypeRef Ty, \ |
1033 | NodeValue input); |
1034 | UNARY_ARITHMETIC_FUN_DECL(Abs) |
1035 | UNARY_ARITHMETIC_FUN_DECL(Neg) |
1036 | UNARY_ARITHMETIC_FUN_DECL(Floor) |
1037 | UNARY_ARITHMETIC_FUN_DECL(Sign) |
1038 | UNARY_ARITHMETIC_FUN_DECL(Ceil) |
1039 | UNARY_ARITHMETIC_FUN_DECL(Round) |
1040 | UNARY_ARITHMETIC_FUN_DECL(Sqrt) |
1041 | UNARY_ARITHMETIC_FUN_DECL(Rsqrt) |
1042 | UNARY_ARITHMETIC_FUN_DECL(Reciprocal) |
1043 | UNARY_ARITHMETIC_FUN_DECL(Sin) |
1044 | UNARY_ARITHMETIC_FUN_DECL(Cos) |
1045 | UNARY_ARITHMETIC_FUN_DECL(Erf) |
1046 | UNARY_ARITHMETIC_FUN_DECL(Truncate) |
1047 | UNARY_ARITHMETIC_FUN_DECL(HardSwish) |
1048 | #undef UNARY_ARITHMETIC_FUN_DECL |
1049 | |
1050 | #define ARITHMETIC_FUN_DECL(NODE_NAME_) \ |
1051 | NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue LHS, \ |
1052 | NodeValue RHS); \ |
1053 | NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, TypeRef Ty, \ |
1054 | NodeValue LHS, NodeValue RHS); |
1055 | ARITHMETIC_FUN_DECL(Add); |
1056 | ARITHMETIC_FUN_DECL(Mul); |
1057 | ARITHMETIC_FUN_DECL(Sub); |
1058 | ARITHMETIC_FUN_DECL(Div); |
1059 | ARITHMETIC_FUN_DECL(Max); |
1060 | ARITHMETIC_FUN_DECL(Min); |
1061 | ARITHMETIC_FUN_DECL(CmpEQ); |
1062 | ARITHMETIC_FUN_DECL(CmpNEQ); |
1063 | ARITHMETIC_FUN_DECL(CmpLT); |
1064 | ARITHMETIC_FUN_DECL(CmpLTE); |
1065 | ARITHMETIC_FUN_DECL(And); |
1066 | ARITHMETIC_FUN_DECL(Or); |
1067 | ARITHMETIC_FUN_DECL(Xor); |
1068 | ARITHMETIC_FUN_DECL(BitwiseAnd); |
1069 | ARITHMETIC_FUN_DECL(BitwiseOr); |
1070 | ARITHMETIC_FUN_DECL(BitwiseXor); |
1071 | ARITHMETIC_FUN_DECL(Pow); |
1072 | ARITHMETIC_FUN_DECL(Fmod); |
1073 | #undef ARITHMETIC_FUN_DECL |
1074 | |
1075 | #define TRIGONOMETRIC_FUN_DECL(NODE_NAME_) \ |
1076 | NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue input); \ |
1077 | NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, TypeRef Ty, \ |
1078 | NodeValue input); |
1079 | TRIGONOMETRIC_FUN_DECL(Acos) |
1080 | TRIGONOMETRIC_FUN_DECL(Asin) |
1081 | TRIGONOMETRIC_FUN_DECL(Atan) |
1082 | #undef TRIGONOMETRIC_FUN_DECL |
1083 | |
1084 | std::vector<NodeValue> |
1085 | broadcastInputs(int axis, const llvm::ArrayRef<NodeValue> inputs); |
1086 | |
1087 | template <class T, class U> |
1088 | using enable_if_same_t = std::enable_if<std::is_same<T, U>::value, U>; |
1089 | |
1090 | #define BROADCAST_FUNC_COMMON_CODE(NUM_INPUTS) \ |
1091 | constexpr size_t numInputs = sizeof...(Args); \ |
1092 | static_assert(numInputs == NUM_INPUTS, \ |
1093 | "Invalid input passed in to commonCreateBroadcast."); \ |
1094 | std::vector<NodeValue> inputs = broadcastInputs(axis, {inputArgs...}); |
1095 | |
1096 | #define DECLARE_BROADCAST_NODE(NODE_NAME, NUM_INPUTS) \ |
1097 | template <class T, class... Args> \ |
1098 | typename enable_if_same_t<T, NODE_NAME##Node>::type * \ |
1099 | createNodeWithBroadcast(const std::string &name, int axis, \ |
1100 | Args &&...inputArgs) { \ |
1101 | BROADCAST_FUNC_COMMON_CODE(NUM_INPUTS) \ |
1102 | return create##NODE_NAME(name, inputs[0].getType(), inputs[0], inputs[1]); \ |
1103 | } |
1104 | |
1105 | /// Template function that creates a node and normalizes its input shapes |
1106 | /// with the use of BroadCast nodes. If axis is -1, it calculates it |
1107 | /// automatically for multi directional broadcast. |
1108 | DECLARE_BROADCAST_NODE(Add, /* NUM_INPUTS */ 2) |
1109 | DECLARE_BROADCAST_NODE(Sub, /* NUM_INPUTS */ 2) |
1110 | DECLARE_BROADCAST_NODE(Mul, /* NUM_INPUTS */ 2) |
1111 | DECLARE_BROADCAST_NODE(Div, /* NUM_INPUTS */ 2) |
1112 | DECLARE_BROADCAST_NODE(And, /* NUM_INPUTS */ 2) |
1113 | DECLARE_BROADCAST_NODE(Xor, /* NUM_INPUTS */ 2) |
1114 | DECLARE_BROADCAST_NODE(Or, /* NUM_INPUTS */ 2) |
1115 | DECLARE_BROADCAST_NODE(BitwiseAnd, /* NUM_INPUTS */ 2) |
1116 | DECLARE_BROADCAST_NODE(BitwiseXor, /* NUM_INPUTS */ 2) |
1117 | DECLARE_BROADCAST_NODE(BitwiseOr, /* NUM_INPUTS */ 2) |
1118 | DECLARE_BROADCAST_NODE(Pow, /* NUM_INPUTS */ 2) |
1119 | DECLARE_BROADCAST_NODE(Fmod, /* NUM_INPUTS */ 2) |
1120 | |
1121 | #define DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(NODE_NAME, NUM_INPUTS, \ |
1122 | OUTTYPEREF) \ |
1123 | template <class T, class... Args> \ |
1124 | typename enable_if_same_t<T, NODE_NAME##Node>::type * \ |
1125 | createNodeWithBroadcastOutTy(const std::string &name, int axis, \ |
1126 | TypeRef OUTTYPEREF, Args &&...inputArgs) { \ |
1127 | BROADCAST_FUNC_COMMON_CODE(NUM_INPUTS) \ |
1128 | return create##NODE_NAME(name, OUTTYPEREF, inputs[0], inputs[1]); \ |
1129 | } |
1130 | |
1131 | DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Add, /* NUM_INPUTS */ 2, outTy) |
1132 | DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Sub, /* NUM_INPUTS */ 2, outTy) |
1133 | DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Mul, /* NUM_INPUTS */ 2, outTy) |
1134 | DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Div, /* NUM_INPUTS */ 2, outTy) |
1135 | DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Min, /* NUM_INPUTS */ 2, outTy) |
1136 | DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Max, /* NUM_INPUTS */ 2, outTy) |
1137 | DECLARE_BROADCAST_NODE_WITH_OUT_TYPE(Fmod, /* NUM_INPUTS */ 2, outTy) |
1138 | |
1139 | #define DECLARE_CMP_BROADCAST_NODE(NODE_NAME) \ |
1140 | template <class T, class... Args> \ |
1141 | typename enable_if_same_t<T, NODE_NAME##Node>::type * \ |
1142 | createNodeWithBroadcast(const std::string &name, int axis, \ |
1143 | Args &&...inputArgs) { \ |
1144 | BROADCAST_FUNC_COMMON_CODE(2) \ |
1145 | return create##NODE_NAME(name, inputs[0], inputs[1]); \ |
1146 | } |
1147 | |
1148 | /// Template function that creates a node and normalizes its input shapes |
1149 | /// with the use of BroadCast nodes. If axis is -1, it calculates it |
1150 | /// automatically for multi directional broadcast. |
1151 | DECLARE_CMP_BROADCAST_NODE(CmpLT) |
1152 | DECLARE_CMP_BROADCAST_NODE(CmpEQ) |
1153 | DECLARE_CMP_BROADCAST_NODE(CmpNEQ) |
1154 | DECLARE_CMP_BROADCAST_NODE(CmpLTE) |
1155 | DECLARE_CMP_BROADCAST_NODE(Min) |
1156 | DECLARE_CMP_BROADCAST_NODE(Max) |
1157 | |
1158 | /// Template function that creates a node and normalizes its input shapes |
1159 | /// with the use of BroadCast nodes. If axis is -1, it calculates it |
1160 | /// automatically for multi directional broadcast. |
1161 | template <class T, class... Args> |
1162 | typename enable_if_same_t<T, SelectNode>::type * |
1163 | createNodeWithBroadcast(const std::string &name, int axis, |
1164 | Args &&...inputArgs) { |
1165 | BROADCAST_FUNC_COMMON_CODE(3) |
1166 | return createSelect(name, inputs[1].getType(), inputs[0], inputs[1], |
1167 | inputs[2]); |
1168 | } |
1169 | |
1170 | #undef BROADCAST_FUNC_COMMON_CODE |
1171 | #undef DECLARE_BROADCAST_NODE |
1172 | #undef DECLARE_BROADCAST_NODE_WITH_OUT_TYPE |
1173 | #undef DECLARE_CMP_BROADCAST_NODE |
1174 | #undef BROADCAST_FUNC_COMMON_CODE |
1175 | |
1176 | /// Create a FloorDivNode with given \p name which divides \p LHS with \p RHS |
1177 | /// and floors the quotient. If \p truncate is true then truncates the |
1178 | /// quotient instead of flooring. |
1179 | FloorDivNode *createFloorDiv(llvm::StringRef name, NodeValue LHS, |
1180 | NodeValue RHS, bool truncate = false); |
1181 | |
1182 | /// Create a FloorDivNode with given \p name and output type \p outTy which |
1183 | /// divides \p LHS with \p RHS and floors the quotient. If \p truncate is true |
1184 | /// then truncates the quotient to zero instead of flooring. |
1185 | FloorDivNode *createFloorDiv(llvm::StringRef name, TypeRef outTy, |
1186 | NodeValue LHS, NodeValue RHS, |
1187 | bool truncate = false); |
1188 | |
1189 | /// Create a FloorDivNode with given \p name which divides \p LHS with \p RHS |
1190 | /// and floors the quotient. If \p truncate is true then truncates the |
1191 | /// quotient to zero instead of flooring. The inputs are broadcasted based on |
1192 | /// \p axis. |
1193 | FloorDivNode *createFloorDivWithBroadcast(llvm::StringRef name, int axis, |
1194 | NodeValue LHS, NodeValue RHS, |
1195 | bool truncate = false); |
1196 | |
1197 | /// Create a FloorDivNode with given \p name and output type \p outTy which |
1198 | /// divides \p LHS with \p RHS and floors the quotient. If \p truncate is true |
1199 | /// then truncates the quotient to zero instead of flooring. The inputs are |
1200 | /// broadcasted based on \p axis. |
1201 | FloorDivNode *createFloorDivWithBroadcast(llvm::StringRef name, int axis, |
1202 | TypeRef outTy, NodeValue LHS, |
1203 | NodeValue RHS, |
1204 | bool truncate = false); |
1205 | |
1206 | /// Create an element-wise GREATER THAN comparison between \p LHS and \p RHS |
1207 | /// by creating a CmpLTNode with given \p name and swapped inputs. |
1208 | CmpLTNode *createCmpGT(llvm::StringRef name, NodeValue LHS, NodeValue RHS); |
1209 | |
1210 | /// Create an element-wise GREATER THAN or EQUAL comparison between \p LHS and |
1211 | /// \p RHS by creating a CmpLTENode with given \p name and swapped inputs. |
1212 | CmpLTENode *createCmpGTE(llvm::StringRef name, NodeValue LHS, NodeValue RHS); |
1213 | |
1214 | /// Create a MulNode with given \p name which multiplies \p input with itself |
1215 | /// to produce an equivalent Square node. |
1216 | MulNode *createSquare(llvm::StringRef name, NodeValue input); |
1217 | |
1218 | /// Create a MulNode with given \p name and output type \p outTy which |
1219 | /// multiplies \p input with itself to produce an equivalent Square node. |
1220 | MulNode *createSquare(llvm::StringRef name, TypeRef outTy, NodeValue input); |
1221 | |
1222 | /// Create a LeakyRELU with \p name, \p input and slope \p alpha. |
1223 | LeakyReluNode *createLeakyRELU(llvm::StringRef name, NodeValue input, |
1224 | float alpha); |
1225 | |
1226 | /// Create a LeakyRELU with \p name, \p outTy, \p input and slope \p alpha. |
1227 | LeakyReluNode *createLeakyRELU(llvm::StringRef name, TypeRef outTy, |
1228 | NodeValue input, float alpha); |
1229 | |
1230 | /// Create a node that produces an boolean output of the same shape as |
1231 | /// \p input in which each element indicates whether or not the corresponding |
1232 | /// element in \p input is NaN or not. |
1233 | IsNaNNode *createIsNaN(llvm::StringRef name, NodeValue input); |
1234 | |
1235 | /// \returns a ReplaceNaNNode given \p name, \p input, and \p value. |
1236 | ReplaceNaNNode *createReplaceNaN(llvm::StringRef name, NodeValue input, |
1237 | float value); |
1238 | |
1239 | PowNode *createPow(llvm::StringRef name, NodeValue base, float exp); |
1240 | |
1241 | NonZeroNode *createNonZero(llvm::StringRef name, NodeValue Cond); |
1242 | |
1243 | SelectNode *createSelect(llvm::StringRef name, NodeValue Cond, NodeValue LHS, |
1244 | NodeValue RHS); |
1245 | |
1246 | SelectNode *createSelect(llvm::StringRef name, TypeRef outTy, NodeValue Cond, |
1247 | NodeValue LHS, NodeValue RHS); |
1248 | |
1249 | SplatNode *createSplat(llvm::StringRef name, TypeRef ty, float value); |
1250 | |
1251 | TouchNode *createTouch(llvm::StringRef name, TypeRef ty); |
1252 | |
1253 | MatMulNode *createMatMul(llvm::StringRef name, NodeValue lhs, NodeValue rhs); |
1254 | |
1255 | MatMulNode *createMatMul(llvm::StringRef name, TypeRef outTy, NodeValue lhs, |
1256 | NodeValue rhs); |
1257 | |
1258 | /// \p lhs and \p rhs are 3d matrices, where the leading dimension is the |
1259 | /// batch size. For each batch element number i, lhs.slice(i) is multiplied by |
1260 | /// rhs.slice(i). |
1261 | BatchMatMulNode *createBatchMatMul(llvm::StringRef name, NodeValue lhs, |
1262 | NodeValue rhs); |
1263 | |
1264 | /// Create a node, performing Norm operation. Output type is based on the |
1265 | /// input \p p type with dimensions specified with \p axes removed. |
1266 | VectorNormNode *createVectorNorm(llvm::StringRef name, NodeValue input, |
1267 | unsigned_t axis, unsigned_t p = 2); |
1268 | |
1269 | /// Create a node, performing BatchedReduceAdd operation. Output type is |
1270 | /// based on the input \p batch type with dimensions specified with \p axes |
1271 | /// removed. |
1272 | BatchedReduceAddNode *createBatchedReduceAdd(llvm::StringRef name, |
1273 | NodeValue batch, |
1274 | llvm::ArrayRef<unsigned_t> axes); |
1275 | |
1276 | /// Create a node, performing BatchedReduceAdd operation. Output type |
1277 | /// matches input \p outTy type. |
1278 | BatchedReduceAddNode *createBatchedReduceAdd(llvm::StringRef name, |
1279 | TypeRef outTy, NodeValue batch, |
1280 | llvm::ArrayRef<unsigned_t> axes); |
1281 | |
1282 | /// Create a node, performing BatchedReduceSumSquare operation. Output type is |
1283 | /// based on the input \p batch type with dimensions specified with \p axes |
1284 | /// removed. |
1285 | BatchedReduceSumSquareNode * |
1286 | createBatchedReduceSumSquare(llvm::StringRef name, NodeValue batch, |
1287 | llvm::ArrayRef<unsigned_t> axes); |
1288 | |
1289 | /// Create a node, performing BatchedReduceSumSquare operation. Output type |
1290 | /// matches input \p outTy type. |
1291 | BatchedReduceSumSquareNode * |
1292 | createBatchedReduceSumSquare(llvm::StringRef name, TypeRef outTy, |
1293 | NodeValue batch, |
1294 | llvm::ArrayRef<unsigned_t> axes); |
1295 | |
1296 | /// Create a node, performing BatchedReduceMin operation. Output type |
1297 | /// matches input \p outTy type. |
1298 | BatchedReduceMinNode *createBatchedReduceMin(llvm::StringRef name, |
1299 | TypeRef outTy, NodeValue batch, |
1300 | llvm::ArrayRef<unsigned_t> axes); |
1301 | |
1302 | /// Create a node, performing BatchedReduceMin operation. Output type is |
1303 | /// based on the input \p batch type with dimensions specified with \p axes |
1304 | /// removed. |
1305 | BatchedReduceMinNode *createBatchedReduceMin(llvm::StringRef name, |
1306 | NodeValue batch, |
1307 | llvm::ArrayRef<unsigned_t> axes); |
1308 | |
1309 | /// Create a node, performing BatchedReduceMax operation. Output type |
1310 | /// matches input \p outTy type. |
1311 | BatchedReduceMaxNode *createBatchedReduceMax(llvm::StringRef name, |
1312 | TypeRef outTy, NodeValue batch, |
1313 | llvm::ArrayRef<unsigned_t> axes); |
1314 | |
1315 | /// Create a node, performing BatchedReduceMax operation. Output type is |
1316 | /// based on the input \p batch type with dimensions specified with \p axes |
1317 | /// removed. |
1318 | BatchedReduceMaxNode *createBatchedReduceMax(llvm::StringRef name, |
1319 | NodeValue batch, |
1320 | llvm::ArrayRef<unsigned_t> axes); |
1321 | |
1322 | /// Create a node, performing BatchedReduceMean operation. Output type |
1323 | /// matches input \p outTy type. |
1324 | BatchedReduceMeanNode * |
1325 | createBatchedReduceMean(llvm::StringRef name, TypeRef outTy, NodeValue batch, |
1326 | llvm::ArrayRef<unsigned_t> axes); |
1327 | |
1328 | /// Create a node, performing BatchedReduceMean operation. Output type is |
1329 | /// based on the input \p batch type with dimensions specified with \p axes |
1330 | /// removed. |
1331 | BatchedReduceMeanNode * |
1332 | createBatchedReduceMean(llvm::StringRef name, NodeValue batch, |
1333 | llvm::ArrayRef<unsigned_t> axes); |
1334 | |
1335 | /// Create a node, performing BatchedReduceProd operation. Output type is |
1336 | /// based on the input \p batch type with dimensions specified with \p axes |
1337 | /// removed. |
1338 | BatchedReduceProdNode * |
1339 | createBatchedReduceProd(llvm::StringRef name, NodeValue batch, |
1340 | llvm::ArrayRef<unsigned_t> axes); |
1341 | |
1342 | /// Create a node, performing BatchedReduceProd operation. Output type |
1343 | /// matches input \p outTy type. |
1344 | BatchedReduceProdNode * |
1345 | createBatchedReduceProd(llvm::StringRef name, TypeRef outTy, NodeValue batch, |
1346 | llvm::ArrayRef<unsigned_t> axes); |
1347 | |
1348 | BatchedAddNode *createBatchedAdd(llvm::StringRef name, NodeValue batch, |
1349 | NodeValue slice); |
1350 | |
1351 | BatchedAddNode *createBatchedAdd(llvm::StringRef name, TypeRef outTy, |
1352 | NodeValue batch, NodeValue slice); |
1353 | |
1354 | BatchedMulNode *createBatchedMul(llvm::StringRef name, NodeValue batch, |
1355 | NodeValue slice); |
1356 | |
1357 | BatchedMulNode *createBatchedMul(llvm::StringRef name, TypeRef outTy, |
1358 | NodeValue batch, NodeValue slice); |
1359 | |
1360 | /// Create a node performing a Cumulative Sum operation, output type matches |
1361 | /// \p input type. |
1362 | CumSumNode *createCumSum(llvm::StringRef name, NodeValue input, |
1363 | int64_t dim = 0, bool exclusive = false, |
1364 | bool reverse = false); |
1365 | |
1366 | /// Implements an operation that accumulates the values in \p data along the |
1367 | /// first dimension into len(\p lengths) entries by summing together the first |
1368 | /// lengths[0] values, then the subsequent lengths[1] values, etc. |
1369 | /// sum(\p lengths) must equal the first dimension of \p data. This operation |
1370 | /// is similar to SparseLengthsSum but the input is a dense represention |
1371 | /// instead of a sparse one. In other words, it has already been Gathered. |
1372 | LengthsSumNode *createLengthsSum(llvm::StringRef name, NodeValue data, |
1373 | NodeValue lengths); |
1374 | |
1375 | /// Create a node, performing SparseLengthsSum operation: |
1376 | /// Gathers slices of the outer-most dimension of Data indexed by Indices |
1377 | /// vector, and then accumulates them into len(Lengths) entries: |
1378 | /// first Lengths[0] slices are aggregated to Result[0], next Lengths[1] |
1379 | /// slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal |
1380 | /// to len(Indices). \p lengthsMode and \p avgLength represent meta |
1381 | /// information about the \p lengths, allowing the backend to use a |
1382 | /// specialized implementation. |
1383 | SparseLengthsSumNode * |
1384 | createSparseLengthsSum(llvm::StringRef name, NodeValue data, |
1385 | NodeValue indices, NodeValue lengths, |
1386 | LengthsMode lengthsMode = LengthsMode::Variable, |
1387 | float avgLength = NAN); |
1388 | |
1389 | /// Same as SparseLengthsSum, but i-th slice is multiplied by weights[i]. |
1390 | /// len(weights) must be equal to len(indices). |
1391 | SparseLengthsWeightedSumNode *createSparseLengthsWeightedSum( |
1392 | llvm::StringRef name, NodeValue data, NodeValue weights, |
1393 | NodeValue indices, NodeValue lengths, |
1394 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1395 | |
1396 | /// Create an Embedding node |
1397 | /// weights is a 2D tensor capturing the embedding table |
1398 | /// indices is a tesnor of arbitrary shape containing the indices to extract |
1399 | /// padIdx, if given, zeros the output vector when encounters padIdx |
1400 | /// scale, if true, will scale gradients by the inverse of the frequency of |
1401 | /// words in mini-batch (currently not supported, default=false) |
1402 | /// sparse, if true, gradinet w.r.t. weight matrix will be a sparse tensor |
1403 | /// (currently not supported, default=false) |
1404 | EmbeddingNode *createEmbedding(llvm::StringRef name, NodeValue weights, |
1405 | NodeValue indices, int32_t padIdx, bool scale, |
1406 | bool sparse); |
1407 | |
1408 | /// Create an EmbeddingBag node. If \p hasEndOffset is true then the node |
1409 | /// expects an extra offset to be appended to \p offsets which marks the end |
1410 | /// of the last range. \p lengthsMode and \p avgLength represent meta |
1411 | /// information about the \p lengths, allowing the backend to use a |
1412 | /// specialized implementation. |
1413 | EmbeddingBagNode *createEmbeddingBag( |
1414 | llvm::StringRef name, NodeValue data, NodeValue weights, |
1415 | NodeValue indices, NodeValue offsets, bool hasEndOffset = false, |
1416 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1417 | |
1418 | /// Create an EmbeddingBagByteRowwiseOffsetsNode node. If \p hasEndOffset is |
1419 | /// true then the node expects an extra offset to be appended to \p offsets |
1420 | /// which marks the end of the last range. \p lengthsMode and \p avgLength |
1421 | /// represent meta information about the \p lengths, allowing the backend to |
1422 | /// use a specialized implementation. |
1423 | EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets( |
1424 | llvm::StringRef name, NodeValue data, NodeValue weights, |
1425 | NodeValue indices, NodeValue offsets, bool useFP16Accumulation = false, |
1426 | bool hasEndOffset = false, |
1427 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1428 | |
1429 | /// Same as \ref createEmbeddingBagByteRowwiseOffsets(), but |
1430 | /// expects float input \p data, which is rowwise-quantized and fused |
1431 | /// internally. \p fusedElemKind represents the element kind to use for the |
1432 | /// final fused rowwise-quantized data. If \p hasEndOffset is true then the |
1433 | /// node expects an extra offset to be appended to \p offsets which marks the |
1434 | /// end of the last range. |
1435 | EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets( |
1436 | llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices, |
1437 | NodeValue offsets, ElemKind fusedElemKind = ElemKind::UInt8FusedQTy, |
1438 | bool useFP16Accumulation = false, bool hasEndOffset = false, |
1439 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1440 | |
1441 | /// Same as \ref createSparseLengthsWeightedSum(), but with \p outTy |
1442 | /// specified. |
1443 | SparseLengthsWeightedSumNode *createSparseLengthsWeightedSum( |
1444 | llvm::StringRef name, TypeRef outTy, NodeValue data, NodeValue weights, |
1445 | NodeValue indices, NodeValue lengths, |
1446 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1447 | |
1448 | /// Creates and \returns a node of \p name, performing the SparseLengthsSum |
1449 | /// operation, using rowwise quantization for the input \p data with the \p |
1450 | /// scales and \p offsets as separate input tensors. Gathers slices of the |
1451 | /// outer-most dimension of data indexed by the \p indices vector, and then |
1452 | /// accumulates them into len(\p lengths) entries: first Lengths[0] slices are |
1453 | /// aggregated to Result[0], next Lengths[1] slices are aggregated to |
1454 | /// Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). |
1455 | /// \p precision represents what precision to use for Scale, Offset, and |
1456 | /// Result. If \p useFP16Accumulation, then internal arithmetic will use FP16 |
1457 | /// accumulation; otherwise defaults to FP32. \p lengthsMode and \p avgLength |
1458 | /// represent meta information about the \p lengths, allowing the backend to |
1459 | /// use a specialized implementation. |
1460 | RowwiseQuantizedSparseLengthsWeightedSumNode * |
1461 | createRowwiseQuantizedSparseLengthsSum( |
1462 | llvm::StringRef name, Storage *data, NodeValue scales, NodeValue offsets, |
1463 | NodeValue indices, NodeValue lengths, |
1464 | ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false, |
1465 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1466 | |
1467 | /// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but expects |
1468 | /// float input \p data, which is rowwise-quantized internally. |
1469 | RowwiseQuantizedSparseLengthsWeightedSumNode * |
1470 | createRowwiseQuantizedSparseLengthsSum( |
1471 | llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths, |
1472 | quantization::Schema schema, ElemKind precision = ElemKind::FloatTy, |
1473 | bool useFP16Accumulation = false, |
1474 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1475 | |
1476 | /// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but i-th slice is |
1477 | /// multiplied by weights[i]. len(weights) must be equal to len(indices). |
1478 | RowwiseQuantizedSparseLengthsWeightedSumNode * |
1479 | createRowwiseQuantizedSparseLengthsWeightedSum( |
1480 | llvm::StringRef name, Storage *data, NodeValue scales, NodeValue offsets, |
1481 | NodeValue weights, NodeValue indices, NodeValue lengths, |
1482 | ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false, |
1483 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1484 | |
1485 | /// Same as \ref createRowwiseQuantizedSparseLengthsWeightedSum(), but expects |
1486 | /// float input \p data, which is rowwise-quantized internally. |
1487 | RowwiseQuantizedSparseLengthsWeightedSumNode * |
1488 | createRowwiseQuantizedSparseLengthsWeightedSum( |
1489 | llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices, |
1490 | NodeValue lengths, quantization::Schema schema, |
1491 | ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false, |
1492 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1493 | |
1494 | /// Creates and \returns a node of \p name, performing the SparseLengthsSum |
1495 | /// operation, using fused rowwise quantization for the input \p data wherein |
1496 | /// the scales and offsets are fused inline with each row of data. \p data |
1497 | /// must be of a fused ElemKind. Gathers slices of the outer-most dimension of |
1498 | /// data indexed by the \p indices vector, and then accumulates them into |
1499 | /// len(\p lengths) entries: first Lengths[0] slices are aggregated to |
1500 | /// Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. |
1501 | /// sum(Lengths) must be equal to len(Indices). The precision for the Result |
1502 | /// is determined by the \p data input's ElemKind used for Scale and |
1503 | /// Offset. If \p useFP16Accumulation, then internal arithmetic will use FP16 |
1504 | /// accumulation; otherwise defaults to FP32. \p lengthsMode and \p avgLength |
1505 | /// represent meta information about the \p lengths, allowing the backend to |
1506 | /// use a specialized implementation. |
1507 | FusedRowwiseQuantizedSparseLengthsSumNode * |
1508 | createFusedRowwiseQuantizedSparseLengthsSum( |
1509 | llvm::StringRef name, Storage *data, NodeValue indices, NodeValue lengths, |
1510 | bool useFP16Accumulation = false, |
1511 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1512 | |
1513 | /// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but expects |
1514 | /// float input \p data, which is rowwise-quantized and fused internally. |
1515 | /// \p fusedElemKind represents the element kind to use for the final fused |
1516 | /// rowwise-quantized data. |
1517 | FusedRowwiseQuantizedSparseLengthsSumNode * |
1518 | createFusedRowwiseQuantizedSparseLengthsSum( |
1519 | llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths, |
1520 | ElemKind fusedElemKind = ElemKind::UInt8FusedQTy, |
1521 | bool useFP16Accumulation = false, |
1522 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1523 | |
1524 | /// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but i-th slice |
1525 | /// is multiplied by weights[i]. len(weights) must be equal to len(indices). |
1526 | FusedRowwiseQuantizedSparseLengthsWeightedSumNode * |
1527 | createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
1528 | llvm::StringRef name, NodeValue data, NodeValue weights, |
1529 | NodeValue indices, NodeValue lengths, bool useFP16Accumulation = false, |
1530 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1531 | |
1532 | /// Same as \ref createFusedRowwiseQuantizedSparseLengthsWeightedSum(), but |
1533 | /// expects float input \p data, which is rowwise-quantized and fused |
1534 | /// internally. \p fusedElemKind represents the element kind to use for the |
1535 | /// final fused rowwise-quantized data. |
1536 | FusedRowwiseQuantizedSparseLengthsWeightedSumNode * |
1537 | createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
1538 | llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices, |
1539 | NodeValue lengths, ElemKind fusedElemKind = ElemKind::UInt8FusedQTy, |
1540 | bool useFP16Accumulation = false, |
1541 | LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN); |
1542 | |
1543 | /// Given a vector of segment lengths, calculates offsets of each segment and |
1544 | /// packs them next to the lengths. For the input vector of length N the |
1545 | /// output is a Nx2 matrix with (offset, lengths) packaged for each segment. |
1546 | LengthsToRangesNode *createLengthsToRanges(llvm::StringRef name, |
1547 | NodeValue lengths); |
1548 | |
1549 | /// Given a vector of \p lengths, \returns a LengthsRangeFillNode. This Node |
1550 | /// calculates a range sequence given \p lengths, where the sum of the |
1551 | /// elements of \p lengths must be no greater than \p maxOutputSize, which is |
1552 | /// used to set the output type. |
1553 | LengthsRangeFillNode *createLengthsRangeFill(llvm::StringRef name, |
1554 | NodeValue lengths, |
1555 | unsigned_t maxOutputSize); |
1556 | |
1557 | /// Implements an operation that converts the sparse representation given by |
1558 | /// the \p lengths, \p indices and \p values into a dense representation. |
1559 | /// This representation contains \p lengths[i] indices in batch i, and in each |
1560 | /// batch contains the value of \p values at the corresponding index given by |
1561 | /// \p indices. All indices that are not present in \p indices are filled with |
1562 | /// defaultValue. \p indices within the same batch should not contain |
1563 | /// duplicates. \p denseLastDim gives the last dimension of the output dense |
1564 | /// representation (ie. the second dimension). |
1565 | BatchSparseToDenseNode * |
1566 | createBatchSparseToDense(llvm::StringRef name, NodeValue lengths, |
1567 | NodeValue indices, NodeValue values, |
1568 | float defaultValue, unsigned_t denseLastDim); |
1569 | |
1570 | /// Implements an operation that inserts zeros into \p data along axis=0 for |
1571 | /// indices where \p indicator is zero. |
1572 | FillExamplesWithIndicatorNode * |
1573 | createFillExamplesWithIndicator(llvm::StringRef name, NodeValue data, |
1574 | NodeValue indicator); |
1575 | |
1576 | /// Implements an operation that converts the sparse representation given by |
1577 | /// the pair of \p indices and \p values into a dense representation, which |
1578 | /// only contains IDs from given \p mask. Indices cannot contain duplicates. |
1579 | /// \p lengths is used to distinguish elements that belong to different |
1580 | /// examples of one batch. That is, first \p lengths[0] index-value pairs |
1581 | /// belong to batch's example 0, next \p lengths[1] pairs belong to example 1 |
1582 | /// and so on. |
1583 | SparseToDenseMaskNode * |
1584 | createSparseToDenseMask(llvm::StringRef name, NodeValue indices, |
1585 | NodeValue values, NodeValue defaultValue, |
1586 | NodeValue lengths, llvm::ArrayRef<dim_t> mask); |
1587 | |
1588 | // TODO: add description |
1589 | SparseLabelSplitNode * |
1590 | createSparseLabelSplit(llvm::StringRef name, NodeValue lengths, |
1591 | NodeValue indices, NodeValue values, dim_t numLabels); |
1592 | |
1593 | /// Given floats a input node \p input, floats \p mean and \p scale, and \p |
1594 | /// seed \returns a GaussianFillNode. The output shape is the same as that of |
1595 | /// \p input, filled with values drawn from a normal distribution with mean |
1596 | /// and std dev \p mean and \scale, respectively, seeded with seed \p seed |
1597 | GaussianFillNode *createGaussianFill(llvm::StringRef name, NodeValue input, |
1598 | float mean, float scale, float seed); |
1599 | |
1600 | SaveNode *createSave(llvm::StringRef name, NodeValue input); |
1601 | |
1602 | /// Creates and \returns a SaveNode of \p input to \p output. If \p skipSuffix |
1603 | /// then the name used is \p name, otherwise suffix "_save" is appended. |
1604 | SaveNode *createSave(llvm::StringRef name, NodeValue input, |
1605 | Placeholder *output, bool skipSuffix = false); |
1606 | |
1607 | /// Create quantization profile node named \p name for the output tensor from |
1608 | /// \p input in PlaceholderBindings \p bindings. Capture observed node name in |
1609 | /// quantization profile node as original node can be replaced during lowering |
1610 | /// phase. Compute the histogram during profiling with \p numHistogramBins. |
1611 | QuantizationProfileNode * |
1612 | createQuantizationProfile(PlaceholderBindings &bindings, llvm::StringRef name, |
1613 | NodeValue input, dim_t numHistogramBins = 10); |
1614 | |
1615 | /// Create lookup table for mapping between quantized operands. \p input and |
1616 | /// \p outTy must be quantized types. The table contains all numbers from the |
1617 | /// quantized range, e.g. 256 entries for int8 input. First position in the |
1618 | /// \p initValues corresponds to the minimum input number and the last |
1619 | /// position corresponds to the maximum input number. |
1620 | template <typename T = int8_t> |
1621 | IntLookupTableNode * |
1622 | createIntLookupTable(llvm::StringRef name, NodeValue input, |
1623 | llvm::ArrayRef<T> initValues, TypeRef outTy); |
1624 | |
1625 | /// Create lookup table for mapping between quantized operands based on the |
1626 | /// floating point function \p func. \p input and \p outTy must be quantized |
1627 | /// types. |
1628 | IntLookupTableNode *createIntLookupTable(llvm::StringRef name, |
1629 | NodeValue input, |
1630 | std::function<float(float)> func, |
1631 | TypeRef outTy); |
1632 | |
1633 | /// Create lookup table for operator \p lutOperator using the provided lookup |
1634 | /// \p table. |
1635 | LookupTableNode *createLookupTable(llvm::StringRef name, NodeValue input, |
1636 | LUTOperator lutOperator, |
1637 | std::vector<float> &lutOperatorArgs, |
1638 | NodeValue table, NodeValue idxTable, |
1639 | TypeRef outTy); |
1640 | |
1641 | /// Create quantized log. |
1642 | IntLookupTableNode *createIntLog(llvm::StringRef name, NodeValue input, |
1643 | TypeRef outTy); |
1644 | |
1645 | /// Create quantized exp. |
1646 | IntLookupTableNode *createIntExp(llvm::StringRef name, NodeValue input, |
1647 | TypeRef outTy); |
1648 | |
1649 | /// Create quantized tanh. |
1650 | IntLookupTableNode *createIntTanh(llvm::StringRef name, NodeValue input, |
1651 | TypeRef outTy); |
1652 | |
1653 | /// Create quantized sigmoid. |
1654 | IntLookupTableNode *createIntSigmoid(llvm::StringRef name, NodeValue input, |
1655 | TypeRef outTy); |
1656 | |
1657 | TopKNode *createTopK(llvm::StringRef name, NodeValue input, unsigned_t k); |
1658 | |
1659 | TopKNode *createTopK(llvm::StringRef name, NodeValue input, unsigned_t k, |
1660 | ElemKind outIndicesTyKind); |
1661 | |
1662 | /// Given \p rpnMaxLevel , \p rpnMinLevel and \p rpnPostNmsTopN |
1663 | /// CollectRpnProposals merges rois in the \p roisIN based on \p roisProbIn |
1664 | /// and returns top proposals limited to rpnPostNmsTopN total, size (n x B), |
1665 | /// where B is box dimensions and based on dimension of input rois |
1666 | /// Format for upright boxes is (image_index, x1, y1, x2, y2). |
1667 | /// Format for rotated boxes (image_index, ctr_x, ctr_y, w, h, angle) |
1668 | /// rpnPostNmsTopN should be greater than zero. |
1669 | CollectRpnProposalsNode *createCollectRpnProposals( |
1670 | llvm::StringRef name, std::vector<NodeValue> &roisIn, |
1671 | std::vector<NodeValue> &roiProbsIn, int64_t rpnMaxLevel, |
1672 | int64_t rpnMinLevel, unsigned_t rpnPostNmsTopN); |
1673 | |
1674 | /// Given \p data tensor of rank r >= 1, and \p indices tensor of rank q, |
1675 | /// gather entries of the \p axis dimension of data (default outer-most for |
1676 | /// axis = 0) indexed by indices and concatenate them in the output tensor of |
1677 | /// rank q + (r - 1). |
1678 | /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#gather |
1679 | GatherNode *createGather(llvm::StringRef name, NodeValue data, |
1680 | NodeValue indices, unsigned_t axis = 0); |
1681 | |
1682 | /// Given \p data tensor of rank r >= 1, \p indices tensor of rank q >= 1, |
1683 | /// and \p batchDims integer b, this operator gathers slices of data |
1684 | /// into an output tensor of rank q + r - indices_shape[-1] - 1 - b. |
1685 | /// \p indices is a q-dimensional integer tensor, best thought of as a (q-1) |
1686 | /// dimensional tensor of index-tuples into \p data, where each element |
1687 | /// defines a slice of data. |
1688 | /// \p batchDims is an integer b indicating the number of batch dimensions, |
1689 | /// that is the leading number of dimensions of \p data tensor and |
1690 | /// \p indices that are representing the batches such that the gather starts |
1691 | /// from the b+1 dimension. |
1692 | /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#gathernd |
1693 | GatherNDNode *createGatherND(llvm::StringRef name, NodeValue data, |
1694 | NodeValue indices, unsigned_t batchDims = 0); |
1695 | |
1696 | /// Create a node, performing GatherElements operation: |
1697 | /// GatherElements takes inputs \p data and indices of the same rank r >= |
1698 | /// 1 and a \p dim attribute that identifies an axis of data. It is an |
1699 | /// indexing operation that produces its output by indexing into the input |
1700 | /// data tensor at by elements of the indices tensor. Its output shape is |
1701 | /// the same as the shape of indices and consists of one value (gathered |
1702 | /// from the data) for each element in indices. |
1703 | GatherElementsNode *createGatherElements(llvm::StringRef name, NodeValue data, |
1704 | NodeValue indices, unsigned_t dim); |
1705 | |
1706 | /// Create a node, performing GatherRanges operation: |
1707 | /// Gathers entries of \p data in groups specified by the "examples" in |
1708 | /// \p ranges. Each example in \p ranges contains a list of pairs of |
1709 | /// indices of the form (index, length) which specify which entries of \p |
1710 | /// data to gather. The ordering of elements in \p ranges and of pairs |
1711 | /// within an element is preserved in the output. In addition to the result |
1712 | /// of gathering ("output"), the lengths of the ranges gathered by each |
1713 | /// example in \p ranges is also produced as an output ("lengths"). |
1714 | /// \p maxOutputSize is the maximum possible size of "output" and is used to |
1715 | /// set its type. Users must use "lengths" to interpret "output" correctly. |
1716 | /// \returns the GatherRangesNode. |
1717 | GatherRangesNode *createGatherRanges(llvm::StringRef name, NodeValue data, |
1718 | NodeValue ranges, |
1719 | unsigned_t maxOutputSize); |
1720 | |
1721 | /// Copies each slice from \p slices into \p data at the corresponding index |
1722 | /// in \p indices, and \returns this new version of data. For example, given |
1723 | /// input data {{1,2},{3,4},{5,6}}, slices {{-3,-4}}, and indices {1}, the |
1724 | /// result is {{1,2},{-3,-4},{5,6}}. If \p cumulative is true, this node adds |
1725 | /// values instead of copying. |
1726 | ScatterDataNode *createScatterData(llvm::StringRef name, NodeValue data, |
1727 | NodeValue indices, NodeValue slices, |
1728 | bool cumulative = false); |
1729 | |
1730 | /// Given 2D matrix \p data, 1D vector \p lengths (of the same size as width |
1731 | /// of \p data), and 1D vector \p values (of the same size as sum of |
1732 | /// \p lengths), expand each row of the \p data to a row of zeros and ones, |
1733 | /// according to One Hot Encoding. j-th element of resulting i-th row is one |
1734 | /// iff \p values[j] == \p data[i][some index within range of j]. |
1735 | BatchOneHotNode *createBatchOneHot(llvm::StringRef name, NodeValue data, |
1736 | NodeValue lengths, NodeValue values); |
1737 | |
1738 | /// Given Input tensor of [N,H,W,C], where N is the batch |
1739 | /// axis, H is the height, W is |
1740 | /// the width, C is the channel or depth. This produces Output tensor of [N, |
1741 | /// H/blockSize, W/blockSize, C * blockSize * blockSize]. |
1742 | SpaceToDepthNode *createSpaceToDepth(llvm::StringRef name, NodeValue input, |
1743 | unsigned blockSize); |
1744 | |
1745 | /// Create a sequence of Reshape and Transpose nodes representing DepthToSpace |
1746 | /// operator with \p blockSize in DCR or CRD mode based on \p isCRD flag. |
1747 | /// Assumes input layout to be NHWC. \returns the last node in the sequence. |
1748 | ReshapeNode *createDepthToSpace(llvm::StringRef name, NodeValue input, |
1749 | unsigned blockSize, bool isCRD = false); |
1750 | |
1751 | /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel |
1752 | /// or depth, H is the height and W is the width, and \p scale tensor with |
1753 | /// tensor format same as \p input then ResizeNearest generates an Output |
1754 | /// tensor with resized spatial dimensions using nearest neighbor |
1755 | /// interpolation. The Output tensor is of shape [floor(N * \p scale[0]), |
1756 | /// floor(H * \p scale[1]), floor(W * \p scale[2]), |
1757 | /// floor(C * \p scale[3])] |
1758 | ResizeNearestNode *createResizeNearest(llvm::StringRef name, NodeValue input, |
1759 | llvm::ArrayRef<float> scale); |
1760 | |
1761 | /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel |
1762 | /// or depth, H is the height and W is the width, with tensor format same as |
1763 | /// \p input then ResizeNearest generates an Output tensor with resized |
1764 | /// spatial dimensions using nearest neighbor interpolation. The Output tensor |
1765 | /// shape is specified with \p outTy. |
1766 | ResizeNearestNode *createResizeNearest(llvm::StringRef name, NodeValue input, |
1767 | TypeRef outTy); |
1768 | |
1769 | /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel |
1770 | /// or depth, H is the height and W is the width, and \p scale tensor with |
1771 | /// tensor format same as \p input then ResizeBilinear generates an Output |
1772 | /// tensor with resized spatial dimensions using bilinear neighbor |
1773 | /// interpolation. The Output tensor is of shape [floor(N * \p scale[0]), |
1774 | /// floor(H * \p scale[1]), floor(W * \p scale[2]), |
1775 | /// floor(C * \p scale[3])] |
1776 | ResizeBilinearNode *createResizeBilinear(llvm::StringRef name, |
1777 | NodeValue input, |
1778 | llvm::ArrayRef<float> scale); |
1779 | |
1780 | /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel |
1781 | /// or depth, H is the height and W is the width, with tensor format same as |
1782 | /// \p input then ResizeBilinear generates an Output tensor with resized |
1783 | /// spatial dimensions using bilinear neighbor interpolation. The Output |
1784 | /// tensor shape is specified with \p outTy. |
1785 | ResizeBilinearNode *createResizeBilinear(llvm::StringRef name, |
1786 | NodeValue input, TypeRef outTy); |
1787 | |
1788 | /// Create quantization node which transforms floating point tensor to a |
1789 | /// quantized one with given Scale and Offset. Scale and Offset params are |
1790 | /// part of the \p outTy. |
1791 | QuantizeNode *createQuantize(llvm::StringRef name, NodeValue input, |
1792 | TypeRef outTy); |
1793 | |
1794 | /// Create quantization node which transforms floating point tensor to a |
1795 | /// quantized one of kind \p q with given \p scale and \p offset. |
1796 | QuantizeNode *createQuantize(llvm::StringRef name, NodeValue input, |
1797 | ElemKind q, float scale, int32_t offset); |
1798 | |
1799 | /// Create dequantization node which transforms quantized tensor to a |
1800 | /// floating point one with given Scale and Offset. Scale and Offset params |
1801 | /// are part of the \p input. Result dequantization kind is \p k. |
1802 | DequantizeNode *createDequantize(llvm::StringRef name, NodeValue input, |
1803 | ElemKind k); |
1804 | |
1805 | /// Create dequantization node which transforms quantized tensor to a |
1806 | /// floating point type \p outTy one with given Scale and Offset. Scale and |
1807 | /// Offset params are part of the \p input. |
1808 | DequantizeNode *createDequantize(llvm::StringRef name, NodeValue input, |
1809 | TypeRef outTy); |
1810 | |
1811 | /// Create transformation for quantized tensors to rescale based on the new |
1812 | /// Scale and Offset. |
1813 | RescaleQuantizedNode *createRescaleQuantized(llvm::StringRef name, |
1814 | NodeValue input, TypeRef outTy); |
1815 | |
1816 | /// Create a series of nodes that implement a weighted sum. \p data and \p |
1817 | /// weights should have the same number of elements. The nodes in \p weights |
1818 | /// should all be of size 1. Each node d_i in \p data is element-wise |
1819 | /// multiplied by the corresponding weight value w_i found in \p weights, |
1820 | /// broadcasted to the same shape as d_i, and resulting in r_i. All r_i are |
1821 | /// element-wise summed, and the final add node in this sum is returned. |
1822 | Node *createWeightedSum(llvm::StringRef name, llvm::ArrayRef<NodeValue> data, |
1823 | llvm::ArrayRef<NodeValue> weights); |
1824 | |
1825 | /// Create a series of nodes that implements a two-parameter |
1826 | /// rowwise Box-Cox transform. For each element of the \p input x, this is |
1827 | /// defined as: |
1828 | /// |
1829 | /// y = ln(max(x + lambda2, 1e-6)), if lambda1 == 0 |
1830 | /// (max(x + lambda2, 1e-6)^lambda1 - 1)/lambda1, if lambda1 != 0 |
1831 | /// |
1832 | /// The transform parameters \p lambda1 and \p lambda2 are vectors of size D |
1833 | /// that are broadcasted to match the size of \p input (NxD). The transform |
1834 | /// itself is implemented using elementwise Max, Add, Log (if lambda1 == 0), |
1835 | /// Pow, Splat, Sub, and Div (if lambda1 != 0) nodes with a Splat and Select |
1836 | /// node to select between the two cases listed above. \returns the final |
1837 | /// Select node. \p epsilon is used to ensure we do not divide by zero when |
1838 | /// calculating the lambda == 0 case, as we use a Select to choose which |
1839 | /// result to use, and so both paths are executed. |
1840 | Node *createBatchBoxCox(llvm::StringRef name, NodeValue input, |
1841 | NodeValue lambda1, NodeValue lambda2, |
1842 | float epsilon = std::numeric_limits<float>::min()); |
1843 | |
1844 | /// Create a Clip node with the given \p name, \p input, minimum clip value |
1845 | /// \p min, maximum clip value \p max and output type \p outTy. |
1846 | ClipNode *createClip(llvm::StringRef name, NodeValue input, TypeRef outTy, |
1847 | float min, float max); |
1848 | |
1849 | /// Create a Clip node with the given \p name, \p input, minimum clip value |
1850 | /// \p min, maximum clip value \p max. Result type will be implicitly set |
1851 | /// based on the \p input type. |
1852 | ClipNode *createClip(llvm::StringRef name, NodeValue input, float min, |
1853 | float max); |
1854 | |
1855 | /// Creates and \returns a ClipNode to the min/max range of FP16 with \p name |
1856 | /// of \p input. Result type will be implicitly set based on the \p input |
1857 | /// type. |
1858 | ClipNode *createClipMinMaxFP16(llvm::StringRef name, NodeValue input); |
1859 | |
1860 | /// Creates and \returns a ClipNode to the min/max range of BFloat16 with \p |
1861 | /// name of \p input. Result type will be implicitly set based on the \p input |
1862 | /// type. |
1863 | ClipNode *createClipMinMaxBFloat16(llvm::StringRef name, NodeValue input); |
1864 | |
1865 | /// @name The builder functions below are identical to the builder functions |
1866 | /// above except that they create nodes that use Placeholder instead of |
1867 | /// Variables. The methods create and initialize the tensors in the |
1868 | /// PlaceholderBindings. As soon as we finish the Placeholder migration we'll |
1869 | /// delete these methods and merge them with the builder methods above. See |
1870 | /// issue #1334. |
1871 | ///@{ |
1872 | |
1873 | BatchNormalizationNode * |
1874 | createBatchNormalization(PlaceholderBindings &bindings, llvm::StringRef name, |
1875 | NodeValue input, unsigned_t channelIdx = 0, |
1876 | float epsilon = 1e-5, float momentum = 0.9); |
1877 | |
1878 | /// Create a BatchedUnaryEmbedding node. \p offsets which marks the end |
1879 | /// of the last range. |
1880 | BatchedUnaryEmbeddingsBagsNode * |
1881 | createBatchedUnaryEmbeddingsBags(llvm::StringRef name, NodeValue weights, |
1882 | NodeValue tableOffsets, NodeValue indices, |
1883 | NodeValue offsets); |
1884 | |
1885 | /// Create an IntNBitSplitEmbeddingBags node. |
1886 | IntNBitSplitEmbeddingBagsNode *createIntNBitSplitEmbeddingBags( |
1887 | llvm::StringRef name, NodeValue devWeights, NodeValue uvmWeights, |
1888 | NodeValue weightsPlacements, NodeValue weightsOffsets, |
1889 | NodeValue weightsTys, NodeValue dimOffsets, int64_t totalDims, |
1890 | NodeValue indices, NodeValue offsets, |
1891 | SplitEmbeddingPoolingMode poolingMode, |
1892 | SplitEmbeddingSparseType outputDtype); |
1893 | |
1894 | /// Create an IntNBitSplitEmbeddingWeightedBags node. |
1895 | IntNBitSplitEmbeddingWeightedBagsNode * |
1896 | createIntNBitSplitEmbeddingWeightedBags( |
1897 | llvm::StringRef name, NodeValue devWeights, NodeValue uvmWeights, |
1898 | NodeValue weightsPlacements, NodeValue weightsOffsets, |
1899 | NodeValue weightsTys, NodeValue dimOffsets, int64_t totalDims, |
1900 | NodeValue indices, NodeValue offsets, |
1901 | SplitEmbeddingPoolingMode poolingMode, |
1902 | SplitEmbeddingSparseType outputDtype, NodeValue indiceWeights); |
1903 | |
1904 | /// Creates a ConvolutionNode with the given \p name which convolves the 4D |
1905 | /// \p input. \p kernels defines the size of the height and width dimensions |
1906 | /// of the convolutional filters. \p stride defines the the number of steps |
1907 | /// to take in the input for each output cell. \p pads defines how many zero |
1908 | /// padding cells should be added to the input during convolution. \p group |
1909 | /// defines the number of groups the input and output channels should be |
1910 | /// divided into and convolved separately. \p dilation defines factor by |
1911 | /// which gap between 2 neighboring kernel elements is expanded along each |
1912 | /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW. |
1913 | ConvolutionNode *createConv(PlaceholderBindings &bindings, |
1914 | llvm::StringRef name, NodeValue input, |
1915 | dim_t outChannels, |
1916 | llvm::ArrayRef<unsigned_t> kernels, |
1917 | llvm::ArrayRef<unsigned_t> strides, |
1918 | llvm::ArrayRef<unsigned_t> pads, unsigned_t group, |
1919 | llvm::ArrayRef<unsigned_t> dilation = {1, 1}, |
1920 | ConvolutionLayout layout = NHWC); |
1921 | |
1922 | /// Creates a ConvolutionNode with the given \p name which convolves the 4D |
1923 | /// \p input. \p kernel defines the size of the height and width dimensions of |
1924 | /// the convolutional filters. \p stride defines the the number of steps to |
1925 | /// take in the input for each output cell. \p pad defines how many zero |
1926 | /// padding cells should be added to the input during convolution. \p group |
1927 | /// defines the number of groups the input and output channels should be |
1928 | /// divided into and convolved separately.\p dilation defines factor by |
1929 | /// which gap between 2 neighboring kernel elements is expanded along each |
1930 | /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW. |
1931 | ConvolutionNode *createConv(PlaceholderBindings &bindings, |
1932 | llvm::StringRef name, NodeValue input, |
1933 | dim_t outChannels, unsigned_t kernel, |
1934 | unsigned_t stride, unsigned_t pad, |
1935 | unsigned_t group, |
1936 | llvm::ArrayRef<unsigned_t> dilation = {1, 1}, |
1937 | ConvolutionLayout layout = NHWC); |
1938 | |
1939 | /// Creates a Convolution3DNode with the given \p name which convolves the 5D |
1940 | /// \p input. \p kernels defines the size of the height, width, and depth |
1941 | /// dimensions of the convolutional filters. \p strides defines the the number |
1942 | /// of steps to take in the input for each output cell. \p pads defines how |
1943 | /// many zero padding cells should be added to the input during convolution. |
1944 | /// \p group defines the number of groups the input and output channels should |
1945 | /// be divided into and convolved separately. |
1946 | Convolution3DNode *createConv3D(PlaceholderBindings &bindings, |
1947 | llvm::StringRef name, NodeValue input, |
1948 | dim_t outChannels, |
1949 | llvm::ArrayRef<unsigned_t> kernels, |
1950 | llvm::ArrayRef<unsigned_t> strides, |
1951 | llvm::ArrayRef<unsigned_t> pads, |
1952 | unsigned_t group); |
1953 | |
1954 | /// Creates a Convolution3DNode with the given \p name which convolves the 5D |
1955 | /// \p input. \p kernel defines the size of the height, width, and depth |
1956 | /// dimensions of the convolutional filters. \p stride defines the the number |
1957 | /// of steps to take in the input for each output cell. \p pad defines how |
1958 | /// many zero padding cells should be added to the input during convolution. |
1959 | /// \p group defines the number of groups the input and output channels should |
1960 | /// be divided into and convolved separately. |
1961 | Convolution3DNode *createConv3D(PlaceholderBindings &bindings, |
1962 | llvm::StringRef name, NodeValue input, |
1963 | size_t outChannels, unsigned_t kernel, |
1964 | unsigned_t stride, unsigned_t pad, |
1965 | unsigned_t group); |
1966 | |
1967 | /// Creates a ConvTransposeNode with the given \p name which does transposed |
1968 | /// convolution on the 4D \p input. \p kernels define the size of the height |
1969 | /// and width dimensions of the convolution filters. \p strides define the |
1970 | /// number of steps to take in the input for each output cell. \p pads define |
1971 | /// how many zero padding cells should be added to the input during |
1972 | /// convolution. \p group defines the number of groups the input and output |
1973 | /// channels should be divided into and convolved separately. |
1974 | ConvTransposeNode *createConvTranspose( |
1975 | PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input, |
1976 | dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels, |
1977 | llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads, |
1978 | unsigned_t group, llvm::ArrayRef<unsigned_t> dilation = {1, 1}); |
1979 | |
1980 | /// Creates a ConvTransposeNode with the given \p name which does transposed |
1981 | /// convolution on the 4D \p input. \p kernel defines the size of the height |
1982 | /// and width dimensions of the convolution filters. \p stride defines the |
1983 | /// number of steps to take in the input for each output cell. \p pad defines |
1984 | /// how many zero padding cells should be added to the input during |
1985 | /// convolution. \p group defines the number of groups the input and output |
1986 | /// channels should be divided into and convolved separately. |
1987 | ConvTransposeNode * |
1988 | createConvTranspose(PlaceholderBindings &bindings, llvm::StringRef name, |
1989 | NodeValue input, dim_t outChannels, unsigned_t kernel, |
1990 | unsigned_t stride, unsigned_t pad, unsigned_t group, |
1991 | llvm::ArrayRef<unsigned_t> dilation = {1, 1}); |
1992 | |
1993 | /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights |
1994 | /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened |
1995 | /// along \p axis. Note, output type is inferred based on the input |
1996 | /// types. Trainable weight and bias variables are created implicitly. |
1997 | FullyConnectedNode *createFullyConnected(PlaceholderBindings &bindings, |
1998 | llvm::StringRef name, |
1999 | NodeValue input, dim_t outDepth, |
2000 | unsigned_t axis = 1); |
2001 | |
2002 | /// Creates an RMSNorm pair. \p X should be a 2D tensor, \p gamma and \p beta |
2003 | /// should be 1D tensors. |
2004 | std::array<Node *, 2> createRMSNorm(llvm::StringRef name, NodeValue X, |
2005 | NodeValue gamma, NodeValue beta, |
2006 | float epsilon = .0f); |
2007 | |
2008 | /// Create an unrolled single-layer Simple RNN cell with \p hiddenSize |
2009 | /// dimensionality of the hidden state and \p outputSize dimensionality of the |
2010 | /// output state. \p inputs define the input for the cell at each time step |
2011 | /// and the number of time steps is equal to the size of the \p inputs. The |
2012 | /// names of the created variables are prefixed by \p namePrefix. |
2013 | /// The output variables are written to \p outputs, they represent the |
2014 | /// activations of the output layer, unrolled over time. |
2015 | // The dimensionality of the output variables is \p batchSize x \p outputSize. |
2016 | void createSimpleRNN(PlaceholderBindings &bindings, |
2017 | llvm::StringRef namePrefix, |
2018 | const llvm::ArrayRef<NodeValue> inputs, |
2019 | unsigned batchSize, unsigned hiddenSize, |
2020 | unsigned outputSize, std::vector<NodeValue> &outputs); |
2021 | |
2022 | /// Create an unrolled single-layer GRU cell with \p hiddenSize |
2023 | /// dimensionality of the hidden state and \p outputSize dimensionality of the |
2024 | /// output state. \p inputs define the input for the cell at each time step |
2025 | /// and the number of time steps is equal to the size of the \p inputs. The |
2026 | /// names of the created variables are prefixed by \p namePrefix. |
2027 | /// The output variables are written to \p outputs, they represent the |
2028 | /// activation of the output layer, unrolled over time. |
2029 | // The dimensionality of the output variables is \p batchSize x \p outputSize. |
2030 | void createGRU(PlaceholderBindings &bindings, llvm::StringRef namePrefix, |
2031 | const llvm::ArrayRef<NodeValue> inputs, unsigned batchSize, |
2032 | unsigned hiddenSize, unsigned outputSize, |
2033 | std::vector<NodeValue> &outputs); |
2034 | |
2035 | /// Create an unrolled single-layer LSTM cell with \p hiddenSize |
2036 | /// dimensionality of the hidden state and \p outputSize dimensionality of the |
2037 | /// output state. \p inputs define the input for the cell at each time step |
2038 | /// and the number of time steps is equal to the size of the \p inputs. The |
2039 | /// names of the created variables are prefixed by \p namePrefix. |
2040 | /// The output variables are written to \p outputs, they represent the |
2041 | /// activation of the output layer, unrolled over time. |
2042 | // The dimensionality of the output variables is \p batchSize x \p outputSize. |
2043 | void createLSTM(PlaceholderBindings &bindings, llvm::StringRef namePrefix, |
2044 | const llvm::ArrayRef<NodeValue> inputs, unsigned batchSize, |
2045 | unsigned hiddenSize, unsigned outputSize, |
2046 | std::vector<NodeValue> &outputs); |
2047 | |
2048 | /// Create an LSTM Unit Node with \p Input which shape is [batch, |
2049 | /// 4*hiddenSize] and follow the order i, f, g, o, and \p C as current cell |
2050 | /// state. |
2051 | LSTMUnitNode *createLSTMUnit(llvm::StringRef namePrefix, NodeValue Input, |
2052 | NodeValue C); |
2053 | |
2054 | /// Helper function create a PyTorch style LSTM for one direction, and returns |
2055 | /// every output in a vector. \p T should be an iterator or reverse_iterator |
2056 | /// of a NodeValue vector, and /p inputItr is an iterator pointer of the input |
2057 | /// vector. \p Wx, \p Wh, \p Bx, \p Bh, \p H and \p C is i, f, g, o, hidden |
2058 | /// state and cell state, whose shape should be the same to |
2059 | /// createSingleDirectionLSTM. |
2060 | template <class T> |
2061 | std::vector<NodeValue> createSingleDirectionLSTM( |
2062 | std::string nameBase, T inputItr, const int timeSteps, NodeValue Wx, |
2063 | NodeValue Wh, NodeValue Bx, NodeValue Bh, NodeValue &H, NodeValue &C); |
2064 | |
2065 | /// Helpfer function to create Pytorch Style Multiple Layer STM for one |
2066 | /// direction |
2067 | std::vector<NodeValue> createMultipleLayerSingleDirectionLSTM( |
2068 | std::string nameBase, NodeValue input, unsigned batchSize, |
2069 | unsigned inputSize, const int timeSteps, std::vector<NodeValue> &Wx, |
2070 | std::vector<NodeValue> &Wh, std::vector<NodeValue> &Bx, |
2071 | std::vector<NodeValue> &Bh, NodeValue &H, NodeValue &C); |
2072 | |
2073 | /// Helpfer function to create sliced input for LSTM |
2074 | std::vector<NodeValue> |
2075 | createSlicedInput(NodeValue input, std::string &nameBase, unsigned batchSize, |
2076 | unsigned inputSize, const int timeSteps); |
2077 | |
2078 | /// Create PyTorch style LSTM with fixed weights and biases. |
2079 | /// The order of \p Wx \p Wh \p Bx and \p Bh is i, f, g, o, |
2080 | /// The \p inputs shape should be (numSteps, batchSize, hiddenSize), |
2081 | /// while \p Wx shape should be (inputSize, hiddenSize * 4), |
2082 | /// Wh shape should be (hiddenSize, hiddenSize * 4), |
2083 | /// \p Bx and \p Bh shape should be (hiddenSize * 4). |
2084 | /// If \p isBidirectional == true, \p WxR, \p WhR, \p BxR and \p BhR |
2085 | /// also need to be provided, indicates the reversed weights and biases. |
2086 | /// \p Ht and \p Ct are initial hidden state and cell. |
2087 | /// For more details, please read: |
2088 | /// https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html |
2089 | void createPyTorchLSTM(llvm::StringRef namePrefix, NodeValue inputs, |
2090 | std::vector<NodeValue> &Wx, std::vector<NodeValue> &Wh, |
2091 | std::vector<NodeValue> &Bx, std::vector<NodeValue> &Bh, |
2092 | NodeValue &Ht, NodeValue &Ct, NodeValue &outputs, |
2093 | bool isBidirectional = false, |
2094 | NodeValue WxR = NodeValue(), |
2095 | NodeValue WhR = NodeValue(), |
2096 | NodeValue BxR = NodeValue(), |
2097 | NodeValue BhR = NodeValue()); |
2098 | |
2099 | /// Type definition for the direction of an RNN module (RNN, GRU, LSTM). |
2100 | enum class RnnDirection { |
2101 | Forward, |
2102 | Reverse, |
2103 | Bidirectional, |
2104 | }; |
2105 | |
2106 | /// Definition for a lambda used to create an activation node for RNN modules. |
2107 | using RnnActivation = std::function<Node *(llvm::StringRef, Node *)>; |
2108 | |
2109 | /// Create an unrolled multi-layer RNN according to the ONNX definition: |
2110 | /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN |
2111 | /// The RNN has the following inputs: |
2112 | /// - input \p X with size [S, B, ISize]. |
2113 | /// - weigts \p W with size [N, HSize, ISize]. |
2114 | /// - reccurence weights \p R with size [N, HSize, HSize]. |
2115 | /// - bias weights \p B with size [N, 2 * HSize]. |
2116 | /// - initial hidden state \p initial_h with size [N, B, HSize]. |
2117 | /// where S is the sequence length, N is the number of directions, B is the |
2118 | /// batch size, ISize is the input size and HSize is the hidden size. |
2119 | /// The RNN has the following outputs: |
2120 | /// - output \p Y with size [S, N, B, HSize]. |
2121 | /// - final hidden state \p Y_h with size [N, B, HSize]. |
2122 | /// The direction of the instatiated RNN is given by \p direction. The RNN |
2123 | /// will use the activation functions defined by the \p activations array: |
2124 | /// - [f] in case the RNN is unidirectional (1 function). |
2125 | /// - [f] for the forward cell followed by [f] for the reverse cell in |
2126 | /// case the RNN is bidirectional (4 functions). |
2127 | /// The inputs \p B and \p initial_h are optional (assumed 0 if nullptr is |
2128 | /// provided). The names of all the nodes created are prefixed with |
2129 | /// \p namePrefix. |
2130 | void createOnnxRNN(llvm::StringRef namePrefix, NodeValue X, NodeValue W, |
2131 | NodeValue R, NodeValue B, NodeValue initial_h, |
2132 | NodeValue &Y, NodeValue &Y_h, unsigned hiddenSize, |
2133 | RnnDirection direction, |
2134 | std::vector<RnnActivation> &activations); |
2135 | |
2136 | /// Create an unrolled multi-layer GRU according to the ONNX definition: |
2137 | /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU |
2138 | /// The GRU has the following inputs: |
2139 | /// - input \p X with size [S, B, ISize]. |
2140 | /// - weigts \p W with size [N, 3 * HSize, ISize]. |
2141 | /// - reccurence weights \p R with size [N, 3 * HSize, HSize]. |
2142 | /// - bias weights \p B with size [N, 6 * HSize]. |
2143 | /// - initial hidden state \p initial_h with size [N, B, HSize]. |
2144 | /// where S is the sequence length, N is the number of directions, B is the |
2145 | /// batch size, ISize is the input size and HSize is the hidden size. |
2146 | /// The GRU has the following outputs: |
2147 | /// - output \p Y with size [S, N, B, HSize]. |
2148 | /// - final hidden state \p Y_h with size [N, B, HSize]. |
2149 | /// The direction of the instatiated GRU is given by \p direction. The GRU |
2150 | /// will use the activation functions defined by the \p activations array: |
2151 | /// - [f,g] in case the GRU is unidirectional (2 functions). |
2152 | /// - [f,g] for the forward cell followed by [f,g] for the reverse cell in |
2153 | /// case the GRU is bidirectional (4 functions). |
2154 | /// The inputs \p B and \p initial_h are optional (assumed 0 if nullptr is |
2155 | /// provided). The names of all the nodes created are prefixed with |
2156 | /// \p namePrefix. The boolean parameter \p linearBeforeReset defines whether |
2157 | /// the reset for the previous hidden state occurs before/after the linear |
2158 | /// expression. |
2159 | void createOnnxGRU(llvm::StringRef namePrefix, NodeValue X, NodeValue W, |
2160 | NodeValue R, NodeValue B, NodeValue initial_h, |
2161 | NodeValue &Y, NodeValue &Y_h, unsigned hiddenSize, |
2162 | RnnDirection direction, |
2163 | std::vector<RnnActivation> &activations, |
2164 | bool linearBeforeReset = false); |
2165 | |
2166 | /// Create an unrolled multi-layer LSTM according to the ONNX definition: |
2167 | /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM |
2168 | /// The LSTM has the following inputs: |
2169 | /// - input \p X with size [S, B, ISize]. |
2170 | /// - weigts \p W with size [N, 4 * HSize, ISize]. |
2171 | /// - reccurence weights \p R with size [N, 4 * HSize, HSize]. |
2172 | /// - bias weights \p B with size [N, 8 * HSize]. |
2173 | /// - initial hidden state \p initial_h with size [N, B, HSize]. |
2174 | /// - initial cell state \p initial_c with size [N, B, HSize]. |
2175 | /// - peephole weights \p P with size [N, 3 * HSize]. |
2176 | /// where S is the sequence length, N is the number of directions, B is the |
2177 | /// batch size, ISize is the input size and HSize is the hidden size. |
2178 | /// The LSTM has the following outputs: |
2179 | /// - output \p Y with size [S, N, B, HSize]. |
2180 | /// - final hidden state \p Y_h with size [N, B, HSize]. |
2181 | /// - final cell state \p Y_c with size [N, B, HSize]. |
2182 | /// The direction of the instatiated LSTM is given by \p direction. The LSTM |
2183 | /// will use the activation functions defined by \p activations array: |
2184 | /// - [f,g,h] in case the LSTM is unidirectional (3 functions). |
2185 | /// - [f,g,h] for the forward cell followed by [f,g,h] for the reverse cell in |
2186 | /// case the LSTM is bidirectional (6 functions). |
2187 | /// The inputs \p B, \p initial_h, \p initial_c and \p P are optional (assumed |
2188 | /// 0 if nullptr is provided). The names of all the nodes created are prefixed |
2189 | /// with \p namePrefix. The boolean parameter \p inputForget defines whether |
2190 | /// the input and forget gates should be coupled (compute the input gate from |
2191 | /// the forget gate). |
2192 | void createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X, NodeValue W, |
2193 | NodeValue R, NodeValue B, NodeValue initial_h, |
2194 | NodeValue initial_c, NodeValue P, NodeValue &Y, |
2195 | NodeValue &Y_h, NodeValue &Y_c, unsigned hiddenSize, |
2196 | RnnDirection direction, |
2197 | std::vector<RnnActivation> &activations, |
2198 | bool inputForget = false); |
2199 | /// @} |
2200 | |
2201 | /// Create a TraceEvent in the runtime profile, which triggers collection of |
2202 | /// runtime statistics. |
2203 | TraceEventNode *createTraceEvent(llvm::StringRef eventName, |
2204 | llvm::StringRef eventType, Node *data, |
2205 | unsigned index); |
2206 | |
2207 | /// Creates NMSv4 node that does NMS for one class. |
2208 | /// Inputs |
2209 | /// - \p boxes Tensor with box coordinates. |
2210 | /// - \p scores Tensor with scores per box. |
2211 | /// - \p centerPointBox Indicates format of the box per ONNX spec. |
2212 | /// - \p iouThreshold Threshold for box overlap. |
2213 | /// - \p scoreThreshold Threshold for box scores. |
2214 | NonMaxSuppressionNode * |
2215 | createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes, |
2216 | NodeValue scores, int64_t centerPointBox, |
2217 | int64_t maxOutputBoxesPerClass, float iouThreshold, |
2218 | float scoreThreshold); |
2219 | |
2220 | /// Creates NMSv4 node that does NMS for one class. |
2221 | /// Inputs |
2222 | /// - \p boxes Tensor with box coordinates. |
2223 | /// - \p scores Tensor with scores per box. |
2224 | /// - \p centerPointBox Indicates format of the box per ONNX spec. |
2225 | /// - \p iouThreshold Threshold for box overlap. |
2226 | /// - \p scoreThreshold Threshold for box scores. |
2227 | /// - \p ElemKind Output ElemKind. |
2228 | NonMaxSuppressionNode * |
2229 | createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes, |
2230 | NodeValue scores, int64_t centerPointBox, |
2231 | int64_t maxOutputBoxesPerClass, float iouThreshold, |
2232 | float scoreThreshold, ElemKind elTy); |
2233 | |
2234 | /// Creates NMSv4 node that does NMS for one class. |
2235 | /// Inputs |
2236 | /// - \p boxes Tensor with box coordinates. |
2237 | /// - \p scores Tensor with scores per box. |
2238 | /// - \p centerPointBox Indicates format of the box per ONNX spec. |
2239 | /// - \p iouThreshold Threshold for box overlap. |
2240 | /// - \p scoreThreshold Threshold for box scores. |
2241 | /// - \p indicesTy Type of indices output. |
2242 | /// - \p numberOfSelectedIndicesTy \p Type of second output for number of |
2243 | /// boxes detected. |
2244 | NonMaxSuppressionNode * |
2245 | createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes, |
2246 | NodeValue scores, int64_t centerPointBox, |
2247 | int64_t maxOutputBoxesPerClass, float iouThreshold, |
2248 | float scoreThreshold, TypeRef indicesTy, |
2249 | TypeRef numberOfSelectedIndicesTy); |
2250 | |
2251 | /// Performs class wise NMS based on ONNX specification, with padding and ONNX |
2252 | /// layout output. |
2253 | /// Inputs |
2254 | /// - \p boxes Tensor with box coordinates. |
2255 | /// - \p scores Tensor with scores per box. |
2256 | /// - \p centerPointBox Indicates format of the box per ONNX spec. |
2257 | /// - \p iouThreshold Threshold for box overlap. |
2258 | /// - \p scoreThreshold Threshold for box scores. |
2259 | NonMaxSuppressionNode * |
2260 | createNonMaxSuppressionONNX(llvm::StringRef name, NodeValue boxes, |
2261 | NodeValue scores, int64_t centerPointBox, |
2262 | int64_t maxOutputBoxesPerClass, |
2263 | float iouThreshold, float scoreThreshold); |
2264 | |
2265 | /// Performs class wise NMS based on ONNX specification, with padding and ONNX |
2266 | /// layout output. |
2267 | /// Inputs |
2268 | /// - \p boxes Tensor with box coordinates. |
2269 | /// - \p scores Tensor with scores per box. |
2270 | /// - \p centerPointBox Indicates format of the box per ONNX spec. |
2271 | /// - \p iouThreshold Threshold for box overlap. |
2272 | /// - \p scoreThreshold Threshold for box scores. |
2273 | NonMaxSuppressionNode *createNonMaxSuppressionONNX( |
2274 | llvm::StringRef name, NodeValue boxes, NodeValue scores, |
2275 | int64_t centerPointBox, int64_t maxOutputBoxesPerClass, |
2276 | float iouThreshold, float scoreThreshold, ElemKind elTy); |
2277 | |
2278 | /// Performs class wise NMS based on ONNX specification, with padding and ONNX |
2279 | /// layout output. |
2280 | /// Inputs |
2281 | /// - \p boxes Tensor with box coordinates. |
2282 | /// - \p scores Tensor with scores per box. |
2283 | /// - \p centerPointBox Indicates format of the box per ONNX spec. |
2284 | /// - \p iouThreshold Threshold for box overlap. |
2285 | /// - \p scoreThreshold Threshold for box scores. |
2286 | NonMaxSuppressionNode *createNonMaxSuppressionONNX( |
2287 | llvm::StringRef name, NodeValue boxes, NodeValue scores, |
2288 | int64_t centerPointBox, int64_t maxOutputBoxesPerClass, |
2289 | float iouThreshold, float scoreThreshold, TypeRef indicesTy); |
2290 | |
2291 | /// Create a TensorFlowLite custom node called "DetectionPostProcess" which |
2292 | /// corresponds to a custom NonMaxSuppresion node. |
2293 | /// The node has the following inputs: |
2294 | /// - \p boxes with size [N, B, 4] |
2295 | /// - \p scores with size [N, B, C] |
2296 | /// - \p anchors with size [B, 4] |
2297 | /// where N is the batch size, B is the number of boxes and C is the number |
2298 | /// of classes. |
2299 | /// The node has the following attributes (parameters): |
2300 | /// - \p numClasses - Number of effective classes (without the background). |
2301 | /// - \p maxDetections - The maximum number of detections. |
2302 | /// - \p maxClassesPerDetection - Maximum classes per detection (Fast NMS). |
2303 | /// - \p maxDetectionsPerClass - Maximum detections per class (Regular NMS). |
2304 | /// - \p iouThreshold - Detection threshold for IoU metric. |
2305 | /// - \p scoreThreshold - Detection threshold for scores. |
2306 | /// - \p xScale - X scale used for decoding the boxes. |
2307 | /// - \p yScale - Y scale used for decoding the boxes. |
2308 | /// - \p hScale - H scale used for decoding the boxes. |
2309 | /// - \p wScale - W scale used for decoding the boxes. |
2310 | /// - \p regularNMS - Whether the NMS is "Regular" or "Fast". |
2311 | /// The node will have the following outputs: |
2312 | /// - DetectionBoxes - the chosen boxes (float) |
2313 | /// - DetectionClasses - the classes of the chosen boxes (int32) |
2314 | /// - DetectionScores - the scores of the chosen boxes (float) |
2315 | /// - NumDetections - number of chosen (detected) boxes (int32) |
2316 | /// The first three output tensors will be allocated using the maximum |
2317 | /// number of possible detections (worst case scenario) but the actual |
2318 | /// usage will be given by the 'NumDetections' output. |
2319 | TFLiteDetectionPostProcessNode *createTFLiteDetectionPostProcess( |
2320 | llvm::StringRef name, NodeValue boxes, NodeValue scores, |
2321 | NodeValue anchors, int32_t numClasses, int32_t maxDetections, |
2322 | int32_t maxClassesPerDetection, int32_t maxDetectionsPerClass, |
2323 | float iouThreshold, float scoreThreshold, float xScale, float yScale, |
2324 | float hScale, float wScale, bool regularNMS); |
2325 | |
2326 | /// Create a constant node with a 1D cosine windowing function defined as: |
2327 | /// w[n] = 0.5 - 0.5 * cos(2 * pi * n / N) for n = 0 .. N - 1 where N |
2328 | /// is the window \p length. The node name will be \p name. |
2329 | Constant *createCosineWindow(llvm::StringRef name, dim_t length); |
2330 | |
2331 | /// Create a constant node with the twiddle factors for a 1D complex FFT: |
2332 | /// W(N, k) = exp(-j * 2 * pi * k / N) for k = 0 ... N -1, where N is the |
2333 | /// \p fftLength. The constant node will contain 2 * \p fftLength real float |
2334 | /// values corresponding to \p fftLength complex values with the real and |
2335 | /// imaginary parts interleaved: real[0], imag[0], real[1], imag[1], etc. |
2336 | /// The node name will be \p name. |
2337 | Constant *createFFTTwiddleFactors(llvm::StringRef name, dim_t fftLength); |
2338 | |
2339 | /// Create a constant node with the bit reverse indices for a 1D FFT, that |
2340 | /// is the corresponding index obtained after reversing the bit order for |
2341 | /// each of the values k = 0 ... N -1 where N is the \p fftLength. The node |
2342 | /// will contain \p fftLength int32 values. The node name will be \p name. |
2343 | Constant *createFFTBitReverseIndices(llvm::StringRef name, dim_t fftLength); |
2344 | |
2345 | /// Create a constant node with the complex weights used to map the results |
2346 | /// of N/2 point complex FFT to a N point real FFT. This allows an efficient |
2347 | /// implementation of the N point FFT for a real data x[n] with n = 0 .. N-1 |
2348 | /// by first computing the N/2 complex FFT G[k] for the complex signal g[n] |
2349 | /// defined as g[n] = x[2*n+0] + j * x[2*n+1] with n = 0 ... N/2-1 and then |
2350 | /// computing the final N point FFT X[k] for the original data x[n] by using |
2351 | /// X[k] = G[k] * A[k] + conj(G[N/2-k]) * (1 - A[k]) for k = 0 ... N/2 (for |
2352 | /// a real signal the FFT is conjugate symmetrical and therefore only the |
2353 | /// first N/2+1 output points of X[k] should be computed, the others being |
2354 | /// redundant). The relation should also use the definitions G[N/2] = G[0] and |
2355 | /// then A[k] = 1/2 * (1 - j * exp(-j * 2 * pi * k / N)) for k = 0 ... N/2. |
2356 | /// The FFT length parameter N is given as \p fftLength. This constant node |
2357 | /// will contain the complex values of A[k] for k = 0 ... L-1 where L is the |
2358 | /// sequence length given as \p outLength (the required length L is smaller |
2359 | /// than N/2+1 since A[k] has such properties that the second half of the |
2360 | /// sequence can be easily deduced from first half). This constant node will |
2361 | /// contain 2 * \p outLength real float values corresponding to \p outLength |
2362 | /// complex values A[k] with the real and imaginary parts interleaved. |
2363 | Constant *createFFTComplexToRealWeights(llvm::StringRef name, dim_t fftLength, |
2364 | dim_t outLength); |
2365 | |
2366 | /// This node computes the spectrogram of a 1D mono audio signal \p input by |
2367 | /// extracting windows of size \p windowSize with stride \p windowStride and |
2368 | /// computing for each window the spectrum power (magnitude squared) or simply |
2369 | /// the magnitude depending on the flag \p magnitudeSquared. If the length of |
2370 | /// the \p input is [inputLength] samples then the size of the spectrogram is |
2371 | /// [windowCount, fftLength/2+1] where: |
2372 | /// - windowCount = floor((inputLength-windowSize)/windowStride)+1 is the |
2373 | /// number of windows extracted from the input. |
2374 | /// - fftLength is the FFT length used to compute the spectrogram which is the |
2375 | /// next power of 2 (e.g. for a window size of 640 the fftLength is 1024). |
2376 | /// The input audio data values are commonly float values scaled in the range |
2377 | /// [-1.0, 1.0]. If the audio data is decoded from a WAV file into int8/int16 |
2378 | /// values then those values are commonly scaled down with 2^7/2^15 before |
2379 | /// using this node. The node name will be \p name. This node is inspired from |
2380 | /// TensorFlow (tensorflow.python.ops.gen_audio_ops.audio_spectrogram). |
2381 | AudioSpectrogramNode *createAudioSpectrogram(llvm::StringRef name, |
2382 | NodeValue input, |
2383 | int64_t windowSize, |
2384 | int64_t windowStride, |
2385 | bool magnitudeSquared = true); |
2386 | |
2387 | /// Create as constants the Mel weights \p melWeights and ranges \p melRanges |
2388 | /// required for the MFCC (Mel Frequency Cepstral Coefficient) transform for a |
2389 | /// spectrogram of length \p spectrogramLength (which must be of the form |
2390 | /// 2 ^ N + 1) obtained for an audio signal with the given \p sampleRate |
2391 | /// (in Hertz) by mapping the spectrogram coefficients in \p filterBankCount |
2392 | /// bins on a Mel scale between \p lowerFrequency and \p upperFrequency |
2393 | /// (in Hertz) using a filterbank of triangular windows. The constant nodes |
2394 | /// will be named using \p prefix. |
2395 | void createMelWeights(llvm::StringRef prefix, dim_t spectrogramLength, |
2396 | float sampleRate, float lowerFrequency, |
2397 | float upperFrequency, dim_t filterBankCount, |
2398 | Constant *&melWeights, Constant *&melRanges); |
2399 | |
2400 | /// Create the DCT-II transform matrix coefficients as a constant defined as: |
2401 | /// d[k][n] = sqrt(2 / N) * cos(pi / N * (n + 1/2) * k) with n = 0 .. N - 1 |
2402 | /// and k = 0 .. K - 1 where \p N is the input data length and \p K is the |
2403 | /// output data length. The common case is that for which the input length |
2404 | /// \p N is equal to the output length \p K but a separate output length |
2405 | /// argument \p K <= \p N allows creating a partial DCT matrix used to compute |
2406 | /// only the first \p K results from the full DCT-II transform. The DCT matrix |
2407 | /// size will be \p K x \p N. The node name will be \p name. |
2408 | Constant *createDCTMat(llvm::StringRef name, dim_t N, dim_t K); |
2409 | |
2410 | /// Computes the MFCC (Mel Frequency Cepstral Coefficient) for the given |
2411 | /// \p spectrogram and is commonly used as feature extractor for voice/speech |
2412 | /// audio data in voice command or keyword spotting applications. The input |
2413 | /// \p spectrogram is a power spectrogram and not a magnitude (computed using |
2414 | /// the 'AudioSpectrogram' node with the 'magnitudeSquared' flag set to True). |
2415 | /// The MFCC transform is computed using the given \p sampleRate (in Hertz) |
2416 | /// by mapping the spectrogram coefficients in \p filterBankCount bins on a |
2417 | /// Mel scale between \p lowerFrequency and \p upperFrequency (in Hertz) using |
2418 | /// a filterbank of triangular windows, taking the natural logarithm and then |
2419 | /// keeping the first \p numCoefficients from the DCT-II transform. If the |
2420 | /// input \p spectrogram size is [windowCount, spectrogramLen] then the output |
2421 | /// node size will be [windowCount, numCoefficients] since the MFCC transform |
2422 | /// is performed separately for each window of [spectrogramLen] input samples |
2423 | /// by yielding \p numCoefficients output samples. This node is inspired from |
2424 | /// TensorFlow (tensorflow.python.ops.gen_audio_ops.mfcc). |
2425 | MFCCNode *createMFCC(llvm::StringRef name, NodeValue spectrogram, |
2426 | float sampleRate, float lowerFrequency, |
2427 | float upperFrequency, int64_t filterBankCount, |
2428 | int64_t numCoefficients); |
2429 | |
2430 | /// Performs the ROIAlign operation given the \p featureMap and the \p boxes. |
2431 | /// ROIAlign is similar to crop and resize followed by pooling. The |
2432 | /// co-ordinates to extract the crops are specified in \p boxes. Each cropped |
2433 | /// image has to be resized to have the shape specified by \p outputHeight and |
2434 | /// \p outputWidth. The \p samplingRatio specifies the number of samples to |
2435 | /// take from each bin (along both the dimensions) for the purpose of pooling. |
2436 | /// This node is defined in: |
2437 | /// (https://github.com/onnx/onnx/blob/master/docs/Operators.md#RoiAlign). |
2438 | /// \p aligned flag is an addition to Onnx definition to indicate if box |
2439 | /// coordinates are aligned to the center of a pixel (VS top-left corner). |
2440 | ROIAlignNode *createROIAlign(llvm::StringRef name, NodeValue featureMap, |
2441 | NodeValue boxes, NodeValue batchIndices, |
2442 | uint32_t outputHeight, uint32_t outputWidth, |
2443 | uint32_t samplingRatio, float spatialScale, |
2444 | bool aligned, bool rotated = false, |
2445 | PoolingMode mode = PoolingMode::AVG); |
2446 | |
2447 | /// Transform proposal bounding boxes to target bounding box using bounding |
2448 | /// box regression deltas. |
2449 | /// Inputs: |
2450 | /// \p rois - Bounding box proposals in pixel coordinates. |
2451 | /// Size (M, 4), format [x1, y1, x2, y2], or |
2452 | /// Size (M, 5), format [batch_index, x1, y1, x2, y2]. |
2453 | /// If proposals from multiple images in a batch are present, they |
2454 | /// should be grouped sequentially and in incremental order. |
2455 | /// For rotated boxes, this would have an additional angle (in degrees) |
2456 | /// in the format [<optionaal_batch_id>, ctr_x, ctr_y, w, h, angle]. |
2457 | /// \p deltas - bounding box translations and scales, |
2458 | /// size (M, 4*K), format [dx, dy, dw, dh], K = # classes. |
2459 | /// For rotated boxes, size (M, 5*K, format [dx, dy, dw, dh, da].) |
2460 | /// \p imInfo - Image dimensions, size (batch_size, 3), |
2461 | /// format [img_height, img_width, img_scale] |
2462 | /// Arguments: |
2463 | /// \p weights - vector<float> weights [wx, wy, ww, wh] for the deltas |
2464 | /// \p applyScale - transform the boxes to the scaled image space after |
2465 | /// applying the bbox deltas. Set to false to match the detectron code, set to |
2466 | /// true for keypoint models and for backward compatibility rotated - If true, |
2467 | /// then boxes (rois and deltas) include angle info to handle rotation. The |
2468 | /// format will be [ctr_x, ctr_y, width, height, angle (in degrees)]. |
2469 | /// \p angleBoundOn - If set, for rotated boxes, angle is normalized to be |
2470 | /// within [angle_bound_lo, angle_bound_hi]. |
2471 | /// \p angleBoundLo - If set, for rotated boxes, angle is normalized to be |
2472 | /// within [angle_bound_lo, angle_bound_hi]. |
2473 | /// \p angleBoundHi - If set, for rotated boxes, angle is normalized to be |
2474 | /// within [angle_bound_lo, angle_bound_hi]. |
2475 | /// \p clipAngleThresh - For RRPN, clip almost horizontal boxes within this |
2476 | /// threshold of tolerance for backward compatibility. Set to negative value |
2477 | /// for no clipping. |
2478 | /// Outputs: |
2479 | /// boxOut - Pixel coordinates of the transformed bounding boxes, |
2480 | /// Size (M, 4*K), format [x1, y1, x2, y2]. For rotated boxes, size (M, 5*K), |
2481 | /// format [ctr_x, ctr_y, w, h, angle]. |
2482 | /// roiBatchSplits - Tensor of shape (batch_size) with each element |
2483 | /// denoting the number of RoIs belonging to the corresponding image in batch |
2484 | /// See definition: |
2485 | /// https://github.com/pytorch/pytorch/blob/master/caffe2/operators/bbox_transform_op.cc#L10 |
2486 | BBoxTransformNode * |
2487 | createBBoxTransform(llvm::StringRef name, NodeValue rois, NodeValue deltas, |
2488 | NodeValue imInfo, llvm::ArrayRef<float> weights, |
2489 | bool applyScale, bool rotated, bool angleBoundOn, |
2490 | int64_t angleBoundLo, int64_t angleBoundHi, |
2491 | float clipAngleThresh, bool legacyPlusOne); |
2492 | |
2493 | /// Create an ExternFunctionCall node. \p funcImpl will contain body |
2494 | /// of or reference to the function which can be invoked. |
2495 | /// \p funcKind contains the type of function. The type of function could be |
2496 | /// source code, like OpenCL, CUDA, or could be a binary or |
2497 | /// a handle to an external function. |
2498 | ExternalFunctionCallNode * |
2499 | createExternalFunctionCall(llvm::StringRef name, TypeRef outTy, |
2500 | llvm::ArrayRef<glow::NodeValue> inputs, |
2501 | llvm::StringRef funcName, llvm::StringRef funcImpl, |
2502 | llvm::StringRef funcKind); |
2503 | |
2504 | /// Erase the node \p N from the Function. |
2505 | void eraseNode(Node *N); |
2506 | |
2507 | /// Erase the node \p I from the Function. |
2508 | void eraseNode(NodesList::iterator I); |
2509 | |
2510 | /// Clone the current function into a new function with the name \p newName in |
2511 | /// the same module. If \p map is non-null then the procedure records the |
2512 | /// mapping between the old node to the new node in \p map. If \p currToNewMap |
2513 | /// is non-null it is used as the initial state of the currToNew map inside |
2514 | /// the cloner. |
2515 | /// \returns a new function that is a copy of the current function. |
2516 | Function *clone(llvm::StringRef newName, |
2517 | llvm::DenseMap<const Node *, Node *> *map = nullptr, |
2518 | llvm::DenseMap<const Node *, Node *> *currToNewMap = nullptr); |
2519 | |
2520 | /// Clone the current function into a user-provided function \p newF. The |
2521 | /// function \p newF is not automatically added to a module by the clone call. |
2522 | /// If \p map is non-null then the procedure records the mapping between the |
2523 | /// old node to the new node in \p map. If \p currToNewMap is non-null it is |
2524 | /// used as the initial state of the currToNew map inside the cloner. \returns |
2525 | /// a user-provided function \p newF that now contains a clone of the current |
2526 | /// function. |
2527 | Function * |
2528 | clone(Function *newF, llvm::DenseMap<const Node *, Node *> *map = nullptr, |
2529 | llvm::DenseMap<const Node *, Node *> *currToNewMap = nullptr) const; |
2530 | |
2531 | /// Verify the correctness of the Function. If \p backend is provided, checks |
2532 | /// backend-specific layout requirements. Else checks the requirements based |
2533 | /// on Glow's "canonical" layout. \returns true when the function is valid. |
2534 | /// False otherwise. |
2535 | bool verify(const Backend *backend = nullptr) const; |
2536 | |
2537 | /// Dump a textual representation of the Function into provided output stream. |
2538 | void dump() const; |
2539 | |
2540 | /// Dump a textual representation of the Function to std::string. If |
2541 | /// \p skipUsersForStorage then user counts for Storage will not be dumped. |
2542 | /// If \p skipName then the name of the Function will not be dumped. |
2543 | std::string toString(bool skipUsersForStorage = false, |
2544 | bool skipName = false) const; |
2545 | |
2546 | /// \returns a hash code of the function. |
2547 | llvm::hash_code getHash() const; |
2548 | |
2549 | /// Dump a textual representation of the Function into default output stream. |
2550 | /// If \p skipUsersForStorage then user counts for Storage will not be dumped. |
2551 | /// If \p skipName then the name of the Function will not be dumped. |
2552 | void dump(llvm::raw_ostream &os, bool skipUsersForStorage = false, |
2553 | bool skipName = false) const; |
2554 | |
2555 | /// Dump a dotty graph that depicts the function into a file. |
2556 | /// \returns full path to the file. |
2557 | std::string dumpDAG(); |
2558 | |
2559 | /// Dump a dotty graph that depicts the function. |
2560 | void dumpDAG(llvm::StringRef dotFilename); |
2561 | |
2562 | /// Dump a dotty graph that depicts the function. |
2563 | void dumpDAG(const char *dotFilename); |
2564 | |
2565 | /// \returns the list of nodes that the Function owns. |
2566 | NodesList &getNodes() { return nodes_; } |
2567 | |
2568 | const NodesList &getNodes() const { return nodes_; } |
2569 | |
2570 | /// \returns a node with the name \p name or nullptr if no node was found. |
2571 | Node *getNodeByName(llvm::StringRef name); |
2572 | |
2573 | /// \returns a node value using the \p name which has the same format as the |
2574 | /// one used by the \ref NodeValue::generateNodeOutputName which is |
2575 | /// "nodeName:outputNumber". The returned node value has a nullptr for the |
2576 | /// node if not found in the Function or if the node has no outputs (for |
2577 | /// example SaveNode). The searched node value can be one of a graph node, |
2578 | /// constant or placeholder. |
2579 | NodeValue getNodeValueByName(llvm::StringRef name); |
2580 | |
2581 | /// \returns pointer to the class member for the nodes list. |
2582 | static NodesList Function::*getNodesMemberPtr() { return &Function::nodes_; } |
2583 | |
2584 | /// Randomize all of the Constants in the function. If a Constant with users |
2585 | /// in this Function also has users in other Functions then this will result |
2586 | /// in a FATAL. \p ignoredConstants is a map Kinds of nodes to the input |
2587 | /// indices for that node that should be ignored (not randomized). |
2588 | void randomizeConstants( |
2589 | const std::map<Kinded::Kind, std::set<unsigned>> &ignoredConstants = {}); |
2590 | }; |
2591 | |
2592 | struct TrainingConfig; |
2593 | |
2594 | using VariableGradientsList = |
2595 | std::list<std::pair<Placeholder *, Placeholder *>>; |
2596 | |
2597 | /// Create a new Function that 'trains' the input Function. We differentiate the |
2598 | /// nodes and insert code to update the weights based on the \p config |
2599 | /// parameters. |
2600 | /// If \p varGrads is set then instead of inserting code to update the weights, |
2601 | /// the procedure adds code to record the last gradient value: a list of |
2602 | /// (var, grad_var) pairs associating variables with their gradient variables. |
2603 | /// This feature is used by the gradient-check unit tests. |
2604 | /// \returns a new function with the name \p newFuncName. |
2605 | Function *differentiate(Function *F, const TrainingConfig &config, |
2606 | llvm::StringRef newFuncName = "" , |
2607 | VariableGradientsList *varGrads = nullptr); |
2608 | |
2609 | /// \returns the first SaveNode user of the placeholder \p PH or |
2610 | /// nullptr if none are found. |
2611 | SaveNode *getOutputSave(Function *F, Placeholder *PH); |
2612 | |
2613 | /// Clone \p node and its sources into \p newF using old-to-new mapping \p |
2614 | /// currToNew. |
2615 | Node *recursiveClone(Function *newF, Node *node, NodeMap &currToNew); |
2616 | |
2617 | /// If \p PH is an output placeholder in the Function \p F, |
2618 | /// \returns true. |
2619 | /// This is determined by checking if the PH has a user which uses the PH as an |
2620 | /// overwritten input. |
2621 | bool isOutput(const Placeholder *PH, const Function &F); |
2622 | |
2623 | /// If \p PH is an input placeholderin the Function \p F, |
2624 | /// \returns true. |
2625 | /// This is determined by checking if the PH is the input to a saveNode or is |
2626 | /// used by a non saveNode. |
2627 | bool isInput(const Placeholder *PH, const Function &F); |
2628 | |
2629 | /// Helper vectors for common transpose shuffles. |
2630 | #define NCH2NHC \ |
2631 | { 0u, 2u, 1u } |
2632 | #define NCHW2NHWC \ |
2633 | { 0u, 2u, 3u, 1u } |
2634 | #define NCTHW2NTHWC \ |
2635 | { 0u, 2u, 3u, 4u, 1u } |
2636 | #define NHWC2NCHW \ |
2637 | { 0u, 3u, 1u, 2u } |
2638 | #define NTHWC2NCTHW \ |
2639 | { 0u, 4u, 1u, 2u, 3u } |
2640 | #define HWCN2NHWC \ |
2641 | { 3u, 0u, 1u, 2u } |
2642 | #define NHWC2HWNC \ |
2643 | { 1u, 2u, 0u, 3u } |
2644 | #define CNHW2NHWC \ |
2645 | { 1u, 2u, 3u, 0u } |
2646 | #define NHWC2CHWN \ |
2647 | { 3u, 1u, 2u, 0u } |
2648 | #define CHWN2NHWC \ |
2649 | { 3u, 1u, 2u, 0u } |
2650 | #define D2S_DCR \ |
2651 | { 0u, 1u, 3u, 2u, 4u, 5u } |
2652 | #define D2S_CRD \ |
2653 | { 0u, 1u, 4u, 2u, 5u, 3u } |
2654 | |
2655 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module &mod); |
2656 | |
2657 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module *mod); |
2658 | |
2659 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &F); |
2660 | |
2661 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function *F); |
2662 | |
2663 | /// \returns whether the Convolution node \p node is equivalent with a |
2664 | /// FullyConnected node. This happens for a 2D NHWC Convolution with 1x1 filter |
2665 | /// with strides 1, pads 0, group 1 and dilations 1. |
2666 | bool isConvolutionSameAsFullyConnected(const ConvolutionNode *node, |
2667 | bool enfoceInput1x1 = false); |
2668 | |
2669 | /// \returns whether the Gemm node \p node is equivalent with a FullyConnected |
2670 | /// node. This happens when alpha and beta are 1.0 and the C operand is 1D. |
2671 | bool isGemmSameAsFullyConnected(const GemmNode *node); |
2672 | |
2673 | } // namespace glow |
2674 | |
2675 | #endif // GLOW_GRAPH_GRAPH_H |
2676 | |