1
2#include "glow/Graph/Nodes.h"
3
4
5namespace glow {
6/// Specifies a node whose Input will be copied to Output.This node prevents graph optimizations from eliminating this node and all of its ancestor nodes. Generally intended to save the final result of a network.
7class SaveNode final : public Node {
8 NodeHandle Input_;
9 NodeHandle Output_;
10
11 public:
12 enum InputIndices {
13 InputIdx = 0,
14 OutputIdx = 1,
15 };
16
17 enum ResultIndices {
18 };
19
20 SaveNode(llvm::StringRef name, NodeValue Input, NodeValue Output)
21 : Node(Kinded::Kind::SaveNodeKind, name), Input_(this, Input), Output_(this, Output) {
22 }
23 const NodeValue getInput() const { return Input_; }
24 const NodeValue getOutput() const { return Output_; }
25
26 static bool classof(const Kinded *k) {
27 return k->getKind() == Kinded::Kind::SaveNodeKind;
28 }
29
30
31 bool isOverwrittenNthInput(unsigned idx) const {
32 if (idx == 1) return true;
33 return false;
34 }
35
36 unsigned getNumInputs() const;
37 std::string getInputName(unsigned idx) const;
38 NodeValue getNthInput(unsigned idx);
39 void setNthInput(unsigned idx, NodeValue val);
40 llvm::StringRef getOutputName(unsigned idx) const;
41 bool hasSideEffects() const { return 1; }
42 bool isCanonical() const { return 1; }
43 bool isDataParallel() const { return 1; }
44 std::string getDebugDesc() const;
45 bool isEqual(const SaveNode &other) const;
46 llvm::hash_code getHash() const;
47 void visit(Node *parent, NodeWalker *visitor);
48 Node* clone() const;
49 bool verify() const;
50 Placeholder *getPlaceholder() const;};
51} // namespace glow
52
53
54namespace glow {
55/// Performs padding of a given input tensor. The Padding information must be specified for each dimension of the tensor in Pads (start and end padding). In case the padding is negative, it means that the tensor must be cropped. Mode defines how extra padding elements are created. Supported modes are defined in the PaddingMode enum: CONSTANT, REFLECT, EDGE. Value is only used with the CONSTANT mode.
56class PadNode final : public Node {
57 NodeHandle Input_;
58 unsigned_t Mode_;
59 std::vector<int> Pads_;
60 float Value_;
61
62 public:
63 enum InputIndices {
64 InputIdx = 0,
65 };
66
67 enum ResultIndices {
68 ResultIdx = 0,
69 };
70
71 PadNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Mode, std::vector<int> Pads, float Value)
72 : Node(Kinded::Kind::PadNodeKind, name), Input_(this, Input), Mode_(Mode), Pads_(Pads), Value_(Value) {
73 addResult(Result);
74 }
75 const NodeValue getInput() const { return Input_; }
76 NodeValue getResult() { return getNthResult(0); }
77 const NodeValue getResult() const { return getNthResult(0); }
78 unsigned_t getMode() const { return Mode_; }
79 llvm::ArrayRef<int> getPads() const { return Pads_; }
80 float getValue() const { return Value_; }
81
82 static bool classof(const Kinded *k) {
83 return k->getKind() == Kinded::Kind::PadNodeKind;
84 }
85
86
87 bool isOverwrittenNthInput(unsigned idx) const {
88 return false;
89 }
90
91 unsigned getNumInputs() const;
92 std::string getInputName(unsigned idx) const;
93 NodeValue getNthInput(unsigned idx);
94 void setNthInput(unsigned idx, NodeValue val);
95 llvm::StringRef getOutputName(unsigned idx) const;
96 bool hasSideEffects() const { return 0; }
97 bool isCanonical() const { return 1; }
98 bool isDataParallel() const { return 0; }
99 std::string getDebugDesc() const;
100 bool isEqual(const PadNode &other) const;
101 llvm::hash_code getHash() const;
102 void visit(Node *parent, NodeWalker *visitor);
103 Node* clone() const;
104 bool verify() const;
105};
106} // namespace glow
107
108
109namespace glow {
110class ConvolutionGradNode final : public Node {
111 NodeHandle Input_;
112 NodeHandle Filter_;
113 NodeHandle Bias_;
114 NodeHandle OriginalOutputForResult_;
115 NodeHandle GradOfOriginalOutputNamedResult_;
116 std::vector<unsigned_t> Kernels_;
117 std::vector<unsigned_t> Strides_;
118 std::vector<unsigned_t> Pads_;
119 unsigned_t Group_;
120 std::vector<unsigned_t> Dilation_;
121 glow::ConvolutionLayout Layout_;
122 glow::FusedActivation FusedActivation_;
123 std::vector<float> FusedActivationArgs_;
124
125 public:
126 enum InputIndices {
127 InputIdx = 0,
128 FilterIdx = 1,
129 BiasIdx = 2,
130 OriginalOutputForResultIdx = 3,
131 GradOfOriginalOutputNamedResultIdx = 4,
132 };
133
134 enum ResultIndices {
135 GradOfInputNamedInputIdx = 0,
136 GradOfInputNamedFilterIdx = 1,
137 GradOfInputNamedBiasIdx = 2,
138 };
139
140 ConvolutionGradNode(llvm::StringRef name, NodeValue Input, NodeValue Filter, NodeValue Bias, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group, std::vector<unsigned_t> Dilation, glow::ConvolutionLayout Layout, glow::FusedActivation FusedActivation, std::vector<float> FusedActivationArgs)
141 : Node(Kinded::Kind::ConvolutionGradNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group), Dilation_(Dilation), Layout_(Layout), FusedActivation_(FusedActivation), FusedActivationArgs_(FusedActivationArgs) {
142 addResult(Input.getType());
143 addResult(Filter.getType());
144 addResult(Bias.getType());
145 }
146 const NodeValue getInput() const { return Input_; }
147 const NodeValue getFilter() const { return Filter_; }
148 const NodeValue getBias() const { return Bias_; }
149 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
150 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
151 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
152 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
153 NodeValue getGradOfInputNamedFilter() { return getNthResult(1); }
154 const NodeValue getGradOfInputNamedFilter() const { return getNthResult(1); }
155 NodeValue getGradOfInputNamedBias() { return getNthResult(2); }
156 const NodeValue getGradOfInputNamedBias() const { return getNthResult(2); }
157 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
158 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
159 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
160 void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; }
161 unsigned_t getGroup() const { return Group_; }
162 void setGroup(unsigned_t a) {Group_ = a; }
163 llvm::ArrayRef<unsigned_t> getDilation() const { return Dilation_; }
164 glow::ConvolutionLayout getLayout() const { return Layout_; }
165 glow::FusedActivation getFusedActivation() const { return FusedActivation_; }
166 void setFusedActivation(glow::FusedActivation a) {FusedActivation_ = a; }
167 llvm::ArrayRef<float> getFusedActivationArgs() const { return FusedActivationArgs_; }
168 void setFusedActivationArgs(llvm::ArrayRef<float> a) {FusedActivationArgs_ = a; }
169
170 static bool classof(const Kinded *k) {
171 return k->getKind() == Kinded::Kind::ConvolutionGradNodeKind;
172 }
173
174
175 bool isOverwrittenNthInput(unsigned idx) const {
176 return false;
177 }
178
179 unsigned getNumInputs() const;
180 std::string getInputName(unsigned idx) const;
181 NodeValue getNthInput(unsigned idx);
182 void setNthInput(unsigned idx, NodeValue val);
183 llvm::StringRef getOutputName(unsigned idx) const;
184 bool hasSideEffects() const { return 0; }
185 bool isCanonical() const { return 1; }
186 bool isDataParallel() const { return 0; }
187 std::string getDebugDesc() const;
188 bool isEqual(const ConvolutionGradNode &other) const;
189 llvm::hash_code getHash() const;
190 void visit(Node *parent, NodeWalker *visitor);
191 Node* clone() const;
192 bool verify() const;
193};
194} // namespace glow
195
196
197namespace glow {
198/// Performs 2D Convolution using a given Input, Filter, and Bias tensors, as well as provided Kernels, Strides, Pads, Group and Dilation. Supported Layouts are defined in the ConvolutionLayout enum: NHWC and NCHW. Supported FusedActivations are defined in the FusedActivation enum.
199class ConvolutionNode final : public Node {
200 NodeHandle Input_;
201 NodeHandle Filter_;
202 NodeHandle Bias_;
203 std::vector<unsigned_t> Kernels_;
204 std::vector<unsigned_t> Strides_;
205 std::vector<unsigned_t> Pads_;
206 unsigned_t Group_;
207 std::vector<unsigned_t> Dilation_;
208 glow::ConvolutionLayout Layout_;
209 glow::FusedActivation FusedActivation_;
210 std::vector<float> FusedActivationArgs_;
211
212 public:
213 enum InputIndices {
214 InputIdx = 0,
215 FilterIdx = 1,
216 BiasIdx = 2,
217 };
218
219 enum ResultIndices {
220 ResultIdx = 0,
221 };
222
223 ConvolutionNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group, std::vector<unsigned_t> Dilation, glow::ConvolutionLayout Layout, glow::FusedActivation FusedActivation, std::vector<float> FusedActivationArgs)
224 : Node(Kinded::Kind::ConvolutionNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group), Dilation_(Dilation), Layout_(Layout), FusedActivation_(FusedActivation), FusedActivationArgs_(FusedActivationArgs) {
225 addResult(Result);
226 }
227 const NodeValue getInput() const { return Input_; }
228 const NodeValue getFilter() const { return Filter_; }
229 const NodeValue getBias() const { return Bias_; }
230 NodeValue getResult() { return getNthResult(0); }
231 const NodeValue getResult() const { return getNthResult(0); }
232 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
233 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
234 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
235 void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; }
236 unsigned_t getGroup() const { return Group_; }
237 void setGroup(unsigned_t a) {Group_ = a; }
238 llvm::ArrayRef<unsigned_t> getDilation() const { return Dilation_; }
239 glow::ConvolutionLayout getLayout() const { return Layout_; }
240 glow::FusedActivation getFusedActivation() const { return FusedActivation_; }
241 void setFusedActivation(glow::FusedActivation a) {FusedActivation_ = a; }
242 llvm::ArrayRef<float> getFusedActivationArgs() const { return FusedActivationArgs_; }
243 void setFusedActivationArgs(llvm::ArrayRef<float> a) {FusedActivationArgs_ = a; }
244
245 static bool classof(const Kinded *k) {
246 return k->getKind() == Kinded::Kind::ConvolutionNodeKind;
247 }
248
249
250 bool isOverwrittenNthInput(unsigned idx) const {
251 return false;
252 }
253
254 unsigned getNumInputs() const;
255 std::string getInputName(unsigned idx) const;
256 NodeValue getNthInput(unsigned idx);
257 void setNthInput(unsigned idx, NodeValue val);
258 llvm::StringRef getOutputName(unsigned idx) const;
259 bool hasSideEffects() const { return 0; }
260 bool isCanonical() const { return 1; }
261 bool isDataParallel() const { return 0; }
262 std::string getDebugDesc() const;
263 bool isEqual(const ConvolutionNode &other) const;
264 llvm::hash_code getHash() const;
265 void visit(Node *parent, NodeWalker *visitor);
266 Node* clone() const;
267 bool verify() const;
268 bool hasFusedActivation() const; ConvolutionGradNode *getGrad(GraphGradMapper &builder);
269};
270} // namespace glow
271
272
273namespace glow {
274/// Performs 2D Convolution using a given Input, Filter, and Bias tensors, as well as provided Kernels, Strides, Pads, and Group. The filter channel wise quantization parameters are provided by FilterScales and FilterOffsets while the bias channel wise quantization parameters are provided by BiasScales and BiasOffsets.
275class ChannelwiseQuantizedConvolutionNode final : public Node {
276 NodeHandle Input_;
277 NodeHandle Filter_;
278 NodeHandle Bias_;
279 NodeHandle FilterScales_;
280 NodeHandle FilterOffsets_;
281 NodeHandle BiasScales_;
282 NodeHandle BiasOffsets_;
283 std::vector<unsigned_t> Kernels_;
284 std::vector<unsigned_t> Strides_;
285 std::vector<unsigned_t> Pads_;
286 unsigned_t Group_;
287 std::vector<unsigned_t> Dilation_;
288 glow::FusedActivation FusedActivation_;
289 std::vector<float> FusedActivationArgs_;
290
291 public:
292 enum InputIndices {
293 InputIdx = 0,
294 FilterIdx = 1,
295 BiasIdx = 2,
296 FilterScalesIdx = 3,
297 FilterOffsetsIdx = 4,
298 BiasScalesIdx = 5,
299 BiasOffsetsIdx = 6,
300 };
301
302 enum ResultIndices {
303 ResultIdx = 0,
304 };
305
306 ChannelwiseQuantizedConvolutionNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, NodeValue FilterScales, NodeValue FilterOffsets, NodeValue BiasScales, NodeValue BiasOffsets, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group, std::vector<unsigned_t> Dilation, glow::FusedActivation FusedActivation, std::vector<float> FusedActivationArgs)
307 : Node(Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), FilterScales_(this, FilterScales), FilterOffsets_(this, FilterOffsets), BiasScales_(this, BiasScales), BiasOffsets_(this, BiasOffsets), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group), Dilation_(Dilation), FusedActivation_(FusedActivation), FusedActivationArgs_(FusedActivationArgs) {
308 addResult(Result);
309 }
310 const NodeValue getInput() const { return Input_; }
311 const NodeValue getFilter() const { return Filter_; }
312 const NodeValue getBias() const { return Bias_; }
313 const NodeValue getFilterScales() const { return FilterScales_; }
314 const NodeValue getFilterOffsets() const { return FilterOffsets_; }
315 const NodeValue getBiasScales() const { return BiasScales_; }
316 const NodeValue getBiasOffsets() const { return BiasOffsets_; }
317 NodeValue getResult() { return getNthResult(0); }
318 const NodeValue getResult() const { return getNthResult(0); }
319 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
320 void setKernels(llvm::ArrayRef<unsigned_t> a) {Kernels_ = a; }
321 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
322 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
323 void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; }
324 unsigned_t getGroup() const { return Group_; }
325 void setGroup(unsigned_t a) {Group_ = a; }
326 llvm::ArrayRef<unsigned_t> getDilation() const { return Dilation_; }
327 glow::FusedActivation getFusedActivation() const { return FusedActivation_; }
328 void setFusedActivation(glow::FusedActivation a) {FusedActivation_ = a; }
329 llvm::ArrayRef<float> getFusedActivationArgs() const { return FusedActivationArgs_; }
330 void setFusedActivationArgs(llvm::ArrayRef<float> a) {FusedActivationArgs_ = a; }
331
332 static bool classof(const Kinded *k) {
333 return k->getKind() == Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind;
334 }
335
336
337 bool isOverwrittenNthInput(unsigned idx) const {
338 return false;
339 }
340
341 unsigned getNumInputs() const;
342 std::string getInputName(unsigned idx) const;
343 NodeValue getNthInput(unsigned idx);
344 void setNthInput(unsigned idx, NodeValue val);
345 llvm::StringRef getOutputName(unsigned idx) const;
346 bool hasSideEffects() const { return 0; }
347 bool isCanonical() const { return 1; }
348 bool isDataParallel() const { return 0; }
349 std::string getDebugDesc() const;
350 bool isEqual(const ChannelwiseQuantizedConvolutionNode &other) const;
351 llvm::hash_code getHash() const;
352 void visit(Node *parent, NodeWalker *visitor);
353 Node* clone() const;
354 bool verify() const;
355 bool hasFusedActivation() const;};
356} // namespace glow
357
358
359namespace glow {
360/// Performs 2D Transposed Convolution using a given Input,Filter, and Bias tensors, as well as provided Kernels,Strides, Pads, and Group.
361class ConvTransposeNode final : public Node {
362 NodeHandle Input_;
363 NodeHandle Filter_;
364 NodeHandle Bias_;
365 std::vector<unsigned_t> Kernels_;
366 std::vector<unsigned_t> Strides_;
367 std::vector<unsigned_t> Pads_;
368 unsigned_t Group_;
369 std::vector<unsigned_t> Dilation_;
370
371 public:
372 enum InputIndices {
373 InputIdx = 0,
374 FilterIdx = 1,
375 BiasIdx = 2,
376 };
377
378 enum ResultIndices {
379 ResultIdx = 0,
380 };
381
382 ConvTransposeNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group, std::vector<unsigned_t> Dilation)
383 : Node(Kinded::Kind::ConvTransposeNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group), Dilation_(Dilation) {
384 addResult(Result);
385 }
386 const NodeValue getInput() const { return Input_; }
387 const NodeValue getFilter() const { return Filter_; }
388 const NodeValue getBias() const { return Bias_; }
389 NodeValue getResult() { return getNthResult(0); }
390 const NodeValue getResult() const { return getNthResult(0); }
391 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
392 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
393 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
394 unsigned_t getGroup() const { return Group_; }
395 llvm::ArrayRef<unsigned_t> getDilation() const { return Dilation_; }
396
397 static bool classof(const Kinded *k) {
398 return k->getKind() == Kinded::Kind::ConvTransposeNodeKind;
399 }
400
401
402 bool isOverwrittenNthInput(unsigned idx) const {
403 return false;
404 }
405
406 unsigned getNumInputs() const;
407 std::string getInputName(unsigned idx) const;
408 NodeValue getNthInput(unsigned idx);
409 void setNthInput(unsigned idx, NodeValue val);
410 llvm::StringRef getOutputName(unsigned idx) const;
411 bool hasSideEffects() const { return 0; }
412 bool isCanonical() const { return 1; }
413 bool isDataParallel() const { return 0; }
414 std::string getDebugDesc() const;
415 bool isEqual(const ConvTransposeNode &other) const;
416 llvm::hash_code getHash() const;
417 void visit(Node *parent, NodeWalker *visitor);
418 Node* clone() const;
419 bool verify() const;
420};
421} // namespace glow
422
423
424namespace glow {
425class Convolution3DGradNode final : public Node {
426 NodeHandle Input_;
427 NodeHandle Filter_;
428 NodeHandle Bias_;
429 NodeHandle OriginalOutputForResult_;
430 NodeHandle GradOfOriginalOutputNamedResult_;
431 std::vector<unsigned_t> Kernels_;
432 std::vector<unsigned_t> Strides_;
433 std::vector<unsigned_t> Pads_;
434 unsigned_t Group_;
435
436 public:
437 enum InputIndices {
438 InputIdx = 0,
439 FilterIdx = 1,
440 BiasIdx = 2,
441 OriginalOutputForResultIdx = 3,
442 GradOfOriginalOutputNamedResultIdx = 4,
443 };
444
445 enum ResultIndices {
446 GradOfInputNamedInputIdx = 0,
447 GradOfInputNamedFilterIdx = 1,
448 GradOfInputNamedBiasIdx = 2,
449 };
450
451 Convolution3DGradNode(llvm::StringRef name, NodeValue Input, NodeValue Filter, NodeValue Bias, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group)
452 : Node(Kinded::Kind::Convolution3DGradNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group) {
453 addResult(Input.getType());
454 addResult(Filter.getType());
455 addResult(Bias.getType());
456 }
457 const NodeValue getInput() const { return Input_; }
458 const NodeValue getFilter() const { return Filter_; }
459 const NodeValue getBias() const { return Bias_; }
460 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
461 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
462 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
463 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
464 NodeValue getGradOfInputNamedFilter() { return getNthResult(1); }
465 const NodeValue getGradOfInputNamedFilter() const { return getNthResult(1); }
466 NodeValue getGradOfInputNamedBias() { return getNthResult(2); }
467 const NodeValue getGradOfInputNamedBias() const { return getNthResult(2); }
468 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
469 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
470 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
471 unsigned_t getGroup() const { return Group_; }
472
473 static bool classof(const Kinded *k) {
474 return k->getKind() == Kinded::Kind::Convolution3DGradNodeKind;
475 }
476
477
478 bool isOverwrittenNthInput(unsigned idx) const {
479 return false;
480 }
481
482 unsigned getNumInputs() const;
483 std::string getInputName(unsigned idx) const;
484 NodeValue getNthInput(unsigned idx);
485 void setNthInput(unsigned idx, NodeValue val);
486 llvm::StringRef getOutputName(unsigned idx) const;
487 bool hasSideEffects() const { return 0; }
488 bool isCanonical() const { return 1; }
489 bool isDataParallel() const { return 0; }
490 std::string getDebugDesc() const;
491 bool isEqual(const Convolution3DGradNode &other) const;
492 llvm::hash_code getHash() const;
493 void visit(Node *parent, NodeWalker *visitor);
494 Node* clone() const;
495 bool verify() const;
496};
497} // namespace glow
498
499
500namespace glow {
501/// Performs 3D Convolution using a given Input, Filter, and Bias tensors, as well as provided Kernels, Strides, Pads, and Group.
502class Convolution3DNode final : public Node {
503 NodeHandle Input_;
504 NodeHandle Filter_;
505 NodeHandle Bias_;
506 std::vector<unsigned_t> Kernels_;
507 std::vector<unsigned_t> Strides_;
508 std::vector<unsigned_t> Pads_;
509 unsigned_t Group_;
510
511 public:
512 enum InputIndices {
513 InputIdx = 0,
514 FilterIdx = 1,
515 BiasIdx = 2,
516 };
517
518 enum ResultIndices {
519 ResultIdx = 0,
520 };
521
522 Convolution3DNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group)
523 : Node(Kinded::Kind::Convolution3DNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group) {
524 addResult(Result);
525 }
526 const NodeValue getInput() const { return Input_; }
527 const NodeValue getFilter() const { return Filter_; }
528 const NodeValue getBias() const { return Bias_; }
529 NodeValue getResult() { return getNthResult(0); }
530 const NodeValue getResult() const { return getNthResult(0); }
531 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
532 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
533 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
534 unsigned_t getGroup() const { return Group_; }
535
536 static bool classof(const Kinded *k) {
537 return k->getKind() == Kinded::Kind::Convolution3DNodeKind;
538 }
539
540
541 bool isOverwrittenNthInput(unsigned idx) const {
542 return false;
543 }
544
545 unsigned getNumInputs() const;
546 std::string getInputName(unsigned idx) const;
547 NodeValue getNthInput(unsigned idx);
548 void setNthInput(unsigned idx, NodeValue val);
549 llvm::StringRef getOutputName(unsigned idx) const;
550 bool hasSideEffects() const { return 0; }
551 bool isCanonical() const { return 1; }
552 bool isDataParallel() const { return 0; }
553 std::string getDebugDesc() const;
554 bool isEqual(const Convolution3DNode &other) const;
555 llvm::hash_code getHash() const;
556 void visit(Node *parent, NodeWalker *visitor);
557 Node* clone() const;
558 bool verify() const;
559 Convolution3DGradNode *getGrad(GraphGradMapper &builder);
560};
561} // namespace glow
562
563
564namespace glow {
565class MaxPoolGradNode final : public Node {
566 NodeHandle Input_;
567 NodeHandle OriginalOutputForResult_;
568 NodeHandle GradOfOriginalOutputNamedResult_;
569 NodeHandle OriginalOutputForArgmax_;
570 NodeHandle GradOfOriginalOutputNamedArgmax_;
571 std::vector<unsigned_t> Kernels_;
572 std::vector<unsigned_t> Strides_;
573 std::vector<unsigned_t> Pads_;
574 unsigned_t Layout_;
575
576 public:
577 enum InputIndices {
578 InputIdx = 0,
579 OriginalOutputForResultIdx = 1,
580 GradOfOriginalOutputNamedResultIdx = 2,
581 OriginalOutputForArgmaxIdx = 3,
582 GradOfOriginalOutputNamedArgmaxIdx = 4,
583 };
584
585 enum ResultIndices {
586 GradOfInputNamedInputIdx = 0,
587 };
588
589 MaxPoolGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, NodeValue OriginalOutputForArgmax, NodeValue GradOfOriginalOutputNamedArgmax, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Layout)
590 : Node(Kinded::Kind::MaxPoolGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), OriginalOutputForArgmax_(this, OriginalOutputForArgmax), GradOfOriginalOutputNamedArgmax_(this, GradOfOriginalOutputNamedArgmax), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Layout_(Layout) {
591 addResult(Input.getType());
592 }
593 const NodeValue getInput() const { return Input_; }
594 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
595 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
596 const NodeValue getOriginalOutputForArgmax() const { return OriginalOutputForArgmax_; }
597 const NodeValue getGradOfOriginalOutputNamedArgmax() const { return GradOfOriginalOutputNamedArgmax_; }
598 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
599 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
600 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
601 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
602 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
603 void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; }
604 unsigned_t getLayout() const { return Layout_; }
605
606 static bool classof(const Kinded *k) {
607 return k->getKind() == Kinded::Kind::MaxPoolGradNodeKind;
608 }
609
610
611 bool isOverwrittenNthInput(unsigned idx) const {
612 return false;
613 }
614
615 unsigned getNumInputs() const;
616 std::string getInputName(unsigned idx) const;
617 NodeValue getNthInput(unsigned idx);
618 void setNthInput(unsigned idx, NodeValue val);
619 llvm::StringRef getOutputName(unsigned idx) const;
620 bool hasSideEffects() const { return 0; }
621 bool isCanonical() const { return 1; }
622 bool isDataParallel() const { return 0; }
623 std::string getDebugDesc() const;
624 bool isEqual(const MaxPoolGradNode &other) const;
625 llvm::hash_code getHash() const;
626 void visit(Node *parent, NodeWalker *visitor);
627 Node* clone() const;
628 bool verify() const;
629};
630} // namespace glow
631
632
633namespace glow {
634/// Performs a Max Pool with Argmax operation on the Input given provided Kernels, Strides, and Pads. Argmax is a flattened index corresponding to respective max element. Supported layouts are defined in the ConvolutionLayout enum: NHWC and NCHW.
635class MaxPoolNode final : public Node {
636 NodeHandle Input_;
637 std::vector<unsigned_t> Kernels_;
638 std::vector<unsigned_t> Strides_;
639 std::vector<unsigned_t> Pads_;
640 unsigned_t Layout_;
641
642 public:
643 enum InputIndices {
644 InputIdx = 0,
645 };
646
647 enum ResultIndices {
648 ResultIdx = 0,
649 ArgmaxIdx = 1,
650 };
651
652 MaxPoolNode(llvm::StringRef name, TypeRef Result , TypeRef Argmax , NodeValue Input, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Layout)
653 : Node(Kinded::Kind::MaxPoolNodeKind, name), Input_(this, Input), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Layout_(Layout) {
654 addResult(Result);
655 addResult(Argmax);
656 }
657 const NodeValue getInput() const { return Input_; }
658 NodeValue getResult() { return getNthResult(0); }
659 const NodeValue getResult() const { return getNthResult(0); }
660 NodeValue getArgmax() { return getNthResult(1); }
661 const NodeValue getArgmax() const { return getNthResult(1); }
662 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
663 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
664 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
665 void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; }
666 unsigned_t getLayout() const { return Layout_; }
667
668 static bool classof(const Kinded *k) {
669 return k->getKind() == Kinded::Kind::MaxPoolNodeKind;
670 }
671
672
673 bool isOverwrittenNthInput(unsigned idx) const {
674 return false;
675 }
676
677 unsigned getNumInputs() const;
678 std::string getInputName(unsigned idx) const;
679 NodeValue getNthInput(unsigned idx);
680 void setNthInput(unsigned idx, NodeValue val);
681 llvm::StringRef getOutputName(unsigned idx) const;
682 bool hasSideEffects() const { return 0; }
683 bool isCanonical() const { return 1; }
684 bool isDataParallel() const { return 0; }
685 std::string getDebugDesc() const;
686 bool isEqual(const MaxPoolNode &other) const;
687 llvm::hash_code getHash() const;
688 void visit(Node *parent, NodeWalker *visitor);
689 Node* clone() const;
690 bool verify() const;
691 MaxPoolGradNode *getGrad(GraphGradMapper &builder);
692};
693} // namespace glow
694
695
696namespace glow {
697/// Finds index of a maximum element along Axis. If KeepDims is not true, the axis is removed from output
698class ArgMaxNode final : public Node {
699 NodeHandle Input_;
700 unsigned_t Axis_;
701 bool KeepDims_;
702
703 public:
704 enum InputIndices {
705 InputIdx = 0,
706 };
707
708 enum ResultIndices {
709 ResultIdx = 0,
710 };
711
712 ArgMaxNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis, bool KeepDims)
713 : Node(Kinded::Kind::ArgMaxNodeKind, name), Input_(this, Input), Axis_(Axis), KeepDims_(KeepDims) {
714 addResult(Result);
715 }
716 const NodeValue getInput() const { return Input_; }
717 NodeValue getResult() { return getNthResult(0); }
718 const NodeValue getResult() const { return getNthResult(0); }
719 unsigned_t getAxis() const { return Axis_; }
720 bool getKeepDims() const { return KeepDims_; }
721
722 static bool classof(const Kinded *k) {
723 return k->getKind() == Kinded::Kind::ArgMaxNodeKind;
724 }
725
726
727 bool isOverwrittenNthInput(unsigned idx) const {
728 return false;
729 }
730
731 unsigned getNumInputs() const;
732 std::string getInputName(unsigned idx) const;
733 NodeValue getNthInput(unsigned idx);
734 void setNthInput(unsigned idx, NodeValue val);
735 llvm::StringRef getOutputName(unsigned idx) const;
736 bool hasSideEffects() const { return 0; }
737 bool isCanonical() const { return 1; }
738 bool isDataParallel() const { return 0; }
739 std::string getDebugDesc() const;
740 bool isEqual(const ArgMaxNode &other) const;
741 llvm::hash_code getHash() const;
742 void visit(Node *parent, NodeWalker *visitor);
743 Node* clone() const;
744 bool verify() const;
745};
746} // namespace glow
747
748
749namespace glow {
750/// Finds index of a minimum element along Axis. If KeepDims is not true, the axis is removed from output
751class ArgMinNode final : public Node {
752 NodeHandle Input_;
753 unsigned_t Axis_;
754 bool KeepDims_;
755
756 public:
757 enum InputIndices {
758 InputIdx = 0,
759 };
760
761 enum ResultIndices {
762 ResultIdx = 0,
763 };
764
765 ArgMinNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis, bool KeepDims)
766 : Node(Kinded::Kind::ArgMinNodeKind, name), Input_(this, Input), Axis_(Axis), KeepDims_(KeepDims) {
767 addResult(Result);
768 }
769 const NodeValue getInput() const { return Input_; }
770 NodeValue getResult() { return getNthResult(0); }
771 const NodeValue getResult() const { return getNthResult(0); }
772 unsigned_t getAxis() const { return Axis_; }
773 bool getKeepDims() const { return KeepDims_; }
774
775 static bool classof(const Kinded *k) {
776 return k->getKind() == Kinded::Kind::ArgMinNodeKind;
777 }
778
779
780 bool isOverwrittenNthInput(unsigned idx) const {
781 return false;
782 }
783
784 unsigned getNumInputs() const;
785 std::string getInputName(unsigned idx) const;
786 NodeValue getNthInput(unsigned idx);
787 void setNthInput(unsigned idx, NodeValue val);
788 llvm::StringRef getOutputName(unsigned idx) const;
789 bool hasSideEffects() const { return 0; }
790 bool isCanonical() const { return 1; }
791 bool isDataParallel() const { return 0; }
792 std::string getDebugDesc() const;
793 bool isEqual(const ArgMinNode &other) const;
794 llvm::hash_code getHash() const;
795 void visit(Node *parent, NodeWalker *visitor);
796 Node* clone() const;
797 bool verify() const;
798};
799} // namespace glow
800
801
802namespace glow {
803class AvgPoolGradNode final : public Node {
804 NodeHandle Input_;
805 NodeHandle OriginalOutputForResult_;
806 NodeHandle GradOfOriginalOutputNamedResult_;
807 std::vector<unsigned_t> Kernels_;
808 std::vector<unsigned_t> Strides_;
809 std::vector<unsigned_t> Pads_;
810 unsigned_t Layout_;
811 bool CountIncludePads_;
812
813 public:
814 enum InputIndices {
815 InputIdx = 0,
816 OriginalOutputForResultIdx = 1,
817 GradOfOriginalOutputNamedResultIdx = 2,
818 };
819
820 enum ResultIndices {
821 GradOfInputNamedInputIdx = 0,
822 };
823
824 AvgPoolGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Layout, bool CountIncludePads)
825 : Node(Kinded::Kind::AvgPoolGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Layout_(Layout), CountIncludePads_(CountIncludePads) {
826 addResult(Input.getType());
827 }
828 const NodeValue getInput() const { return Input_; }
829 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
830 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
831 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
832 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
833 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
834 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
835 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
836 void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; }
837 unsigned_t getLayout() const { return Layout_; }
838 bool getCountIncludePads() const { return CountIncludePads_; }
839
840 static bool classof(const Kinded *k) {
841 return k->getKind() == Kinded::Kind::AvgPoolGradNodeKind;
842 }
843
844
845 bool isOverwrittenNthInput(unsigned idx) const {
846 return false;
847 }
848
849 unsigned getNumInputs() const;
850 std::string getInputName(unsigned idx) const;
851 NodeValue getNthInput(unsigned idx);
852 void setNthInput(unsigned idx, NodeValue val);
853 llvm::StringRef getOutputName(unsigned idx) const;
854 bool hasSideEffects() const { return 0; }
855 bool isCanonical() const { return 1; }
856 bool isDataParallel() const { return 0; }
857 std::string getDebugDesc() const;
858 bool isEqual(const AvgPoolGradNode &other) const;
859 llvm::hash_code getHash() const;
860 void visit(Node *parent, NodeWalker *visitor);
861 Node* clone() const;
862 bool verify() const;
863};
864} // namespace glow
865
866
867namespace glow {
868/// Performs an Average Pool operation on the Input given provided Kernels, Strides, and Pads. Supported layouts are defined in the ConvolutionLayout enum: NHWC, NCHW, NTHWC and NCTHW.
869class AvgPoolNode final : public Node {
870 NodeHandle Input_;
871 std::vector<unsigned_t> Kernels_;
872 std::vector<unsigned_t> Strides_;
873 std::vector<unsigned_t> Pads_;
874 unsigned_t Layout_;
875 bool CountIncludePads_;
876
877 public:
878 enum InputIndices {
879 InputIdx = 0,
880 };
881
882 enum ResultIndices {
883 ResultIdx = 0,
884 };
885
886 AvgPoolNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Layout, bool CountIncludePads)
887 : Node(Kinded::Kind::AvgPoolNodeKind, name), Input_(this, Input), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Layout_(Layout), CountIncludePads_(CountIncludePads) {
888 addResult(Result);
889 }
890 const NodeValue getInput() const { return Input_; }
891 NodeValue getResult() { return getNthResult(0); }
892 const NodeValue getResult() const { return getNthResult(0); }
893 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
894 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
895 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
896 void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; }
897 unsigned_t getLayout() const { return Layout_; }
898 bool getCountIncludePads() const { return CountIncludePads_; }
899
900 static bool classof(const Kinded *k) {
901 return k->getKind() == Kinded::Kind::AvgPoolNodeKind;
902 }
903
904
905 bool isOverwrittenNthInput(unsigned idx) const {
906 return false;
907 }
908
909 unsigned getNumInputs() const;
910 std::string getInputName(unsigned idx) const;
911 NodeValue getNthInput(unsigned idx);
912 void setNthInput(unsigned idx, NodeValue val);
913 llvm::StringRef getOutputName(unsigned idx) const;
914 bool hasSideEffects() const { return 0; }
915 bool isCanonical() const { return 1; }
916 bool isDataParallel() const { return 0; }
917 std::string getDebugDesc() const;
918 bool isEqual(const AvgPoolNode &other) const;
919 llvm::hash_code getHash() const;
920 void visit(Node *parent, NodeWalker *visitor);
921 Node* clone() const;
922 bool verify() const;
923 AvgPoolGradNode *getGrad(GraphGradMapper &builder);
924};
925} // namespace glow
926
927
928namespace glow {
929class AdaptiveAvgPoolGradNode final : public Node {
930 NodeHandle Input_;
931 NodeHandle OriginalOutputForResult_;
932 NodeHandle GradOfOriginalOutputNamedResult_;
933
934 public:
935 enum InputIndices {
936 InputIdx = 0,
937 OriginalOutputForResultIdx = 1,
938 GradOfOriginalOutputNamedResultIdx = 2,
939 };
940
941 enum ResultIndices {
942 GradOfInputNamedInputIdx = 0,
943 };
944
945 AdaptiveAvgPoolGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
946 : Node(Kinded::Kind::AdaptiveAvgPoolGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
947 addResult(Input.getType());
948 }
949 const NodeValue getInput() const { return Input_; }
950 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
951 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
952 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
953 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
954
955 static bool classof(const Kinded *k) {
956 return k->getKind() == Kinded::Kind::AdaptiveAvgPoolGradNodeKind;
957 }
958
959
960 bool isOverwrittenNthInput(unsigned idx) const {
961 return false;
962 }
963
964 unsigned getNumInputs() const;
965 std::string getInputName(unsigned idx) const;
966 NodeValue getNthInput(unsigned idx);
967 void setNthInput(unsigned idx, NodeValue val);
968 llvm::StringRef getOutputName(unsigned idx) const;
969 bool hasSideEffects() const { return 0; }
970 bool isCanonical() const { return 1; }
971 bool isDataParallel() const { return 0; }
972 std::string getDebugDesc() const;
973 bool isEqual(const AdaptiveAvgPoolGradNode &other) const;
974 llvm::hash_code getHash() const;
975 void visit(Node *parent, NodeWalker *visitor);
976 Node* clone() const;
977 bool verify() const;
978};
979} // namespace glow
980
981
982namespace glow {
983/// Performs an Adaptive Average Pool operation on the Input given
984class AdaptiveAvgPoolNode final : public Node {
985 NodeHandle Input_;
986
987 public:
988 enum InputIndices {
989 InputIdx = 0,
990 };
991
992 enum ResultIndices {
993 ResultIdx = 0,
994 };
995
996 AdaptiveAvgPoolNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
997 : Node(Kinded::Kind::AdaptiveAvgPoolNodeKind, name), Input_(this, Input) {
998 addResult(Result);
999 }
1000 const NodeValue getInput() const { return Input_; }
1001 NodeValue getResult() { return getNthResult(0); }
1002 const NodeValue getResult() const { return getNthResult(0); }
1003
1004 static bool classof(const Kinded *k) {
1005 return k->getKind() == Kinded::Kind::AdaptiveAvgPoolNodeKind;
1006 }
1007
1008
1009 bool isOverwrittenNthInput(unsigned idx) const {
1010 return false;
1011 }
1012
1013 unsigned getNumInputs() const;
1014 std::string getInputName(unsigned idx) const;
1015 NodeValue getNthInput(unsigned idx);
1016 void setNthInput(unsigned idx, NodeValue val);
1017 llvm::StringRef getOutputName(unsigned idx) const;
1018 bool hasSideEffects() const { return 0; }
1019 bool isCanonical() const { return 1; }
1020 bool isDataParallel() const { return 0; }
1021 std::string getDebugDesc() const;
1022 bool isEqual(const AdaptiveAvgPoolNode &other) const;
1023 llvm::hash_code getHash() const;
1024 void visit(Node *parent, NodeWalker *visitor);
1025 Node* clone() const;
1026 bool verify() const;
1027 AdaptiveAvgPoolGradNode *getGrad(GraphGradMapper &builder);
1028};
1029} // namespace glow
1030
1031
1032namespace glow {
1033/// Computes Y = Alpha * A * B + Beta * C where Alpha, Beta are scalars and A, B, C are matrices. If TransposeA or TransposeB is used then A or B is additionally transposed.
1034class GemmNode final : public Node {
1035 NodeHandle A_;
1036 NodeHandle B_;
1037 NodeHandle C_;
1038 float Alpha_;
1039 float Beta_;
1040 bool TransposeA_;
1041 bool TransposeB_;
1042
1043 public:
1044 enum InputIndices {
1045 AIdx = 0,
1046 BIdx = 1,
1047 CIdx = 2,
1048 };
1049
1050 enum ResultIndices {
1051 ResultIdx = 0,
1052 };
1053
1054 GemmNode(llvm::StringRef name, TypeRef Result , NodeValue A, NodeValue B, NodeValue C, float Alpha, float Beta, bool TransposeA, bool TransposeB)
1055 : Node(Kinded::Kind::GemmNodeKind, name), A_(this, A), B_(this, B), C_(this, C), Alpha_(Alpha), Beta_(Beta), TransposeA_(TransposeA), TransposeB_(TransposeB) {
1056 addResult(Result);
1057 }
1058 const NodeValue getA() const { return A_; }
1059 const NodeValue getB() const { return B_; }
1060 const NodeValue getC() const { return C_; }
1061 NodeValue getResult() { return getNthResult(0); }
1062 const NodeValue getResult() const { return getNthResult(0); }
1063 float getAlpha() const { return Alpha_; }
1064 float getBeta() const { return Beta_; }
1065 bool getTransposeA() const { return TransposeA_; }
1066 bool getTransposeB() const { return TransposeB_; }
1067
1068 static bool classof(const Kinded *k) {
1069 return k->getKind() == Kinded::Kind::GemmNodeKind;
1070 }
1071
1072
1073 bool isOverwrittenNthInput(unsigned idx) const {
1074 return false;
1075 }
1076
1077 unsigned getNumInputs() const;
1078 std::string getInputName(unsigned idx) const;
1079 NodeValue getNthInput(unsigned idx);
1080 void setNthInput(unsigned idx, NodeValue val);
1081 llvm::StringRef getOutputName(unsigned idx) const;
1082 bool hasSideEffects() const { return 0; }
1083 bool isCanonical() const { return 1; }
1084 bool isDataParallel() const { return 0; }
1085 std::string getDebugDesc() const;
1086 bool isEqual(const GemmNode &other) const;
1087 llvm::hash_code getHash() const;
1088 void visit(Node *parent, NodeWalker *visitor);
1089 Node* clone() const;
1090 bool verify() const;
1091};
1092} // namespace glow
1093
1094
1095namespace glow {
1096class FullyConnectedGradNode final : public Node {
1097 NodeHandle Input_;
1098 NodeHandle Weights_;
1099 NodeHandle Bias_;
1100 NodeHandle OriginalOutputForResult_;
1101 NodeHandle GradOfOriginalOutputNamedResult_;
1102
1103 public:
1104 enum InputIndices {
1105 InputIdx = 0,
1106 WeightsIdx = 1,
1107 BiasIdx = 2,
1108 OriginalOutputForResultIdx = 3,
1109 GradOfOriginalOutputNamedResultIdx = 4,
1110 };
1111
1112 enum ResultIndices {
1113 GradOfInputNamedInputIdx = 0,
1114 GradOfInputNamedWeightsIdx = 1,
1115 GradOfInputNamedBiasIdx = 2,
1116 };
1117
1118 FullyConnectedGradNode(llvm::StringRef name, NodeValue Input, NodeValue Weights, NodeValue Bias, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
1119 : Node(Kinded::Kind::FullyConnectedGradNodeKind, name), Input_(this, Input), Weights_(this, Weights), Bias_(this, Bias), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
1120 addResult(Input.getType());
1121 addResult(Weights.getType());
1122 addResult(Bias.getType());
1123 }
1124 const NodeValue getInput() const { return Input_; }
1125 const NodeValue getWeights() const { return Weights_; }
1126 const NodeValue getBias() const { return Bias_; }
1127 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
1128 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
1129 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
1130 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
1131 NodeValue getGradOfInputNamedWeights() { return getNthResult(1); }
1132 const NodeValue getGradOfInputNamedWeights() const { return getNthResult(1); }
1133 NodeValue getGradOfInputNamedBias() { return getNthResult(2); }
1134 const NodeValue getGradOfInputNamedBias() const { return getNthResult(2); }
1135
1136 static bool classof(const Kinded *k) {
1137 return k->getKind() == Kinded::Kind::FullyConnectedGradNodeKind;
1138 }
1139
1140
1141 bool isOverwrittenNthInput(unsigned idx) const {
1142 return false;
1143 }
1144
1145 unsigned getNumInputs() const;
1146 std::string getInputName(unsigned idx) const;
1147 NodeValue getNthInput(unsigned idx);
1148 void setNthInput(unsigned idx, NodeValue val);
1149 llvm::StringRef getOutputName(unsigned idx) const;
1150 bool hasSideEffects() const { return 0; }
1151 bool isCanonical() const { return 1; }
1152 bool isDataParallel() const { return 0; }
1153 std::string getDebugDesc() const;
1154 bool isEqual(const FullyConnectedGradNode &other) const;
1155 llvm::hash_code getHash() const;
1156 void visit(Node *parent, NodeWalker *visitor);
1157 Node* clone() const;
1158 bool verify() const;
1159};
1160} // namespace glow
1161
1162
1163namespace glow {
1164/// Creates a FullyConnected node where the Input tensor and Weights tensor are multiplied, and then the Bias tensor is added to it, producing the Output.
1165class FullyConnectedNode final : public Node {
1166 NodeHandle Input_;
1167 NodeHandle Weights_;
1168 NodeHandle Bias_;
1169
1170 public:
1171 enum InputIndices {
1172 InputIdx = 0,
1173 WeightsIdx = 1,
1174 BiasIdx = 2,
1175 };
1176
1177 enum ResultIndices {
1178 ResultIdx = 0,
1179 };
1180
1181 FullyConnectedNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Weights, NodeValue Bias)
1182 : Node(Kinded::Kind::FullyConnectedNodeKind, name), Input_(this, Input), Weights_(this, Weights), Bias_(this, Bias) {
1183 addResult(Result);
1184 }
1185 const NodeValue getInput() const { return Input_; }
1186 const NodeValue getWeights() const { return Weights_; }
1187 const NodeValue getBias() const { return Bias_; }
1188 NodeValue getResult() { return getNthResult(0); }
1189 const NodeValue getResult() const { return getNthResult(0); }
1190
1191 static bool classof(const Kinded *k) {
1192 return k->getKind() == Kinded::Kind::FullyConnectedNodeKind;
1193 }
1194
1195
1196 bool isOverwrittenNthInput(unsigned idx) const {
1197 return false;
1198 }
1199
1200 unsigned getNumInputs() const;
1201 std::string getInputName(unsigned idx) const;
1202 NodeValue getNthInput(unsigned idx);
1203 void setNthInput(unsigned idx, NodeValue val);
1204 llvm::StringRef getOutputName(unsigned idx) const;
1205 bool hasSideEffects() const { return 0; }
1206 bool isCanonical() const { return 1; }
1207 bool isDataParallel() const { return 0; }
1208 std::string getDebugDesc() const;
1209 bool isEqual(const FullyConnectedNode &other) const;
1210 llvm::hash_code getHash() const;
1211 void visit(Node *parent, NodeWalker *visitor);
1212 Node* clone() const;
1213 bool verify() const;
1214 FullyConnectedGradNode *getGrad(GraphGradMapper &builder);
1215};
1216} // namespace glow
1217
1218
1219namespace glow {
1220/// Creates a RowwiseQuantizedFullyConnected node where the Input matrix and the transpose of Weights matrix are multiplied, and then the Bias vector is broadcast-added to the result. Input, Bias and Result are regularly quantized, while Weights use row-wisequantization.
1221class RowwiseQuantizedFullyConnectedNode final : public Node {
1222 NodeHandle Input_;
1223 NodeHandle Weights_;
1224 NodeHandle Scales_;
1225 NodeHandle Offsets_;
1226 NodeHandle Bias_;
1227
1228 public:
1229 enum InputIndices {
1230 InputIdx = 0,
1231 WeightsIdx = 1,
1232 ScalesIdx = 2,
1233 OffsetsIdx = 3,
1234 BiasIdx = 4,
1235 };
1236
1237 enum ResultIndices {
1238 ResultIdx = 0,
1239 };
1240
1241 RowwiseQuantizedFullyConnectedNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Weights, NodeValue Scales, NodeValue Offsets, NodeValue Bias)
1242 : Node(Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind, name), Input_(this, Input), Weights_(this, Weights), Scales_(this, Scales), Offsets_(this, Offsets), Bias_(this, Bias) {
1243 addResult(Result);
1244 }
1245 const NodeValue getInput() const { return Input_; }
1246 const NodeValue getWeights() const { return Weights_; }
1247 const NodeValue getScales() const { return Scales_; }
1248 const NodeValue getOffsets() const { return Offsets_; }
1249 const NodeValue getBias() const { return Bias_; }
1250 NodeValue getResult() { return getNthResult(0); }
1251 const NodeValue getResult() const { return getNthResult(0); }
1252
1253 static bool classof(const Kinded *k) {
1254 return k->getKind() == Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind;
1255 }
1256
1257
1258 bool isOverwrittenNthInput(unsigned idx) const {
1259 return false;
1260 }
1261
1262 unsigned getNumInputs() const;
1263 std::string getInputName(unsigned idx) const;
1264 NodeValue getNthInput(unsigned idx);
1265 void setNthInput(unsigned idx, NodeValue val);
1266 llvm::StringRef getOutputName(unsigned idx) const;
1267 bool hasSideEffects() const { return 0; }
1268 bool isCanonical() const { return 1; }
1269 bool isDataParallel() const { return 0; }
1270 std::string getDebugDesc() const;
1271 bool isEqual(const RowwiseQuantizedFullyConnectedNode &other) const;
1272 llvm::hash_code getHash() const;
1273 void visit(Node *parent, NodeWalker *visitor);
1274 Node* clone() const;
1275 bool verify() const;
1276};
1277} // namespace glow
1278
1279
1280namespace glow {
1281/// Creates a DynamicQuantizedFullyConnectedNode which implement the functionality of dynamic_quantization => quantized_fc => dequantize, which support symmteric/asymmetric quantization. Quantize parameters are automatically selected from range of input, while weights are pre-quantized to int8 and bias are whether float or int32
1282class DynamicQuantizedFullyConnectedNode final : public Node {
1283 NodeHandle Input_;
1284 NodeHandle Weights_;
1285 NodeHandle Bias_;
1286 bool IsSymmetric_;
1287 bool IsPerBatchElement_;
1288
1289 public:
1290 enum InputIndices {
1291 InputIdx = 0,
1292 WeightsIdx = 1,
1293 BiasIdx = 2,
1294 };
1295
1296 enum ResultIndices {
1297 ResultIdx = 0,
1298 };
1299
1300 DynamicQuantizedFullyConnectedNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Weights, NodeValue Bias, bool IsSymmetric, bool IsPerBatchElement)
1301 : Node(Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind, name), Input_(this, Input), Weights_(this, Weights), Bias_(this, Bias), IsSymmetric_(IsSymmetric), IsPerBatchElement_(IsPerBatchElement) {
1302 addResult(Result);
1303 }
1304 const NodeValue getInput() const { return Input_; }
1305 const NodeValue getWeights() const { return Weights_; }
1306 const NodeValue getBias() const { return Bias_; }
1307 NodeValue getResult() { return getNthResult(0); }
1308 const NodeValue getResult() const { return getNthResult(0); }
1309 bool getIsSymmetric() const { return IsSymmetric_; }
1310 bool getIsPerBatchElement() const { return IsPerBatchElement_; }
1311
1312 static bool classof(const Kinded *k) {
1313 return k->getKind() == Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind;
1314 }
1315
1316
1317 bool isOverwrittenNthInput(unsigned idx) const {
1318 return false;
1319 }
1320
1321 unsigned getNumInputs() const;
1322 std::string getInputName(unsigned idx) const;
1323 NodeValue getNthInput(unsigned idx);
1324 void setNthInput(unsigned idx, NodeValue val);
1325 llvm::StringRef getOutputName(unsigned idx) const;
1326 bool hasSideEffects() const { return 0; }
1327 bool isCanonical() const { return 1; }
1328 bool isDataParallel() const { return 0; }
1329 std::string getDebugDesc() const;
1330 bool isEqual(const DynamicQuantizedFullyConnectedNode &other) const;
1331 llvm::hash_code getHash() const;
1332 void visit(Node *parent, NodeWalker *visitor);
1333 Node* clone() const;
1334 bool verify() const;
1335};
1336} // namespace glow
1337
1338
1339namespace glow {
1340/// Creates a DynamicRowwiseQuantizedFullyConnectedNode which implement the functionality of dynamic_quantization => quantized_fc => dequantize, which support symmteric/asymmetric quantization. Quantize parameters are automatically selected from range of input, while weights are pre-rowwise-quantized to int8, whose rowwise params stored in Scales and Offsets, and bias are whether float or int32
1341class DynamicRowwiseQuantizedFullyConnectedNode final : public Node {
1342 NodeHandle Input_;
1343 NodeHandle Weights_;
1344 NodeHandle Bias_;
1345 NodeHandle Scales_;
1346 NodeHandle Offsets_;
1347 bool IsSymmetric_;
1348 bool IsPerBatchElement_;
1349
1350 public:
1351 enum InputIndices {
1352 InputIdx = 0,
1353 WeightsIdx = 1,
1354 BiasIdx = 2,
1355 ScalesIdx = 3,
1356 OffsetsIdx = 4,
1357 };
1358
1359 enum ResultIndices {
1360 ResultIdx = 0,
1361 };
1362
1363 DynamicRowwiseQuantizedFullyConnectedNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Weights, NodeValue Bias, NodeValue Scales, NodeValue Offsets, bool IsSymmetric, bool IsPerBatchElement)
1364 : Node(Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind, name), Input_(this, Input), Weights_(this, Weights), Bias_(this, Bias), Scales_(this, Scales), Offsets_(this, Offsets), IsSymmetric_(IsSymmetric), IsPerBatchElement_(IsPerBatchElement) {
1365 addResult(Result);
1366 }
1367 const NodeValue getInput() const { return Input_; }
1368 const NodeValue getWeights() const { return Weights_; }
1369 const NodeValue getBias() const { return Bias_; }
1370 const NodeValue getScales() const { return Scales_; }
1371 const NodeValue getOffsets() const { return Offsets_; }
1372 NodeValue getResult() { return getNthResult(0); }
1373 const NodeValue getResult() const { return getNthResult(0); }
1374 bool getIsSymmetric() const { return IsSymmetric_; }
1375 bool getIsPerBatchElement() const { return IsPerBatchElement_; }
1376
1377 static bool classof(const Kinded *k) {
1378 return k->getKind() == Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind;
1379 }
1380
1381
1382 bool isOverwrittenNthInput(unsigned idx) const {
1383 return false;
1384 }
1385
1386 unsigned getNumInputs() const;
1387 std::string getInputName(unsigned idx) const;
1388 NodeValue getNthInput(unsigned idx);
1389 void setNthInput(unsigned idx, NodeValue val);
1390 llvm::StringRef getOutputName(unsigned idx) const;
1391 bool hasSideEffects() const { return 0; }
1392 bool isCanonical() const { return 1; }
1393 bool isDataParallel() const { return 0; }
1394 std::string getDebugDesc() const;
1395 bool isEqual(const DynamicRowwiseQuantizedFullyConnectedNode &other) const;
1396 llvm::hash_code getHash() const;
1397 void visit(Node *parent, NodeWalker *visitor);
1398 Node* clone() const;
1399 bool verify() const;
1400};
1401} // namespace glow
1402
1403
1404namespace glow {
1405class BatchNormalizationGradNode final : public Node {
1406 NodeHandle Input_;
1407 NodeHandle Scale_;
1408 NodeHandle Bias_;
1409 NodeHandle Mean_;
1410 NodeHandle Var_;
1411 NodeHandle OriginalOutputForResult_;
1412 NodeHandle GradOfOriginalOutputNamedResult_;
1413 unsigned_t ChannelIdx_;
1414 float Epsilon_;
1415 float Momentum_;
1416
1417 public:
1418 enum InputIndices {
1419 InputIdx = 0,
1420 ScaleIdx = 1,
1421 BiasIdx = 2,
1422 MeanIdx = 3,
1423 VarIdx = 4,
1424 OriginalOutputForResultIdx = 5,
1425 GradOfOriginalOutputNamedResultIdx = 6,
1426 };
1427
1428 enum ResultIndices {
1429 GradOfInputNamedInputIdx = 0,
1430 GradOfInputNamedScaleIdx = 1,
1431 GradOfInputNamedBiasIdx = 2,
1432 GradOfInputNamedMeanIdx = 3,
1433 GradOfInputNamedVarIdx = 4,
1434 };
1435
1436 BatchNormalizationGradNode(llvm::StringRef name, NodeValue Input, NodeValue Scale, NodeValue Bias, NodeValue Mean, NodeValue Var, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, unsigned_t ChannelIdx, float Epsilon, float Momentum)
1437 : Node(Kinded::Kind::BatchNormalizationGradNodeKind, name), Input_(this, Input), Scale_(this, Scale), Bias_(this, Bias), Mean_(this, Mean), Var_(this, Var), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), ChannelIdx_(ChannelIdx), Epsilon_(Epsilon), Momentum_(Momentum) {
1438 addResult(Input.getType());
1439 addResult(Scale.getType());
1440 addResult(Bias.getType());
1441 addResult(Mean.getType());
1442 addResult(Var.getType());
1443 }
1444 const NodeValue getInput() const { return Input_; }
1445 const NodeValue getScale() const { return Scale_; }
1446 const NodeValue getBias() const { return Bias_; }
1447 const NodeValue getMean() const { return Mean_; }
1448 const NodeValue getVar() const { return Var_; }
1449 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
1450 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
1451 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
1452 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
1453 NodeValue getGradOfInputNamedScale() { return getNthResult(1); }
1454 const NodeValue getGradOfInputNamedScale() const { return getNthResult(1); }
1455 NodeValue getGradOfInputNamedBias() { return getNthResult(2); }
1456 const NodeValue getGradOfInputNamedBias() const { return getNthResult(2); }
1457 NodeValue getGradOfInputNamedMean() { return getNthResult(3); }
1458 const NodeValue getGradOfInputNamedMean() const { return getNthResult(3); }
1459 NodeValue getGradOfInputNamedVar() { return getNthResult(4); }
1460 const NodeValue getGradOfInputNamedVar() const { return getNthResult(4); }
1461 unsigned_t getChannelIdx() const { return ChannelIdx_; }
1462 float getEpsilon() const { return Epsilon_; }
1463 float getMomentum() const { return Momentum_; }
1464
1465 static bool classof(const Kinded *k) {
1466 return k->getKind() == Kinded::Kind::BatchNormalizationGradNodeKind;
1467 }
1468
1469
1470 bool isOverwrittenNthInput(unsigned idx) const {
1471 return false;
1472 }
1473
1474 unsigned getNumInputs() const;
1475 std::string getInputName(unsigned idx) const;
1476 NodeValue getNthInput(unsigned idx);
1477 void setNthInput(unsigned idx, NodeValue val);
1478 llvm::StringRef getOutputName(unsigned idx) const;
1479 bool hasSideEffects() const { return 0; }
1480 bool isCanonical() const { return 1; }
1481 bool isDataParallel() const { return 0; }
1482 std::string getDebugDesc() const;
1483 bool isEqual(const BatchNormalizationGradNode &other) const;
1484 llvm::hash_code getHash() const;
1485 void visit(Node *parent, NodeWalker *visitor);
1486 Node* clone() const;
1487 bool verify() const;
1488};
1489} // namespace glow
1490
1491
1492namespace glow {
1493/// Performs batch normalization on the Input tensor with the provided Scale, Bias, Mean, Var, ChannelIdx, Epsilon, and Momentum. Similar to Caffe2 SpatialBN, and ONNX BatchNormalization operator.
1494class BatchNormalizationNode final : public Node {
1495 NodeHandle Input_;
1496 NodeHandle Scale_;
1497 NodeHandle Bias_;
1498 NodeHandle Mean_;
1499 NodeHandle Var_;
1500 unsigned_t ChannelIdx_;
1501 float Epsilon_;
1502 float Momentum_;
1503
1504 public:
1505 enum InputIndices {
1506 InputIdx = 0,
1507 ScaleIdx = 1,
1508 BiasIdx = 2,
1509 MeanIdx = 3,
1510 VarIdx = 4,
1511 };
1512
1513 enum ResultIndices {
1514 ResultIdx = 0,
1515 };
1516
1517 BatchNormalizationNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Scale, NodeValue Bias, NodeValue Mean, NodeValue Var, unsigned_t ChannelIdx, float Epsilon, float Momentum)
1518 : Node(Kinded::Kind::BatchNormalizationNodeKind, name), Input_(this, Input), Scale_(this, Scale), Bias_(this, Bias), Mean_(this, Mean), Var_(this, Var), ChannelIdx_(ChannelIdx), Epsilon_(Epsilon), Momentum_(Momentum) {
1519 addResult(Result);
1520 }
1521 const NodeValue getInput() const { return Input_; }
1522 const NodeValue getScale() const { return Scale_; }
1523 const NodeValue getBias() const { return Bias_; }
1524 const NodeValue getMean() const { return Mean_; }
1525 const NodeValue getVar() const { return Var_; }
1526 NodeValue getResult() { return getNthResult(0); }
1527 const NodeValue getResult() const { return getNthResult(0); }
1528 unsigned_t getChannelIdx() const { return ChannelIdx_; }
1529 float getEpsilon() const { return Epsilon_; }
1530 float getMomentum() const { return Momentum_; }
1531
1532 static bool classof(const Kinded *k) {
1533 return k->getKind() == Kinded::Kind::BatchNormalizationNodeKind;
1534 }
1535
1536
1537 bool isOverwrittenNthInput(unsigned idx) const {
1538 return false;
1539 }
1540
1541 unsigned getNumInputs() const;
1542 std::string getInputName(unsigned idx) const;
1543 NodeValue getNthInput(unsigned idx);
1544 void setNthInput(unsigned idx, NodeValue val);
1545 llvm::StringRef getOutputName(unsigned idx) const;
1546 bool hasSideEffects() const { return 0; }
1547 bool isCanonical() const { return 1; }
1548 bool isDataParallel() const { return 0; }
1549 std::string getDebugDesc() const;
1550 bool isEqual(const BatchNormalizationNode &other) const;
1551 llvm::hash_code getHash() const;
1552 void visit(Node *parent, NodeWalker *visitor);
1553 Node* clone() const;
1554 bool verify() const;
1555 BatchNormalizationGradNode *getGrad(GraphGradMapper &builder);
1556};
1557} // namespace glow
1558
1559
1560namespace glow {
1561/// Performs instance normalization on the Input tensor with the provided Scale, Bias, Epsilon. Similar to ONNX InstanceNormalization operator.
1562class InstanceNormalizationNode final : public Node {
1563 NodeHandle Input_;
1564 NodeHandle Scale_;
1565 NodeHandle Bias_;
1566 unsigned_t ChannelIdx_;
1567 float Epsilon_;
1568
1569 public:
1570 enum InputIndices {
1571 InputIdx = 0,
1572 ScaleIdx = 1,
1573 BiasIdx = 2,
1574 };
1575
1576 enum ResultIndices {
1577 ResultIdx = 0,
1578 };
1579
1580 InstanceNormalizationNode(llvm::StringRef name, NodeValue Input, NodeValue Scale, NodeValue Bias, unsigned_t ChannelIdx, float Epsilon)
1581 : Node(Kinded::Kind::InstanceNormalizationNodeKind, name), Input_(this, Input), Scale_(this, Scale), Bias_(this, Bias), ChannelIdx_(ChannelIdx), Epsilon_(Epsilon) {
1582 addResult(Input.getType());
1583 }
1584 const NodeValue getInput() const { return Input_; }
1585 const NodeValue getScale() const { return Scale_; }
1586 const NodeValue getBias() const { return Bias_; }
1587 NodeValue getResult() { return getNthResult(0); }
1588 const NodeValue getResult() const { return getNthResult(0); }
1589 unsigned_t getChannelIdx() const { return ChannelIdx_; }
1590 float getEpsilon() const { return Epsilon_; }
1591
1592 static bool classof(const Kinded *k) {
1593 return k->getKind() == Kinded::Kind::InstanceNormalizationNodeKind;
1594 }
1595
1596
1597 bool isOverwrittenNthInput(unsigned idx) const {
1598 return false;
1599 }
1600
1601 unsigned getNumInputs() const;
1602 std::string getInputName(unsigned idx) const;
1603 NodeValue getNthInput(unsigned idx);
1604 void setNthInput(unsigned idx, NodeValue val);
1605 llvm::StringRef getOutputName(unsigned idx) const;
1606 bool hasSideEffects() const { return 0; }
1607 bool isCanonical() const { return 1; }
1608 bool isDataParallel() const { return 0; }
1609 std::string getDebugDesc() const;
1610 bool isEqual(const InstanceNormalizationNode &other) const;
1611 llvm::hash_code getHash() const;
1612 void visit(Node *parent, NodeWalker *visitor);
1613 Node* clone() const;
1614 bool verify() const;
1615};
1616} // namespace glow
1617
1618
1619namespace glow {
1620/// Calculates new normalized mean and variance based on the input mean, variance, and input.
1621class MeanVarNormalizationNode final : public Node {
1622 NodeHandle Input_;
1623 NodeHandle Mean_;
1624 NodeHandle Var_;
1625 unsigned_t ChannelIdx_;
1626 float Momentum_;
1627
1628 public:
1629 enum InputIndices {
1630 InputIdx = 0,
1631 MeanIdx = 1,
1632 VarIdx = 2,
1633 };
1634
1635 enum ResultIndices {
1636 NewMeanIdx = 0,
1637 NewVarIdx = 1,
1638 };
1639
1640 MeanVarNormalizationNode(llvm::StringRef name, NodeValue Input, NodeValue Mean, NodeValue Var, unsigned_t ChannelIdx, float Momentum)
1641 : Node(Kinded::Kind::MeanVarNormalizationNodeKind, name), Input_(this, Input), Mean_(this, Mean), Var_(this, Var), ChannelIdx_(ChannelIdx), Momentum_(Momentum) {
1642 addResult(Mean.getType());
1643 addResult(Var.getType());
1644 }
1645 const NodeValue getInput() const { return Input_; }
1646 const NodeValue getMean() const { return Mean_; }
1647 const NodeValue getVar() const { return Var_; }
1648 NodeValue getNewMean() { return getNthResult(0); }
1649 const NodeValue getNewMean() const { return getNthResult(0); }
1650 NodeValue getNewVar() { return getNthResult(1); }
1651 const NodeValue getNewVar() const { return getNthResult(1); }
1652 unsigned_t getChannelIdx() const { return ChannelIdx_; }
1653 float getMomentum() const { return Momentum_; }
1654
1655 static bool classof(const Kinded *k) {
1656 return k->getKind() == Kinded::Kind::MeanVarNormalizationNodeKind;
1657 }
1658
1659
1660 bool isOverwrittenNthInput(unsigned idx) const {
1661 return false;
1662 }
1663
1664 unsigned getNumInputs() const;
1665 std::string getInputName(unsigned idx) const;
1666 NodeValue getNthInput(unsigned idx);
1667 void setNthInput(unsigned idx, NodeValue val);
1668 llvm::StringRef getOutputName(unsigned idx) const;
1669 bool hasSideEffects() const { return 0; }
1670 bool isCanonical() const { return 1; }
1671 bool isDataParallel() const { return 0; }
1672 std::string getDebugDesc() const;
1673 bool isEqual(const MeanVarNormalizationNode &other) const;
1674 llvm::hash_code getHash() const;
1675 void visit(Node *parent, NodeWalker *visitor);
1676 Node* clone() const;
1677 bool verify() const;
1678};
1679} // namespace glow
1680
1681
1682namespace glow {
1683class LocalResponseNormalizationGradNode final : public Node {
1684 NodeHandle Input_;
1685 NodeHandle OriginalOutputForResult_;
1686 NodeHandle GradOfOriginalOutputNamedResult_;
1687 unsigned_t HalfWindowSize_;
1688 float Alpha_;
1689 float Beta_;
1690 float K_;
1691
1692 public:
1693 enum InputIndices {
1694 InputIdx = 0,
1695 OriginalOutputForResultIdx = 1,
1696 GradOfOriginalOutputNamedResultIdx = 2,
1697 };
1698
1699 enum ResultIndices {
1700 GradOfInputNamedInputIdx = 0,
1701 };
1702
1703 LocalResponseNormalizationGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, unsigned_t HalfWindowSize, float Alpha, float Beta, float K)
1704 : Node(Kinded::Kind::LocalResponseNormalizationGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), HalfWindowSize_(HalfWindowSize), Alpha_(Alpha), Beta_(Beta), K_(K) {
1705 addResult(Input.getType());
1706 }
1707 const NodeValue getInput() const { return Input_; }
1708 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
1709 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
1710 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
1711 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
1712 unsigned_t getHalfWindowSize() const { return HalfWindowSize_; }
1713 float getAlpha() const { return Alpha_; }
1714 float getBeta() const { return Beta_; }
1715 float getK() const { return K_; }
1716
1717 static bool classof(const Kinded *k) {
1718 return k->getKind() == Kinded::Kind::LocalResponseNormalizationGradNodeKind;
1719 }
1720
1721
1722 bool isOverwrittenNthInput(unsigned idx) const {
1723 return false;
1724 }
1725
1726 unsigned getNumInputs() const;
1727 std::string getInputName(unsigned idx) const;
1728 NodeValue getNthInput(unsigned idx);
1729 void setNthInput(unsigned idx, NodeValue val);
1730 llvm::StringRef getOutputName(unsigned idx) const;
1731 bool hasSideEffects() const { return 0; }
1732 bool isCanonical() const { return 1; }
1733 bool isDataParallel() const { return 0; }
1734 std::string getDebugDesc() const;
1735 bool isEqual(const LocalResponseNormalizationGradNode &other) const;
1736 llvm::hash_code getHash() const;
1737 void visit(Node *parent, NodeWalker *visitor);
1738 Node* clone() const;
1739 bool verify() const;
1740};
1741} // namespace glow
1742
1743
1744namespace glow {
1745/// Performs local response normalization on the Input tensor with the provided Scale, Bias, Mean, Var, ChannelIdx, Epsilon, and Momentum. Similar to Caffe2 and ONNX LRN.
1746class LocalResponseNormalizationNode final : public Node {
1747 NodeHandle Input_;
1748 unsigned_t HalfWindowSize_;
1749 float Alpha_;
1750 float Beta_;
1751 float K_;
1752
1753 public:
1754 enum InputIndices {
1755 InputIdx = 0,
1756 };
1757
1758 enum ResultIndices {
1759 ResultIdx = 0,
1760 };
1761
1762 LocalResponseNormalizationNode(llvm::StringRef name, NodeValue Input, unsigned_t HalfWindowSize, float Alpha, float Beta, float K)
1763 : Node(Kinded::Kind::LocalResponseNormalizationNodeKind, name), Input_(this, Input), HalfWindowSize_(HalfWindowSize), Alpha_(Alpha), Beta_(Beta), K_(K) {
1764 addResult(Input.getType());
1765 }
1766 const NodeValue getInput() const { return Input_; }
1767 NodeValue getResult() { return getNthResult(0); }
1768 const NodeValue getResult() const { return getNthResult(0); }
1769 unsigned_t getHalfWindowSize() const { return HalfWindowSize_; }
1770 float getAlpha() const { return Alpha_; }
1771 float getBeta() const { return Beta_; }
1772 float getK() const { return K_; }
1773
1774 static bool classof(const Kinded *k) {
1775 return k->getKind() == Kinded::Kind::LocalResponseNormalizationNodeKind;
1776 }
1777
1778
1779 bool isOverwrittenNthInput(unsigned idx) const {
1780 return false;
1781 }
1782
1783 unsigned getNumInputs() const;
1784 std::string getInputName(unsigned idx) const;
1785 NodeValue getNthInput(unsigned idx);
1786 void setNthInput(unsigned idx, NodeValue val);
1787 llvm::StringRef getOutputName(unsigned idx) const;
1788 bool hasSideEffects() const { return 0; }
1789 bool isCanonical() const { return 1; }
1790 bool isDataParallel() const { return 0; }
1791 std::string getDebugDesc() const;
1792 bool isEqual(const LocalResponseNormalizationNode &other) const;
1793 llvm::hash_code getHash() const;
1794 void visit(Node *parent, NodeWalker *visitor);
1795 Node* clone() const;
1796 bool verify() const;
1797 LocalResponseNormalizationGradNode *getGrad(GraphGradMapper &builder);
1798};
1799} // namespace glow
1800
1801
1802namespace glow {
1803/// Performs layer normalization on the Input tensor with the provided Scale, Bias, and Epsilon. Layer sizes are determined by the dimensions of Scale and Bias. Similar to PyTorch layer_norm.
1804class LayerNormalizationNode final : public Node {
1805 NodeHandle Input_;
1806 NodeHandle Scale_;
1807 NodeHandle Bias_;
1808 float Epsilon_;
1809
1810 public:
1811 enum InputIndices {
1812 InputIdx = 0,
1813 ScaleIdx = 1,
1814 BiasIdx = 2,
1815 };
1816
1817 enum ResultIndices {
1818 ResultIdx = 0,
1819 };
1820
1821 LayerNormalizationNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Scale, NodeValue Bias, float Epsilon)
1822 : Node(Kinded::Kind::LayerNormalizationNodeKind, name), Input_(this, Input), Scale_(this, Scale), Bias_(this, Bias), Epsilon_(Epsilon) {
1823 addResult(Result);
1824 }
1825 const NodeValue getInput() const { return Input_; }
1826 const NodeValue getScale() const { return Scale_; }
1827 const NodeValue getBias() const { return Bias_; }
1828 NodeValue getResult() { return getNthResult(0); }
1829 const NodeValue getResult() const { return getNthResult(0); }
1830 float getEpsilon() const { return Epsilon_; }
1831
1832 static bool classof(const Kinded *k) {
1833 return k->getKind() == Kinded::Kind::LayerNormalizationNodeKind;
1834 }
1835
1836
1837 bool isOverwrittenNthInput(unsigned idx) const {
1838 return false;
1839 }
1840
1841 unsigned getNumInputs() const;
1842 std::string getInputName(unsigned idx) const;
1843 NodeValue getNthInput(unsigned idx);
1844 void setNthInput(unsigned idx, NodeValue val);
1845 llvm::StringRef getOutputName(unsigned idx) const;
1846 bool hasSideEffects() const { return 0; }
1847 bool isCanonical() const { return 1; }
1848 bool isDataParallel() const { return 0; }
1849 std::string getDebugDesc() const;
1850 bool isEqual(const LayerNormalizationNode &other) const;
1851 llvm::hash_code getHash() const;
1852 void visit(Node *parent, NodeWalker *visitor);
1853 Node* clone() const;
1854 bool verify() const;
1855};
1856} // namespace glow
1857
1858
1859namespace glow {
1860/// Apply box-cox transform for each column for each column in NxD input tensor
1861class BatchBoxCoxNode final : public Node {
1862 NodeHandle Input_;
1863 NodeHandle Lambda1_;
1864 NodeHandle Lambda2_;
1865 float Epsilon_;
1866
1867 public:
1868 enum InputIndices {
1869 InputIdx = 0,
1870 Lambda1Idx = 1,
1871 Lambda2Idx = 2,
1872 };
1873
1874 enum ResultIndices {
1875 ResultIdx = 0,
1876 };
1877
1878 BatchBoxCoxNode(llvm::StringRef name, NodeValue Input, NodeValue Lambda1, NodeValue Lambda2, float Epsilon)
1879 : Node(Kinded::Kind::BatchBoxCoxNodeKind, name), Input_(this, Input), Lambda1_(this, Lambda1), Lambda2_(this, Lambda2), Epsilon_(Epsilon) {
1880 addResult(Input.getType());
1881 }
1882 const NodeValue getInput() const { return Input_; }
1883 const NodeValue getLambda1() const { return Lambda1_; }
1884 const NodeValue getLambda2() const { return Lambda2_; }
1885 NodeValue getResult() { return getNthResult(0); }
1886 const NodeValue getResult() const { return getNthResult(0); }
1887 float getEpsilon() const { return Epsilon_; }
1888
1889 static bool classof(const Kinded *k) {
1890 return k->getKind() == Kinded::Kind::BatchBoxCoxNodeKind;
1891 }
1892
1893
1894 bool isOverwrittenNthInput(unsigned idx) const {
1895 return false;
1896 }
1897
1898 unsigned getNumInputs() const;
1899 std::string getInputName(unsigned idx) const;
1900 NodeValue getNthInput(unsigned idx);
1901 void setNthInput(unsigned idx, NodeValue val);
1902 llvm::StringRef getOutputName(unsigned idx) const;
1903 bool hasSideEffects() const { return 0; }
1904 bool isCanonical() const { return 1; }
1905 bool isDataParallel() const { return 0; }
1906 std::string getDebugDesc() const;
1907 bool isEqual(const BatchBoxCoxNode &other) const;
1908 llvm::hash_code getHash() const;
1909 void visit(Node *parent, NodeWalker *visitor);
1910 Node* clone() const;
1911 bool verify() const;
1912};
1913} // namespace glow
1914
1915
1916namespace glow {
1917/// Performs L2 norm of the Input operand based on Axis.
1918class VectorNormNode final : public Node {
1919 NodeHandle Input_;
1920 unsigned_t Axis_;
1921 unsigned_t P_;
1922
1923 public:
1924 enum InputIndices {
1925 InputIdx = 0,
1926 };
1927
1928 enum ResultIndices {
1929 ResultIdx = 0,
1930 };
1931
1932 VectorNormNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis, unsigned_t P)
1933 : Node(Kinded::Kind::VectorNormNodeKind, name), Input_(this, Input), Axis_(Axis), P_(P) {
1934 addResult(Result);
1935 }
1936 const NodeValue getInput() const { return Input_; }
1937 NodeValue getResult() { return getNthResult(0); }
1938 const NodeValue getResult() const { return getNthResult(0); }
1939 unsigned_t getAxis() const { return Axis_; }
1940 unsigned_t getP() const { return P_; }
1941
1942 static bool classof(const Kinded *k) {
1943 return k->getKind() == Kinded::Kind::VectorNormNodeKind;
1944 }
1945
1946
1947 bool isOverwrittenNthInput(unsigned idx) const {
1948 return false;
1949 }
1950
1951 unsigned getNumInputs() const;
1952 std::string getInputName(unsigned idx) const;
1953 NodeValue getNthInput(unsigned idx);
1954 void setNthInput(unsigned idx, NodeValue val);
1955 llvm::StringRef getOutputName(unsigned idx) const;
1956 bool hasSideEffects() const { return 0; }
1957 bool isCanonical() const { return 1; }
1958 bool isDataParallel() const { return 0; }
1959 std::string getDebugDesc() const;
1960 bool isEqual(const VectorNormNode &other) const;
1961 llvm::hash_code getHash() const;
1962 void visit(Node *parent, NodeWalker *visitor);
1963 Node* clone() const;
1964 bool verify() const;
1965};
1966} // namespace glow
1967
1968
1969namespace glow {
1970/// Performs bucketization on the input given Boundaries
1971class BucketizeNode final : public Node {
1972 NodeHandle Input_;
1973 std::vector<float> Boundaries_;
1974
1975 public:
1976 enum InputIndices {
1977 InputIdx = 0,
1978 };
1979
1980 enum ResultIndices {
1981 ResultIdx = 0,
1982 };
1983
1984 BucketizeNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<float> Boundaries)
1985 : Node(Kinded::Kind::BucketizeNodeKind, name), Input_(this, Input), Boundaries_(Boundaries) {
1986 addResult(Result);
1987 }
1988 const NodeValue getInput() const { return Input_; }
1989 NodeValue getResult() { return getNthResult(0); }
1990 const NodeValue getResult() const { return getNthResult(0); }
1991 llvm::ArrayRef<float> getBoundaries() const { return Boundaries_; }
1992
1993 static bool classof(const Kinded *k) {
1994 return k->getKind() == Kinded::Kind::BucketizeNodeKind;
1995 }
1996
1997
1998 bool isOverwrittenNthInput(unsigned idx) const {
1999 return false;
2000 }
2001
2002 unsigned getNumInputs() const;
2003 std::string getInputName(unsigned idx) const;
2004 NodeValue getNthInput(unsigned idx);
2005 void setNthInput(unsigned idx, NodeValue val);
2006 llvm::StringRef getOutputName(unsigned idx) const;
2007 bool hasSideEffects() const { return 0; }
2008 bool isCanonical() const { return 1; }
2009 bool isDataParallel() const { return 0; }
2010 std::string getDebugDesc() const;
2011 bool isEqual(const BucketizeNode &other) const;
2012 llvm::hash_code getHash() const;
2013 void visit(Node *parent, NodeWalker *visitor);
2014 Node* clone() const;
2015 bool verify() const;
2016};
2017} // namespace glow
2018
2019
2020namespace glow {
2021class SoftMaxGradNode final : public Node {
2022 NodeHandle Input_;
2023 NodeHandle Selected_;
2024 NodeHandle OriginalOutputForResult_;
2025 NodeHandle GradOfOriginalOutputNamedResult_;
2026
2027 public:
2028 enum InputIndices {
2029 InputIdx = 0,
2030 SelectedIdx = 1,
2031 OriginalOutputForResultIdx = 2,
2032 GradOfOriginalOutputNamedResultIdx = 3,
2033 };
2034
2035 enum ResultIndices {
2036 GradOfInputNamedInputIdx = 0,
2037 GradOfInputNamedSelectedIdx = 1,
2038 };
2039
2040 SoftMaxGradNode(llvm::StringRef name, NodeValue Input, NodeValue Selected, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
2041 : Node(Kinded::Kind::SoftMaxGradNodeKind, name), Input_(this, Input), Selected_(this, Selected), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
2042 addResult(Input.getType());
2043 addResult(Selected.getType());
2044 }
2045 const NodeValue getInput() const { return Input_; }
2046 const NodeValue getSelected() const { return Selected_; }
2047 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
2048 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
2049 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
2050 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
2051 NodeValue getGradOfInputNamedSelected() { return getNthResult(1); }
2052 const NodeValue getGradOfInputNamedSelected() const { return getNthResult(1); }
2053
2054 static bool classof(const Kinded *k) {
2055 return k->getKind() == Kinded::Kind::SoftMaxGradNodeKind;
2056 }
2057
2058
2059 bool isOverwrittenNthInput(unsigned idx) const {
2060 return false;
2061 }
2062
2063 unsigned getNumInputs() const;
2064 std::string getInputName(unsigned idx) const;
2065 NodeValue getNthInput(unsigned idx);
2066 void setNthInput(unsigned idx, NodeValue val);
2067 llvm::StringRef getOutputName(unsigned idx) const;
2068 bool hasSideEffects() const { return 0; }
2069 bool isCanonical() const { return 1; }
2070 bool isDataParallel() const { return 0; }
2071 std::string getDebugDesc() const;
2072 bool isEqual(const SoftMaxGradNode &other) const;
2073 llvm::hash_code getHash() const;
2074 void visit(Node *parent, NodeWalker *visitor);
2075 Node* clone() const;
2076 bool verify() const;
2077};
2078} // namespace glow
2079
2080
2081namespace glow {
2082/// Performs SoftMax normalization on the Input tensor.
2083class SoftMaxNode final : public Node {
2084 NodeHandle Input_;
2085 NodeHandle Selected_;
2086
2087 public:
2088 enum InputIndices {
2089 InputIdx = 0,
2090 SelectedIdx = 1,
2091 };
2092
2093 enum ResultIndices {
2094 ResultIdx = 0,
2095 };
2096
2097 SoftMaxNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Selected)
2098 : Node(Kinded::Kind::SoftMaxNodeKind, name), Input_(this, Input), Selected_(this, Selected) {
2099 addResult(Result);
2100 }
2101 const NodeValue getInput() const { return Input_; }
2102 const NodeValue getSelected() const { return Selected_; }
2103 NodeValue getResult() { return getNthResult(0); }
2104 const NodeValue getResult() const { return getNthResult(0); }
2105
2106 static bool classof(const Kinded *k) {
2107 return k->getKind() == Kinded::Kind::SoftMaxNodeKind;
2108 }
2109
2110
2111 bool isOverwrittenNthInput(unsigned idx) const {
2112 return false;
2113 }
2114
2115 unsigned getNumInputs() const;
2116 std::string getInputName(unsigned idx) const;
2117 NodeValue getNthInput(unsigned idx);
2118 void setNthInput(unsigned idx, NodeValue val);
2119 llvm::StringRef getOutputName(unsigned idx) const;
2120 bool hasSideEffects() const { return 0; }
2121 bool isCanonical() const { return 1; }
2122 bool isDataParallel() const { return 0; }
2123 std::string getDebugDesc() const;
2124 bool isEqual(const SoftMaxNode &other) const;
2125 llvm::hash_code getHash() const;
2126 void visit(Node *parent, NodeWalker *visitor);
2127 Node* clone() const;
2128 bool verify() const;
2129 SoftMaxGradNode *getGrad(GraphGradMapper &builder);
2130};
2131} // namespace glow
2132
2133
2134namespace glow {
2135class LogSoftMaxGradNode final : public Node {
2136 NodeHandle Input_;
2137 NodeHandle Selected_;
2138 NodeHandle OriginalOutputForResult_;
2139 NodeHandle GradOfOriginalOutputNamedResult_;
2140
2141 public:
2142 enum InputIndices {
2143 InputIdx = 0,
2144 SelectedIdx = 1,
2145 OriginalOutputForResultIdx = 2,
2146 GradOfOriginalOutputNamedResultIdx = 3,
2147 };
2148
2149 enum ResultIndices {
2150 GradOfInputNamedInputIdx = 0,
2151 GradOfInputNamedSelectedIdx = 1,
2152 };
2153
2154 LogSoftMaxGradNode(llvm::StringRef name, NodeValue Input, NodeValue Selected, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
2155 : Node(Kinded::Kind::LogSoftMaxGradNodeKind, name), Input_(this, Input), Selected_(this, Selected), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
2156 addResult(Input.getType());
2157 addResult(Selected.getType());
2158 }
2159 const NodeValue getInput() const { return Input_; }
2160 const NodeValue getSelected() const { return Selected_; }
2161 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
2162 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
2163 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
2164 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
2165 NodeValue getGradOfInputNamedSelected() { return getNthResult(1); }
2166 const NodeValue getGradOfInputNamedSelected() const { return getNthResult(1); }
2167
2168 static bool classof(const Kinded *k) {
2169 return k->getKind() == Kinded::Kind::LogSoftMaxGradNodeKind;
2170 }
2171
2172
2173 bool isOverwrittenNthInput(unsigned idx) const {
2174 return false;
2175 }
2176
2177 unsigned getNumInputs() const;
2178 std::string getInputName(unsigned idx) const;
2179 NodeValue getNthInput(unsigned idx);
2180 void setNthInput(unsigned idx, NodeValue val);
2181 llvm::StringRef getOutputName(unsigned idx) const;
2182 bool hasSideEffects() const { return 0; }
2183 bool isCanonical() const { return 1; }
2184 bool isDataParallel() const { return 0; }
2185 std::string getDebugDesc() const;
2186 bool isEqual(const LogSoftMaxGradNode &other) const;
2187 llvm::hash_code getHash() const;
2188 void visit(Node *parent, NodeWalker *visitor);
2189 Node* clone() const;
2190 bool verify() const;
2191};
2192} // namespace glow
2193
2194
2195namespace glow {
2196/// Performs LogSoftMax normalization on the Input tensor.
2197class LogSoftMaxNode final : public Node {
2198 NodeHandle Input_;
2199 NodeHandle Selected_;
2200
2201 public:
2202 enum InputIndices {
2203 InputIdx = 0,
2204 SelectedIdx = 1,
2205 };
2206
2207 enum ResultIndices {
2208 ResultIdx = 0,
2209 };
2210
2211 LogSoftMaxNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Selected)
2212 : Node(Kinded::Kind::LogSoftMaxNodeKind, name), Input_(this, Input), Selected_(this, Selected) {
2213 addResult(Result);
2214 }
2215 const NodeValue getInput() const { return Input_; }
2216 const NodeValue getSelected() const { return Selected_; }
2217 NodeValue getResult() { return getNthResult(0); }
2218 const NodeValue getResult() const { return getNthResult(0); }
2219
2220 static bool classof(const Kinded *k) {
2221 return k->getKind() == Kinded::Kind::LogSoftMaxNodeKind;
2222 }
2223
2224
2225 bool isOverwrittenNthInput(unsigned idx) const {
2226 return false;
2227 }
2228
2229 unsigned getNumInputs() const;
2230 std::string getInputName(unsigned idx) const;
2231 NodeValue getNthInput(unsigned idx);
2232 void setNthInput(unsigned idx, NodeValue val);
2233 llvm::StringRef getOutputName(unsigned idx) const;
2234 bool hasSideEffects() const { return 0; }
2235 bool isCanonical() const { return 1; }
2236 bool isDataParallel() const { return 0; }
2237 std::string getDebugDesc() const;
2238 bool isEqual(const LogSoftMaxNode &other) const;
2239 llvm::hash_code getHash() const;
2240 void visit(Node *parent, NodeWalker *visitor);
2241 Node* clone() const;
2242 bool verify() const;
2243 LogSoftMaxGradNode *getGrad(GraphGradMapper &builder);
2244};
2245} // namespace glow
2246
2247
2248namespace glow {
2249class CrossEntropyLossGradNode final : public Node {
2250 NodeHandle P_;
2251 NodeHandle Labels_;
2252 NodeHandle OriginalOutputForCE_;
2253 NodeHandle GradOfOriginalOutputNamedCE_;
2254
2255 public:
2256 enum InputIndices {
2257 PIdx = 0,
2258 LabelsIdx = 1,
2259 OriginalOutputForCEIdx = 2,
2260 GradOfOriginalOutputNamedCEIdx = 3,
2261 };
2262
2263 enum ResultIndices {
2264 GradOfInputNamedPIdx = 0,
2265 GradOfInputNamedLabelsIdx = 1,
2266 };
2267
2268 CrossEntropyLossGradNode(llvm::StringRef name, NodeValue P, NodeValue Labels, NodeValue OriginalOutputForCE, NodeValue GradOfOriginalOutputNamedCE)
2269 : Node(Kinded::Kind::CrossEntropyLossGradNodeKind, name), P_(this, P), Labels_(this, Labels), OriginalOutputForCE_(this, OriginalOutputForCE), GradOfOriginalOutputNamedCE_(this, GradOfOriginalOutputNamedCE) {
2270 addResult(P.getType());
2271 addResult(Labels.getType());
2272 }
2273 const NodeValue getP() const { return P_; }
2274 const NodeValue getLabels() const { return Labels_; }
2275 const NodeValue getOriginalOutputForCE() const { return OriginalOutputForCE_; }
2276 const NodeValue getGradOfOriginalOutputNamedCE() const { return GradOfOriginalOutputNamedCE_; }
2277 NodeValue getGradOfInputNamedP() { return getNthResult(0); }
2278 const NodeValue getGradOfInputNamedP() const { return getNthResult(0); }
2279 NodeValue getGradOfInputNamedLabels() { return getNthResult(1); }
2280 const NodeValue getGradOfInputNamedLabels() const { return getNthResult(1); }
2281
2282 static bool classof(const Kinded *k) {
2283 return k->getKind() == Kinded::Kind::CrossEntropyLossGradNodeKind;
2284 }
2285
2286
2287 bool isOverwrittenNthInput(unsigned idx) const {
2288 return false;
2289 }
2290
2291 unsigned getNumInputs() const;
2292 std::string getInputName(unsigned idx) const;
2293 NodeValue getNthInput(unsigned idx);
2294 void setNthInput(unsigned idx, NodeValue val);
2295 llvm::StringRef getOutputName(unsigned idx) const;
2296 bool hasSideEffects() const { return 0; }
2297 bool isCanonical() const { return 1; }
2298 bool isDataParallel() const { return 0; }
2299 std::string getDebugDesc() const;
2300 bool isEqual(const CrossEntropyLossGradNode &other) const;
2301 llvm::hash_code getHash() const;
2302 void visit(Node *parent, NodeWalker *visitor);
2303 Node* clone() const;
2304 bool verify() const;
2305};
2306} // namespace glow
2307
2308
2309namespace glow {
2310/// Computes the average cross entropy loss of the input.
2311class CrossEntropyLossNode final : public Node {
2312 NodeHandle P_;
2313 NodeHandle Labels_;
2314
2315 public:
2316 enum InputIndices {
2317 PIdx = 0,
2318 LabelsIdx = 1,
2319 };
2320
2321 enum ResultIndices {
2322 CEIdx = 0,
2323 };
2324
2325 CrossEntropyLossNode(llvm::StringRef name, TypeRef CE , NodeValue P, NodeValue Labels)
2326 : Node(Kinded::Kind::CrossEntropyLossNodeKind, name), P_(this, P), Labels_(this, Labels) {
2327 addResult(CE);
2328 }
2329 const NodeValue getP() const { return P_; }
2330 const NodeValue getLabels() const { return Labels_; }
2331 NodeValue getCE() { return getNthResult(0); }
2332 const NodeValue getCE() const { return getNthResult(0); }
2333
2334 static bool classof(const Kinded *k) {
2335 return k->getKind() == Kinded::Kind::CrossEntropyLossNodeKind;
2336 }
2337
2338
2339 bool isOverwrittenNthInput(unsigned idx) const {
2340 return false;
2341 }
2342
2343 unsigned getNumInputs() const;
2344 std::string getInputName(unsigned idx) const;
2345 NodeValue getNthInput(unsigned idx);
2346 void setNthInput(unsigned idx, NodeValue val);
2347 llvm::StringRef getOutputName(unsigned idx) const;
2348 bool hasSideEffects() const { return 0; }
2349 bool isCanonical() const { return 1; }
2350 bool isDataParallel() const { return 0; }
2351 std::string getDebugDesc() const;
2352 bool isEqual(const CrossEntropyLossNode &other) const;
2353 llvm::hash_code getHash() const;
2354 void visit(Node *parent, NodeWalker *visitor);
2355 Node* clone() const;
2356 bool verify() const;
2357 CrossEntropyLossGradNode *getGrad(GraphGradMapper &builder);
2358};
2359} // namespace glow
2360
2361
2362namespace glow {
2363class RegressionGradNode final : public Node {
2364 NodeHandle Input_;
2365 NodeHandle Expected_;
2366 NodeHandle OriginalOutputForResult_;
2367 NodeHandle GradOfOriginalOutputNamedResult_;
2368
2369 public:
2370 enum InputIndices {
2371 InputIdx = 0,
2372 ExpectedIdx = 1,
2373 OriginalOutputForResultIdx = 2,
2374 GradOfOriginalOutputNamedResultIdx = 3,
2375 };
2376
2377 enum ResultIndices {
2378 GradOfInputNamedInputIdx = 0,
2379 GradOfInputNamedExpectedIdx = 1,
2380 };
2381
2382 RegressionGradNode(llvm::StringRef name, NodeValue Input, NodeValue Expected, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
2383 : Node(Kinded::Kind::RegressionGradNodeKind, name), Input_(this, Input), Expected_(this, Expected), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
2384 addResult(Input.getType());
2385 addResult(Expected.getType());
2386 }
2387 const NodeValue getInput() const { return Input_; }
2388 const NodeValue getExpected() const { return Expected_; }
2389 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
2390 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
2391 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
2392 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
2393 NodeValue getGradOfInputNamedExpected() { return getNthResult(1); }
2394 const NodeValue getGradOfInputNamedExpected() const { return getNthResult(1); }
2395
2396 static bool classof(const Kinded *k) {
2397 return k->getKind() == Kinded::Kind::RegressionGradNodeKind;
2398 }
2399
2400
2401 bool isOverwrittenNthInput(unsigned idx) const {
2402 return false;
2403 }
2404
2405 unsigned getNumInputs() const;
2406 std::string getInputName(unsigned idx) const;
2407 NodeValue getNthInput(unsigned idx);
2408 void setNthInput(unsigned idx, NodeValue val);
2409 llvm::StringRef getOutputName(unsigned idx) const;
2410 bool hasSideEffects() const { return 0; }
2411 bool isCanonical() const { return 1; }
2412 bool isDataParallel() const { return 0; }
2413 std::string getDebugDesc() const;
2414 bool isEqual(const RegressionGradNode &other) const;
2415 llvm::hash_code getHash() const;
2416 void visit(Node *parent, NodeWalker *visitor);
2417 Node* clone() const;
2418 bool verify() const;
2419};
2420} // namespace glow
2421
2422
2423namespace glow {
2424/// Takes an Input tensor and creates a regression output layer.
2425class RegressionNode final : public Node {
2426 NodeHandle Input_;
2427 NodeHandle Expected_;
2428
2429 public:
2430 enum InputIndices {
2431 InputIdx = 0,
2432 ExpectedIdx = 1,
2433 };
2434
2435 enum ResultIndices {
2436 ResultIdx = 0,
2437 };
2438
2439 RegressionNode(llvm::StringRef name, NodeValue Input, NodeValue Expected)
2440 : Node(Kinded::Kind::RegressionNodeKind, name), Input_(this, Input), Expected_(this, Expected) {
2441 addResult(Input.getType());
2442 }
2443 const NodeValue getInput() const { return Input_; }
2444 const NodeValue getExpected() const { return Expected_; }
2445 NodeValue getResult() { return getNthResult(0); }
2446 const NodeValue getResult() const { return getNthResult(0); }
2447
2448 static bool classof(const Kinded *k) {
2449 return k->getKind() == Kinded::Kind::RegressionNodeKind;
2450 }
2451
2452
2453 bool isOverwrittenNthInput(unsigned idx) const {
2454 return false;
2455 }
2456
2457 unsigned getNumInputs() const;
2458 std::string getInputName(unsigned idx) const;
2459 NodeValue getNthInput(unsigned idx);
2460 void setNthInput(unsigned idx, NodeValue val);
2461 llvm::StringRef getOutputName(unsigned idx) const;
2462 bool hasSideEffects() const { return 0; }
2463 bool isCanonical() const { return 1; }
2464 bool isDataParallel() const { return 0; }
2465 std::string getDebugDesc() const;
2466 bool isEqual(const RegressionNode &other) const;
2467 llvm::hash_code getHash() const;
2468 void visit(Node *parent, NodeWalker *visitor);
2469 Node* clone() const;
2470 bool verify() const;
2471 RegressionGradNode *getGrad(GraphGradMapper &builder);
2472};
2473} // namespace glow
2474
2475
2476namespace glow {
2477/// Computes the sigmoid cross entropy between two inputs.
2478class SigmoidCrossEntropyWithLogitsNode final : public Node {
2479 NodeHandle Logits_;
2480 NodeHandle Targets_;
2481
2482 public:
2483 enum InputIndices {
2484 LogitsIdx = 0,
2485 TargetsIdx = 1,
2486 };
2487
2488 enum ResultIndices {
2489 ResultIdx = 0,
2490 };
2491
2492 SigmoidCrossEntropyWithLogitsNode(llvm::StringRef name, TypeRef Result , NodeValue Logits, NodeValue Targets)
2493 : Node(Kinded::Kind::SigmoidCrossEntropyWithLogitsNodeKind, name), Logits_(this, Logits), Targets_(this, Targets) {
2494 addResult(Result);
2495 }
2496 const NodeValue getLogits() const { return Logits_; }
2497 const NodeValue getTargets() const { return Targets_; }
2498 NodeValue getResult() { return getNthResult(0); }
2499 const NodeValue getResult() const { return getNthResult(0); }
2500
2501 static bool classof(const Kinded *k) {
2502 return k->getKind() == Kinded::Kind::SigmoidCrossEntropyWithLogitsNodeKind;
2503 }
2504
2505
2506 bool isOverwrittenNthInput(unsigned idx) const {
2507 return false;
2508 }
2509
2510 unsigned getNumInputs() const;
2511 std::string getInputName(unsigned idx) const;
2512 NodeValue getNthInput(unsigned idx);
2513 void setNthInput(unsigned idx, NodeValue val);
2514 llvm::StringRef getOutputName(unsigned idx) const;
2515 bool hasSideEffects() const { return 0; }
2516 bool isCanonical() const { return 1; }
2517 bool isDataParallel() const { return 0; }
2518 std::string getDebugDesc() const;
2519 bool isEqual(const SigmoidCrossEntropyWithLogitsNode &other) const;
2520 llvm::hash_code getHash() const;
2521 void visit(Node *parent, NodeWalker *visitor);
2522 Node* clone() const;
2523 bool verify() const;
2524};
2525} // namespace glow
2526
2527
2528namespace glow {
2529class AddGradNode final : public Node {
2530 NodeHandle LHS_;
2531 NodeHandle RHS_;
2532 NodeHandle OriginalOutputForResult_;
2533 NodeHandle GradOfOriginalOutputNamedResult_;
2534
2535 public:
2536 enum InputIndices {
2537 LHSIdx = 0,
2538 RHSIdx = 1,
2539 OriginalOutputForResultIdx = 2,
2540 GradOfOriginalOutputNamedResultIdx = 3,
2541 };
2542
2543 enum ResultIndices {
2544 GradOfInputNamedLHSIdx = 0,
2545 GradOfInputNamedRHSIdx = 1,
2546 };
2547
2548 AddGradNode(llvm::StringRef name, NodeValue LHS, NodeValue RHS, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
2549 : Node(Kinded::Kind::AddGradNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
2550 addResult(LHS.getType());
2551 addResult(RHS.getType());
2552 }
2553 const NodeValue getLHS() const { return LHS_; }
2554 const NodeValue getRHS() const { return RHS_; }
2555 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
2556 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
2557 NodeValue getGradOfInputNamedLHS() { return getNthResult(0); }
2558 const NodeValue getGradOfInputNamedLHS() const { return getNthResult(0); }
2559 NodeValue getGradOfInputNamedRHS() { return getNthResult(1); }
2560 const NodeValue getGradOfInputNamedRHS() const { return getNthResult(1); }
2561
2562 static bool classof(const Kinded *k) {
2563 return k->getKind() == Kinded::Kind::AddGradNodeKind;
2564 }
2565
2566
2567 bool isOverwrittenNthInput(unsigned idx) const {
2568 return false;
2569 }
2570
2571 unsigned getNumInputs() const;
2572 std::string getInputName(unsigned idx) const;
2573 NodeValue getNthInput(unsigned idx);
2574 void setNthInput(unsigned idx, NodeValue val);
2575 llvm::StringRef getOutputName(unsigned idx) const;
2576 bool hasSideEffects() const { return 0; }
2577 bool isCanonical() const { return 1; }
2578 bool isDataParallel() const { return 1; }
2579 std::string getDebugDesc() const;
2580 bool isEqual(const AddGradNode &other) const;
2581 llvm::hash_code getHash() const;
2582 void visit(Node *parent, NodeWalker *visitor);
2583 Node* clone() const;
2584 bool verify() const;
2585};
2586} // namespace glow
2587
2588
2589namespace glow {
2590/// Performs Add on the LHS and RHS operands.
2591class AddNode final : public Node {
2592 NodeHandle LHS_;
2593 NodeHandle RHS_;
2594
2595 public:
2596 enum InputIndices {
2597 LHSIdx = 0,
2598 RHSIdx = 1,
2599 };
2600
2601 enum ResultIndices {
2602 ResultIdx = 0,
2603 };
2604
2605 AddNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
2606 : Node(Kinded::Kind::AddNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
2607 addResult(Result);
2608 }
2609 const NodeValue getLHS() const { return LHS_; }
2610 const NodeValue getRHS() const { return RHS_; }
2611 NodeValue getResult() { return getNthResult(0); }
2612 const NodeValue getResult() const { return getNthResult(0); }
2613
2614 static bool classof(const Kinded *k) {
2615 return k->getKind() == Kinded::Kind::AddNodeKind;
2616 }
2617
2618
2619 bool isOverwrittenNthInput(unsigned idx) const {
2620 return false;
2621 }
2622
2623 unsigned getNumInputs() const;
2624 std::string getInputName(unsigned idx) const;
2625 NodeValue getNthInput(unsigned idx);
2626 void setNthInput(unsigned idx, NodeValue val);
2627 llvm::StringRef getOutputName(unsigned idx) const;
2628 bool hasSideEffects() const { return 0; }
2629 bool isCanonical() const { return 1; }
2630 bool isDataParallel() const { return 1; }
2631 std::string getDebugDesc() const;
2632 bool isEqual(const AddNode &other) const;
2633 llvm::hash_code getHash() const;
2634 void visit(Node *parent, NodeWalker *visitor);
2635 Node* clone() const;
2636 bool verify() const;
2637 AddGradNode *getGrad(GraphGradMapper &builder);
2638};
2639} // namespace glow
2640
2641
2642namespace glow {
2643class MulGradNode final : public Node {
2644 NodeHandle LHS_;
2645 NodeHandle RHS_;
2646 NodeHandle OriginalOutputForResult_;
2647 NodeHandle GradOfOriginalOutputNamedResult_;
2648
2649 public:
2650 enum InputIndices {
2651 LHSIdx = 0,
2652 RHSIdx = 1,
2653 OriginalOutputForResultIdx = 2,
2654 GradOfOriginalOutputNamedResultIdx = 3,
2655 };
2656
2657 enum ResultIndices {
2658 GradOfInputNamedLHSIdx = 0,
2659 GradOfInputNamedRHSIdx = 1,
2660 };
2661
2662 MulGradNode(llvm::StringRef name, NodeValue LHS, NodeValue RHS, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
2663 : Node(Kinded::Kind::MulGradNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
2664 addResult(LHS.getType());
2665 addResult(RHS.getType());
2666 }
2667 const NodeValue getLHS() const { return LHS_; }
2668 const NodeValue getRHS() const { return RHS_; }
2669 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
2670 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
2671 NodeValue getGradOfInputNamedLHS() { return getNthResult(0); }
2672 const NodeValue getGradOfInputNamedLHS() const { return getNthResult(0); }
2673 NodeValue getGradOfInputNamedRHS() { return getNthResult(1); }
2674 const NodeValue getGradOfInputNamedRHS() const { return getNthResult(1); }
2675
2676 static bool classof(const Kinded *k) {
2677 return k->getKind() == Kinded::Kind::MulGradNodeKind;
2678 }
2679
2680
2681 bool isOverwrittenNthInput(unsigned idx) const {
2682 return false;
2683 }
2684
2685 unsigned getNumInputs() const;
2686 std::string getInputName(unsigned idx) const;
2687 NodeValue getNthInput(unsigned idx);
2688 void setNthInput(unsigned idx, NodeValue val);
2689 llvm::StringRef getOutputName(unsigned idx) const;
2690 bool hasSideEffects() const { return 0; }
2691 bool isCanonical() const { return 1; }
2692 bool isDataParallel() const { return 1; }
2693 std::string getDebugDesc() const;
2694 bool isEqual(const MulGradNode &other) const;
2695 llvm::hash_code getHash() const;
2696 void visit(Node *parent, NodeWalker *visitor);
2697 Node* clone() const;
2698 bool verify() const;
2699};
2700} // namespace glow
2701
2702
2703namespace glow {
2704/// Performs Mul on the LHS and RHS operands.
2705class MulNode final : public Node {
2706 NodeHandle LHS_;
2707 NodeHandle RHS_;
2708
2709 public:
2710 enum InputIndices {
2711 LHSIdx = 0,
2712 RHSIdx = 1,
2713 };
2714
2715 enum ResultIndices {
2716 ResultIdx = 0,
2717 };
2718
2719 MulNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
2720 : Node(Kinded::Kind::MulNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
2721 addResult(Result);
2722 }
2723 const NodeValue getLHS() const { return LHS_; }
2724 const NodeValue getRHS() const { return RHS_; }
2725 NodeValue getResult() { return getNthResult(0); }
2726 const NodeValue getResult() const { return getNthResult(0); }
2727
2728 static bool classof(const Kinded *k) {
2729 return k->getKind() == Kinded::Kind::MulNodeKind;
2730 }
2731
2732
2733 bool isOverwrittenNthInput(unsigned idx) const {
2734 return false;
2735 }
2736
2737 unsigned getNumInputs() const;
2738 std::string getInputName(unsigned idx) const;
2739 NodeValue getNthInput(unsigned idx);
2740 void setNthInput(unsigned idx, NodeValue val);
2741 llvm::StringRef getOutputName(unsigned idx) const;
2742 bool hasSideEffects() const { return 0; }
2743 bool isCanonical() const { return 1; }
2744 bool isDataParallel() const { return 1; }
2745 std::string getDebugDesc() const;
2746 bool isEqual(const MulNode &other) const;
2747 llvm::hash_code getHash() const;
2748 void visit(Node *parent, NodeWalker *visitor);
2749 Node* clone() const;
2750 bool verify() const;
2751 MulGradNode *getGrad(GraphGradMapper &builder);
2752};
2753} // namespace glow
2754
2755
2756namespace glow {
2757class SubGradNode final : public Node {
2758 NodeHandle LHS_;
2759 NodeHandle RHS_;
2760 NodeHandle OriginalOutputForResult_;
2761 NodeHandle GradOfOriginalOutputNamedResult_;
2762
2763 public:
2764 enum InputIndices {
2765 LHSIdx = 0,
2766 RHSIdx = 1,
2767 OriginalOutputForResultIdx = 2,
2768 GradOfOriginalOutputNamedResultIdx = 3,
2769 };
2770
2771 enum ResultIndices {
2772 GradOfInputNamedLHSIdx = 0,
2773 GradOfInputNamedRHSIdx = 1,
2774 };
2775
2776 SubGradNode(llvm::StringRef name, NodeValue LHS, NodeValue RHS, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
2777 : Node(Kinded::Kind::SubGradNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
2778 addResult(LHS.getType());
2779 addResult(RHS.getType());
2780 }
2781 const NodeValue getLHS() const { return LHS_; }
2782 const NodeValue getRHS() const { return RHS_; }
2783 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
2784 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
2785 NodeValue getGradOfInputNamedLHS() { return getNthResult(0); }
2786 const NodeValue getGradOfInputNamedLHS() const { return getNthResult(0); }
2787 NodeValue getGradOfInputNamedRHS() { return getNthResult(1); }
2788 const NodeValue getGradOfInputNamedRHS() const { return getNthResult(1); }
2789
2790 static bool classof(const Kinded *k) {
2791 return k->getKind() == Kinded::Kind::SubGradNodeKind;
2792 }
2793
2794
2795 bool isOverwrittenNthInput(unsigned idx) const {
2796 return false;
2797 }
2798
2799 unsigned getNumInputs() const;
2800 std::string getInputName(unsigned idx) const;
2801 NodeValue getNthInput(unsigned idx);
2802 void setNthInput(unsigned idx, NodeValue val);
2803 llvm::StringRef getOutputName(unsigned idx) const;
2804 bool hasSideEffects() const { return 0; }
2805 bool isCanonical() const { return 1; }
2806 bool isDataParallel() const { return 1; }
2807 std::string getDebugDesc() const;
2808 bool isEqual(const SubGradNode &other) const;
2809 llvm::hash_code getHash() const;
2810 void visit(Node *parent, NodeWalker *visitor);
2811 Node* clone() const;
2812 bool verify() const;
2813};
2814} // namespace glow
2815
2816
2817namespace glow {
2818/// Performs Sub on the LHS and RHS operands.
2819class SubNode final : public Node {
2820 NodeHandle LHS_;
2821 NodeHandle RHS_;
2822
2823 public:
2824 enum InputIndices {
2825 LHSIdx = 0,
2826 RHSIdx = 1,
2827 };
2828
2829 enum ResultIndices {
2830 ResultIdx = 0,
2831 };
2832
2833 SubNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
2834 : Node(Kinded::Kind::SubNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
2835 addResult(Result);
2836 }
2837 const NodeValue getLHS() const { return LHS_; }
2838 const NodeValue getRHS() const { return RHS_; }
2839 NodeValue getResult() { return getNthResult(0); }
2840 const NodeValue getResult() const { return getNthResult(0); }
2841
2842 static bool classof(const Kinded *k) {
2843 return k->getKind() == Kinded::Kind::SubNodeKind;
2844 }
2845
2846
2847 bool isOverwrittenNthInput(unsigned idx) const {
2848 return false;
2849 }
2850
2851 unsigned getNumInputs() const;
2852 std::string getInputName(unsigned idx) const;
2853 NodeValue getNthInput(unsigned idx);
2854 void setNthInput(unsigned idx, NodeValue val);
2855 llvm::StringRef getOutputName(unsigned idx) const;
2856 bool hasSideEffects() const { return 0; }
2857 bool isCanonical() const { return 1; }
2858 bool isDataParallel() const { return 1; }
2859 std::string getDebugDesc() const;
2860 bool isEqual(const SubNode &other) const;
2861 llvm::hash_code getHash() const;
2862 void visit(Node *parent, NodeWalker *visitor);
2863 Node* clone() const;
2864 bool verify() const;
2865 SubGradNode *getGrad(GraphGradMapper &builder);
2866};
2867} // namespace glow
2868
2869
2870namespace glow {
2871class DivGradNode final : public Node {
2872 NodeHandle LHS_;
2873 NodeHandle RHS_;
2874 NodeHandle OriginalOutputForResult_;
2875 NodeHandle GradOfOriginalOutputNamedResult_;
2876
2877 public:
2878 enum InputIndices {
2879 LHSIdx = 0,
2880 RHSIdx = 1,
2881 OriginalOutputForResultIdx = 2,
2882 GradOfOriginalOutputNamedResultIdx = 3,
2883 };
2884
2885 enum ResultIndices {
2886 GradOfInputNamedLHSIdx = 0,
2887 GradOfInputNamedRHSIdx = 1,
2888 };
2889
2890 DivGradNode(llvm::StringRef name, NodeValue LHS, NodeValue RHS, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
2891 : Node(Kinded::Kind::DivGradNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
2892 addResult(LHS.getType());
2893 addResult(RHS.getType());
2894 }
2895 const NodeValue getLHS() const { return LHS_; }
2896 const NodeValue getRHS() const { return RHS_; }
2897 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
2898 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
2899 NodeValue getGradOfInputNamedLHS() { return getNthResult(0); }
2900 const NodeValue getGradOfInputNamedLHS() const { return getNthResult(0); }
2901 NodeValue getGradOfInputNamedRHS() { return getNthResult(1); }
2902 const NodeValue getGradOfInputNamedRHS() const { return getNthResult(1); }
2903
2904 static bool classof(const Kinded *k) {
2905 return k->getKind() == Kinded::Kind::DivGradNodeKind;
2906 }
2907
2908
2909 bool isOverwrittenNthInput(unsigned idx) const {
2910 return false;
2911 }
2912
2913 unsigned getNumInputs() const;
2914 std::string getInputName(unsigned idx) const;
2915 NodeValue getNthInput(unsigned idx);
2916 void setNthInput(unsigned idx, NodeValue val);
2917 llvm::StringRef getOutputName(unsigned idx) const;
2918 bool hasSideEffects() const { return 0; }
2919 bool isCanonical() const { return 1; }
2920 bool isDataParallel() const { return 1; }
2921 std::string getDebugDesc() const;
2922 bool isEqual(const DivGradNode &other) const;
2923 llvm::hash_code getHash() const;
2924 void visit(Node *parent, NodeWalker *visitor);
2925 Node* clone() const;
2926 bool verify() const;
2927};
2928} // namespace glow
2929
2930
2931namespace glow {
2932/// Performs Div on the LHS and RHS operands.
2933class DivNode final : public Node {
2934 NodeHandle LHS_;
2935 NodeHandle RHS_;
2936
2937 public:
2938 enum InputIndices {
2939 LHSIdx = 0,
2940 RHSIdx = 1,
2941 };
2942
2943 enum ResultIndices {
2944 ResultIdx = 0,
2945 };
2946
2947 DivNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
2948 : Node(Kinded::Kind::DivNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
2949 addResult(Result);
2950 }
2951 const NodeValue getLHS() const { return LHS_; }
2952 const NodeValue getRHS() const { return RHS_; }
2953 NodeValue getResult() { return getNthResult(0); }
2954 const NodeValue getResult() const { return getNthResult(0); }
2955
2956 static bool classof(const Kinded *k) {
2957 return k->getKind() == Kinded::Kind::DivNodeKind;
2958 }
2959
2960
2961 bool isOverwrittenNthInput(unsigned idx) const {
2962 return false;
2963 }
2964
2965 unsigned getNumInputs() const;
2966 std::string getInputName(unsigned idx) const;
2967 NodeValue getNthInput(unsigned idx);
2968 void setNthInput(unsigned idx, NodeValue val);
2969 llvm::StringRef getOutputName(unsigned idx) const;
2970 bool hasSideEffects() const { return 0; }
2971 bool isCanonical() const { return 1; }
2972 bool isDataParallel() const { return 1; }
2973 std::string getDebugDesc() const;
2974 bool isEqual(const DivNode &other) const;
2975 llvm::hash_code getHash() const;
2976 void visit(Node *parent, NodeWalker *visitor);
2977 Node* clone() const;
2978 bool verify() const;
2979 DivGradNode *getGrad(GraphGradMapper &builder);
2980};
2981} // namespace glow
2982
2983
2984namespace glow {
2985/// Performs Div on the LHS and RHS operands, then Floor. If Truncate is set to true then truncate the quotient to zero instead.
2986class FloorDivNode final : public Node {
2987 NodeHandle LHS_;
2988 NodeHandle RHS_;
2989 bool Truncate_;
2990
2991 public:
2992 enum InputIndices {
2993 LHSIdx = 0,
2994 RHSIdx = 1,
2995 };
2996
2997 enum ResultIndices {
2998 ResultIdx = 0,
2999 };
3000
3001 FloorDivNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS, bool Truncate)
3002 : Node(Kinded::Kind::FloorDivNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), Truncate_(Truncate) {
3003 addResult(Result);
3004 }
3005 const NodeValue getLHS() const { return LHS_; }
3006 const NodeValue getRHS() const { return RHS_; }
3007 NodeValue getResult() { return getNthResult(0); }
3008 const NodeValue getResult() const { return getNthResult(0); }
3009 bool getTruncate() const { return Truncate_; }
3010
3011 static bool classof(const Kinded *k) {
3012 return k->getKind() == Kinded::Kind::FloorDivNodeKind;
3013 }
3014
3015
3016 bool isOverwrittenNthInput(unsigned idx) const {
3017 return false;
3018 }
3019
3020 unsigned getNumInputs() const;
3021 std::string getInputName(unsigned idx) const;
3022 NodeValue getNthInput(unsigned idx);
3023 void setNthInput(unsigned idx, NodeValue val);
3024 llvm::StringRef getOutputName(unsigned idx) const;
3025 bool hasSideEffects() const { return 0; }
3026 bool isCanonical() const { return 1; }
3027 bool isDataParallel() const { return 1; }
3028 std::string getDebugDesc() const;
3029 bool isEqual(const FloorDivNode &other) const;
3030 llvm::hash_code getHash() const;
3031 void visit(Node *parent, NodeWalker *visitor);
3032 Node* clone() const;
3033 bool verify() const;
3034};
3035} // namespace glow
3036
3037
3038namespace glow {
3039/// Computes the element-wise remainder of division.
3040class FmodNode final : public Node {
3041 NodeHandle LHS_;
3042 NodeHandle RHS_;
3043
3044 public:
3045 enum InputIndices {
3046 LHSIdx = 0,
3047 RHSIdx = 1,
3048 };
3049
3050 enum ResultIndices {
3051 ResultIdx = 0,
3052 };
3053
3054 FmodNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3055 : Node(Kinded::Kind::FmodNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3056 addResult(Result);
3057 }
3058 const NodeValue getLHS() const { return LHS_; }
3059 const NodeValue getRHS() const { return RHS_; }
3060 NodeValue getResult() { return getNthResult(0); }
3061 const NodeValue getResult() const { return getNthResult(0); }
3062
3063 static bool classof(const Kinded *k) {
3064 return k->getKind() == Kinded::Kind::FmodNodeKind;
3065 }
3066
3067
3068 bool isOverwrittenNthInput(unsigned idx) const {
3069 return false;
3070 }
3071
3072 unsigned getNumInputs() const;
3073 std::string getInputName(unsigned idx) const;
3074 NodeValue getNthInput(unsigned idx);
3075 void setNthInput(unsigned idx, NodeValue val);
3076 llvm::StringRef getOutputName(unsigned idx) const;
3077 bool hasSideEffects() const { return 0; }
3078 bool isCanonical() const { return 1; }
3079 bool isDataParallel() const { return 1; }
3080 std::string getDebugDesc() const;
3081 bool isEqual(const FmodNode &other) const;
3082 llvm::hash_code getHash() const;
3083 void visit(Node *parent, NodeWalker *visitor);
3084 Node* clone() const;
3085 bool verify() const;
3086};
3087} // namespace glow
3088
3089
3090namespace glow {
3091/// Performs Max on the LHS and RHS operands.
3092class MaxNode final : public Node {
3093 NodeHandle LHS_;
3094 NodeHandle RHS_;
3095
3096 public:
3097 enum InputIndices {
3098 LHSIdx = 0,
3099 RHSIdx = 1,
3100 };
3101
3102 enum ResultIndices {
3103 ResultIdx = 0,
3104 };
3105
3106 MaxNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3107 : Node(Kinded::Kind::MaxNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3108 addResult(Result);
3109 }
3110 const NodeValue getLHS() const { return LHS_; }
3111 const NodeValue getRHS() const { return RHS_; }
3112 NodeValue getResult() { return getNthResult(0); }
3113 const NodeValue getResult() const { return getNthResult(0); }
3114
3115 static bool classof(const Kinded *k) {
3116 return k->getKind() == Kinded::Kind::MaxNodeKind;
3117 }
3118
3119
3120 bool isOverwrittenNthInput(unsigned idx) const {
3121 return false;
3122 }
3123
3124 unsigned getNumInputs() const;
3125 std::string getInputName(unsigned idx) const;
3126 NodeValue getNthInput(unsigned idx);
3127 void setNthInput(unsigned idx, NodeValue val);
3128 llvm::StringRef getOutputName(unsigned idx) const;
3129 bool hasSideEffects() const { return 0; }
3130 bool isCanonical() const { return 1; }
3131 bool isDataParallel() const { return 1; }
3132 std::string getDebugDesc() const;
3133 bool isEqual(const MaxNode &other) const;
3134 llvm::hash_code getHash() const;
3135 void visit(Node *parent, NodeWalker *visitor);
3136 Node* clone() const;
3137 bool verify() const;
3138};
3139} // namespace glow
3140
3141
3142namespace glow {
3143/// Performs Min on the LHS and RHS operands.
3144class MinNode final : public Node {
3145 NodeHandle LHS_;
3146 NodeHandle RHS_;
3147
3148 public:
3149 enum InputIndices {
3150 LHSIdx = 0,
3151 RHSIdx = 1,
3152 };
3153
3154 enum ResultIndices {
3155 ResultIdx = 0,
3156 };
3157
3158 MinNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3159 : Node(Kinded::Kind::MinNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3160 addResult(Result);
3161 }
3162 const NodeValue getLHS() const { return LHS_; }
3163 const NodeValue getRHS() const { return RHS_; }
3164 NodeValue getResult() { return getNthResult(0); }
3165 const NodeValue getResult() const { return getNthResult(0); }
3166
3167 static bool classof(const Kinded *k) {
3168 return k->getKind() == Kinded::Kind::MinNodeKind;
3169 }
3170
3171
3172 bool isOverwrittenNthInput(unsigned idx) const {
3173 return false;
3174 }
3175
3176 unsigned getNumInputs() const;
3177 std::string getInputName(unsigned idx) const;
3178 NodeValue getNthInput(unsigned idx);
3179 void setNthInput(unsigned idx, NodeValue val);
3180 llvm::StringRef getOutputName(unsigned idx) const;
3181 bool hasSideEffects() const { return 0; }
3182 bool isCanonical() const { return 1; }
3183 bool isDataParallel() const { return 1; }
3184 std::string getDebugDesc() const;
3185 bool isEqual(const MinNode &other) const;
3186 llvm::hash_code getHash() const;
3187 void visit(Node *parent, NodeWalker *visitor);
3188 Node* clone() const;
3189 bool verify() const;
3190};
3191} // namespace glow
3192
3193
3194namespace glow {
3195/// Performs an element-wise EQUAL comparison between the LHS and RHS operands.
3196class CmpEQNode final : public Node {
3197 NodeHandle LHS_;
3198 NodeHandle RHS_;
3199
3200 public:
3201 enum InputIndices {
3202 LHSIdx = 0,
3203 RHSIdx = 1,
3204 };
3205
3206 enum ResultIndices {
3207 ResultIdx = 0,
3208 };
3209
3210 CmpEQNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3211 : Node(Kinded::Kind::CmpEQNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3212 addResult(Result);
3213 }
3214 const NodeValue getLHS() const { return LHS_; }
3215 const NodeValue getRHS() const { return RHS_; }
3216 NodeValue getResult() { return getNthResult(0); }
3217 const NodeValue getResult() const { return getNthResult(0); }
3218
3219 static bool classof(const Kinded *k) {
3220 return k->getKind() == Kinded::Kind::CmpEQNodeKind;
3221 }
3222
3223
3224 bool isOverwrittenNthInput(unsigned idx) const {
3225 return false;
3226 }
3227
3228 unsigned getNumInputs() const;
3229 std::string getInputName(unsigned idx) const;
3230 NodeValue getNthInput(unsigned idx);
3231 void setNthInput(unsigned idx, NodeValue val);
3232 llvm::StringRef getOutputName(unsigned idx) const;
3233 bool hasSideEffects() const { return 0; }
3234 bool isCanonical() const { return 1; }
3235 bool isDataParallel() const { return 1; }
3236 std::string getDebugDesc() const;
3237 bool isEqual(const CmpEQNode &other) const;
3238 llvm::hash_code getHash() const;
3239 void visit(Node *parent, NodeWalker *visitor);
3240 Node* clone() const;
3241 bool verify() const;
3242};
3243} // namespace glow
3244
3245
3246namespace glow {
3247/// Performs an element-wise NOT EQUAL comparison between the LHS and RHS operands.
3248class CmpNEQNode final : public Node {
3249 NodeHandle LHS_;
3250 NodeHandle RHS_;
3251
3252 public:
3253 enum InputIndices {
3254 LHSIdx = 0,
3255 RHSIdx = 1,
3256 };
3257
3258 enum ResultIndices {
3259 ResultIdx = 0,
3260 };
3261
3262 CmpNEQNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3263 : Node(Kinded::Kind::CmpNEQNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3264 addResult(Result);
3265 }
3266 const NodeValue getLHS() const { return LHS_; }
3267 const NodeValue getRHS() const { return RHS_; }
3268 NodeValue getResult() { return getNthResult(0); }
3269 const NodeValue getResult() const { return getNthResult(0); }
3270
3271 static bool classof(const Kinded *k) {
3272 return k->getKind() == Kinded::Kind::CmpNEQNodeKind;
3273 }
3274
3275
3276 bool isOverwrittenNthInput(unsigned idx) const {
3277 return false;
3278 }
3279
3280 unsigned getNumInputs() const;
3281 std::string getInputName(unsigned idx) const;
3282 NodeValue getNthInput(unsigned idx);
3283 void setNthInput(unsigned idx, NodeValue val);
3284 llvm::StringRef getOutputName(unsigned idx) const;
3285 bool hasSideEffects() const { return 0; }
3286 bool isCanonical() const { return 1; }
3287 bool isDataParallel() const { return 1; }
3288 std::string getDebugDesc() const;
3289 bool isEqual(const CmpNEQNode &other) const;
3290 llvm::hash_code getHash() const;
3291 void visit(Node *parent, NodeWalker *visitor);
3292 Node* clone() const;
3293 bool verify() const;
3294};
3295} // namespace glow
3296
3297
3298namespace glow {
3299/// Performs an element-wise LESS THAN comparison between the LHS and RHS operands.
3300class CmpLTNode final : public Node {
3301 NodeHandle LHS_;
3302 NodeHandle RHS_;
3303
3304 public:
3305 enum InputIndices {
3306 LHSIdx = 0,
3307 RHSIdx = 1,
3308 };
3309
3310 enum ResultIndices {
3311 ResultIdx = 0,
3312 };
3313
3314 CmpLTNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3315 : Node(Kinded::Kind::CmpLTNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3316 addResult(Result);
3317 }
3318 const NodeValue getLHS() const { return LHS_; }
3319 const NodeValue getRHS() const { return RHS_; }
3320 NodeValue getResult() { return getNthResult(0); }
3321 const NodeValue getResult() const { return getNthResult(0); }
3322
3323 static bool classof(const Kinded *k) {
3324 return k->getKind() == Kinded::Kind::CmpLTNodeKind;
3325 }
3326
3327
3328 bool isOverwrittenNthInput(unsigned idx) const {
3329 return false;
3330 }
3331
3332 unsigned getNumInputs() const;
3333 std::string getInputName(unsigned idx) const;
3334 NodeValue getNthInput(unsigned idx);
3335 void setNthInput(unsigned idx, NodeValue val);
3336 llvm::StringRef getOutputName(unsigned idx) const;
3337 bool hasSideEffects() const { return 0; }
3338 bool isCanonical() const { return 1; }
3339 bool isDataParallel() const { return 1; }
3340 std::string getDebugDesc() const;
3341 bool isEqual(const CmpLTNode &other) const;
3342 llvm::hash_code getHash() const;
3343 void visit(Node *parent, NodeWalker *visitor);
3344 Node* clone() const;
3345 bool verify() const;
3346};
3347} // namespace glow
3348
3349
3350namespace glow {
3351/// Performs an element-wise LESS THAN OR EQUAL comparison between the LHS and RHS operands.
3352class CmpLTENode final : public Node {
3353 NodeHandle LHS_;
3354 NodeHandle RHS_;
3355
3356 public:
3357 enum InputIndices {
3358 LHSIdx = 0,
3359 RHSIdx = 1,
3360 };
3361
3362 enum ResultIndices {
3363 ResultIdx = 0,
3364 };
3365
3366 CmpLTENode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3367 : Node(Kinded::Kind::CmpLTENodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3368 addResult(Result);
3369 }
3370 const NodeValue getLHS() const { return LHS_; }
3371 const NodeValue getRHS() const { return RHS_; }
3372 NodeValue getResult() { return getNthResult(0); }
3373 const NodeValue getResult() const { return getNthResult(0); }
3374
3375 static bool classof(const Kinded *k) {
3376 return k->getKind() == Kinded::Kind::CmpLTENodeKind;
3377 }
3378
3379
3380 bool isOverwrittenNthInput(unsigned idx) const {
3381 return false;
3382 }
3383
3384 unsigned getNumInputs() const;
3385 std::string getInputName(unsigned idx) const;
3386 NodeValue getNthInput(unsigned idx);
3387 void setNthInput(unsigned idx, NodeValue val);
3388 llvm::StringRef getOutputName(unsigned idx) const;
3389 bool hasSideEffects() const { return 0; }
3390 bool isCanonical() const { return 1; }
3391 bool isDataParallel() const { return 1; }
3392 std::string getDebugDesc() const;
3393 bool isEqual(const CmpLTENode &other) const;
3394 llvm::hash_code getHash() const;
3395 void visit(Node *parent, NodeWalker *visitor);
3396 Node* clone() const;
3397 bool verify() const;
3398};
3399} // namespace glow
3400
3401
3402namespace glow {
3403/// Performs elementwise pow(LHS, RHS).
3404class PowNode final : public Node {
3405 NodeHandle LHS_;
3406 NodeHandle RHS_;
3407
3408 public:
3409 enum InputIndices {
3410 LHSIdx = 0,
3411 RHSIdx = 1,
3412 };
3413
3414 enum ResultIndices {
3415 ResultIdx = 0,
3416 };
3417
3418 PowNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3419 : Node(Kinded::Kind::PowNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3420 addResult(Result);
3421 }
3422 const NodeValue getLHS() const { return LHS_; }
3423 const NodeValue getRHS() const { return RHS_; }
3424 NodeValue getResult() { return getNthResult(0); }
3425 const NodeValue getResult() const { return getNthResult(0); }
3426
3427 static bool classof(const Kinded *k) {
3428 return k->getKind() == Kinded::Kind::PowNodeKind;
3429 }
3430
3431
3432 bool isOverwrittenNthInput(unsigned idx) const {
3433 return false;
3434 }
3435
3436 unsigned getNumInputs() const;
3437 std::string getInputName(unsigned idx) const;
3438 NodeValue getNthInput(unsigned idx);
3439 void setNthInput(unsigned idx, NodeValue val);
3440 llvm::StringRef getOutputName(unsigned idx) const;
3441 bool hasSideEffects() const { return 0; }
3442 bool isCanonical() const { return 1; }
3443 bool isDataParallel() const { return 1; }
3444 std::string getDebugDesc() const;
3445 bool isEqual(const PowNode &other) const;
3446 llvm::hash_code getHash() const;
3447 void visit(Node *parent, NodeWalker *visitor);
3448 Node* clone() const;
3449 bool verify() const;
3450};
3451} // namespace glow
3452
3453
3454namespace glow {
3455/// Performs an element-wise logical AND between the LHS and RHS operands.
3456class AndNode final : public Node {
3457 NodeHandle LHS_;
3458 NodeHandle RHS_;
3459
3460 public:
3461 enum InputIndices {
3462 LHSIdx = 0,
3463 RHSIdx = 1,
3464 };
3465
3466 enum ResultIndices {
3467 ResultIdx = 0,
3468 };
3469
3470 AndNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3471 : Node(Kinded::Kind::AndNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3472 addResult(Result);
3473 }
3474 const NodeValue getLHS() const { return LHS_; }
3475 const NodeValue getRHS() const { return RHS_; }
3476 NodeValue getResult() { return getNthResult(0); }
3477 const NodeValue getResult() const { return getNthResult(0); }
3478
3479 static bool classof(const Kinded *k) {
3480 return k->getKind() == Kinded::Kind::AndNodeKind;
3481 }
3482
3483
3484 bool isOverwrittenNthInput(unsigned idx) const {
3485 return false;
3486 }
3487
3488 unsigned getNumInputs() const;
3489 std::string getInputName(unsigned idx) const;
3490 NodeValue getNthInput(unsigned idx);
3491 void setNthInput(unsigned idx, NodeValue val);
3492 llvm::StringRef getOutputName(unsigned idx) const;
3493 bool hasSideEffects() const { return 0; }
3494 bool isCanonical() const { return 1; }
3495 bool isDataParallel() const { return 1; }
3496 std::string getDebugDesc() const;
3497 bool isEqual(const AndNode &other) const;
3498 llvm::hash_code getHash() const;
3499 void visit(Node *parent, NodeWalker *visitor);
3500 Node* clone() const;
3501 bool verify() const;
3502};
3503} // namespace glow
3504
3505
3506namespace glow {
3507/// Performs an element-wise bitwise AND between the LHS and RHS operands.
3508class BitwiseAndNode final : public Node {
3509 NodeHandle LHS_;
3510 NodeHandle RHS_;
3511
3512 public:
3513 enum InputIndices {
3514 LHSIdx = 0,
3515 RHSIdx = 1,
3516 };
3517
3518 enum ResultIndices {
3519 ResultIdx = 0,
3520 };
3521
3522 BitwiseAndNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3523 : Node(Kinded::Kind::BitwiseAndNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3524 addResult(Result);
3525 }
3526 const NodeValue getLHS() const { return LHS_; }
3527 const NodeValue getRHS() const { return RHS_; }
3528 NodeValue getResult() { return getNthResult(0); }
3529 const NodeValue getResult() const { return getNthResult(0); }
3530
3531 static bool classof(const Kinded *k) {
3532 return k->getKind() == Kinded::Kind::BitwiseAndNodeKind;
3533 }
3534
3535
3536 bool isOverwrittenNthInput(unsigned idx) const {
3537 return false;
3538 }
3539
3540 unsigned getNumInputs() const;
3541 std::string getInputName(unsigned idx) const;
3542 NodeValue getNthInput(unsigned idx);
3543 void setNthInput(unsigned idx, NodeValue val);
3544 llvm::StringRef getOutputName(unsigned idx) const;
3545 bool hasSideEffects() const { return 0; }
3546 bool isCanonical() const { return 1; }
3547 bool isDataParallel() const { return 1; }
3548 std::string getDebugDesc() const;
3549 bool isEqual(const BitwiseAndNode &other) const;
3550 llvm::hash_code getHash() const;
3551 void visit(Node *parent, NodeWalker *visitor);
3552 Node* clone() const;
3553 bool verify() const;
3554};
3555} // namespace glow
3556
3557
3558namespace glow {
3559/// Performs an element-wise logical OR between the LHS and RHS operands.
3560class OrNode final : public Node {
3561 NodeHandle LHS_;
3562 NodeHandle RHS_;
3563
3564 public:
3565 enum InputIndices {
3566 LHSIdx = 0,
3567 RHSIdx = 1,
3568 };
3569
3570 enum ResultIndices {
3571 ResultIdx = 0,
3572 };
3573
3574 OrNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3575 : Node(Kinded::Kind::OrNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3576 addResult(Result);
3577 }
3578 const NodeValue getLHS() const { return LHS_; }
3579 const NodeValue getRHS() const { return RHS_; }
3580 NodeValue getResult() { return getNthResult(0); }
3581 const NodeValue getResult() const { return getNthResult(0); }
3582
3583 static bool classof(const Kinded *k) {
3584 return k->getKind() == Kinded::Kind::OrNodeKind;
3585 }
3586
3587
3588 bool isOverwrittenNthInput(unsigned idx) const {
3589 return false;
3590 }
3591
3592 unsigned getNumInputs() const;
3593 std::string getInputName(unsigned idx) const;
3594 NodeValue getNthInput(unsigned idx);
3595 void setNthInput(unsigned idx, NodeValue val);
3596 llvm::StringRef getOutputName(unsigned idx) const;
3597 bool hasSideEffects() const { return 0; }
3598 bool isCanonical() const { return 1; }
3599 bool isDataParallel() const { return 1; }
3600 std::string getDebugDesc() const;
3601 bool isEqual(const OrNode &other) const;
3602 llvm::hash_code getHash() const;
3603 void visit(Node *parent, NodeWalker *visitor);
3604 Node* clone() const;
3605 bool verify() const;
3606};
3607} // namespace glow
3608
3609
3610namespace glow {
3611/// Performs an element-wise bitwise OR between the LHS and RHS operands.
3612class BitwiseOrNode final : public Node {
3613 NodeHandle LHS_;
3614 NodeHandle RHS_;
3615
3616 public:
3617 enum InputIndices {
3618 LHSIdx = 0,
3619 RHSIdx = 1,
3620 };
3621
3622 enum ResultIndices {
3623 ResultIdx = 0,
3624 };
3625
3626 BitwiseOrNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3627 : Node(Kinded::Kind::BitwiseOrNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3628 addResult(Result);
3629 }
3630 const NodeValue getLHS() const { return LHS_; }
3631 const NodeValue getRHS() const { return RHS_; }
3632 NodeValue getResult() { return getNthResult(0); }
3633 const NodeValue getResult() const { return getNthResult(0); }
3634
3635 static bool classof(const Kinded *k) {
3636 return k->getKind() == Kinded::Kind::BitwiseOrNodeKind;
3637 }
3638
3639
3640 bool isOverwrittenNthInput(unsigned idx) const {
3641 return false;
3642 }
3643
3644 unsigned getNumInputs() const;
3645 std::string getInputName(unsigned idx) const;
3646 NodeValue getNthInput(unsigned idx);
3647 void setNthInput(unsigned idx, NodeValue val);
3648 llvm::StringRef getOutputName(unsigned idx) const;
3649 bool hasSideEffects() const { return 0; }
3650 bool isCanonical() const { return 1; }
3651 bool isDataParallel() const { return 1; }
3652 std::string getDebugDesc() const;
3653 bool isEqual(const BitwiseOrNode &other) const;
3654 llvm::hash_code getHash() const;
3655 void visit(Node *parent, NodeWalker *visitor);
3656 Node* clone() const;
3657 bool verify() const;
3658};
3659} // namespace glow
3660
3661
3662namespace glow {
3663/// Performs an element-wise logical XOR between the LHS and RHS operands.
3664class XorNode final : public Node {
3665 NodeHandle LHS_;
3666 NodeHandle RHS_;
3667
3668 public:
3669 enum InputIndices {
3670 LHSIdx = 0,
3671 RHSIdx = 1,
3672 };
3673
3674 enum ResultIndices {
3675 ResultIdx = 0,
3676 };
3677
3678 XorNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3679 : Node(Kinded::Kind::XorNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3680 addResult(Result);
3681 }
3682 const NodeValue getLHS() const { return LHS_; }
3683 const NodeValue getRHS() const { return RHS_; }
3684 NodeValue getResult() { return getNthResult(0); }
3685 const NodeValue getResult() const { return getNthResult(0); }
3686
3687 static bool classof(const Kinded *k) {
3688 return k->getKind() == Kinded::Kind::XorNodeKind;
3689 }
3690
3691
3692 bool isOverwrittenNthInput(unsigned idx) const {
3693 return false;
3694 }
3695
3696 unsigned getNumInputs() const;
3697 std::string getInputName(unsigned idx) const;
3698 NodeValue getNthInput(unsigned idx);
3699 void setNthInput(unsigned idx, NodeValue val);
3700 llvm::StringRef getOutputName(unsigned idx) const;
3701 bool hasSideEffects() const { return 0; }
3702 bool isCanonical() const { return 1; }
3703 bool isDataParallel() const { return 1; }
3704 std::string getDebugDesc() const;
3705 bool isEqual(const XorNode &other) const;
3706 llvm::hash_code getHash() const;
3707 void visit(Node *parent, NodeWalker *visitor);
3708 Node* clone() const;
3709 bool verify() const;
3710};
3711} // namespace glow
3712
3713
3714namespace glow {
3715/// Performs an element-wise bitwise XOR between the LHS and RHS operands.
3716class BitwiseXorNode final : public Node {
3717 NodeHandle LHS_;
3718 NodeHandle RHS_;
3719
3720 public:
3721 enum InputIndices {
3722 LHSIdx = 0,
3723 RHSIdx = 1,
3724 };
3725
3726 enum ResultIndices {
3727 ResultIdx = 0,
3728 };
3729
3730 BitwiseXorNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
3731 : Node(Kinded::Kind::BitwiseXorNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
3732 addResult(Result);
3733 }
3734 const NodeValue getLHS() const { return LHS_; }
3735 const NodeValue getRHS() const { return RHS_; }
3736 NodeValue getResult() { return getNthResult(0); }
3737 const NodeValue getResult() const { return getNthResult(0); }
3738
3739 static bool classof(const Kinded *k) {
3740 return k->getKind() == Kinded::Kind::BitwiseXorNodeKind;
3741 }
3742
3743
3744 bool isOverwrittenNthInput(unsigned idx) const {
3745 return false;
3746 }
3747
3748 unsigned getNumInputs() const;
3749 std::string getInputName(unsigned idx) const;
3750 NodeValue getNthInput(unsigned idx);
3751 void setNthInput(unsigned idx, NodeValue val);
3752 llvm::StringRef getOutputName(unsigned idx) const;
3753 bool hasSideEffects() const { return 0; }
3754 bool isCanonical() const { return 1; }
3755 bool isDataParallel() const { return 1; }
3756 std::string getDebugDesc() const;
3757 bool isEqual(const BitwiseXorNode &other) const;
3758 llvm::hash_code getHash() const;
3759 void visit(Node *parent, NodeWalker *visitor);
3760 Node* clone() const;
3761 bool verify() const;
3762};
3763} // namespace glow
3764
3765
3766namespace glow {
3767/// Performs an element-wise logical NOT of the Input operand.
3768class NotNode final : public Node {
3769 NodeHandle Input_;
3770
3771 public:
3772 enum InputIndices {
3773 InputIdx = 0,
3774 };
3775
3776 enum ResultIndices {
3777 ResultIdx = 0,
3778 };
3779
3780 NotNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
3781 : Node(Kinded::Kind::NotNodeKind, name), Input_(this, Input) {
3782 addResult(Result);
3783 }
3784 const NodeValue getInput() const { return Input_; }
3785 NodeValue getResult() { return getNthResult(0); }
3786 const NodeValue getResult() const { return getNthResult(0); }
3787
3788 static bool classof(const Kinded *k) {
3789 return k->getKind() == Kinded::Kind::NotNodeKind;
3790 }
3791
3792
3793 bool isOverwrittenNthInput(unsigned idx) const {
3794 return false;
3795 }
3796
3797 unsigned getNumInputs() const;
3798 std::string getInputName(unsigned idx) const;
3799 NodeValue getNthInput(unsigned idx);
3800 void setNthInput(unsigned idx, NodeValue val);
3801 llvm::StringRef getOutputName(unsigned idx) const;
3802 bool hasSideEffects() const { return 0; }
3803 bool isCanonical() const { return 1; }
3804 bool isDataParallel() const { return 1; }
3805 std::string getDebugDesc() const;
3806 bool isEqual(const NotNode &other) const;
3807 llvm::hash_code getHash() const;
3808 void visit(Node *parent, NodeWalker *visitor);
3809 Node* clone() const;
3810 bool verify() const;
3811};
3812} // namespace glow
3813
3814
3815namespace glow {
3816/// Performs an element-wise bitwise NOT of the Input operand.
3817class BitwiseNotNode final : public Node {
3818 NodeHandle Input_;
3819
3820 public:
3821 enum InputIndices {
3822 InputIdx = 0,
3823 };
3824
3825 enum ResultIndices {
3826 ResultIdx = 0,
3827 };
3828
3829 BitwiseNotNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
3830 : Node(Kinded::Kind::BitwiseNotNodeKind, name), Input_(this, Input) {
3831 addResult(Result);
3832 }
3833 const NodeValue getInput() const { return Input_; }
3834 NodeValue getResult() { return getNthResult(0); }
3835 const NodeValue getResult() const { return getNthResult(0); }
3836
3837 static bool classof(const Kinded *k) {
3838 return k->getKind() == Kinded::Kind::BitwiseNotNodeKind;
3839 }
3840
3841
3842 bool isOverwrittenNthInput(unsigned idx) const {
3843 return false;
3844 }
3845
3846 unsigned getNumInputs() const;
3847 std::string getInputName(unsigned idx) const;
3848 NodeValue getNthInput(unsigned idx);
3849 void setNthInput(unsigned idx, NodeValue val);
3850 llvm::StringRef getOutputName(unsigned idx) const;
3851 bool hasSideEffects() const { return 0; }
3852 bool isCanonical() const { return 1; }
3853 bool isDataParallel() const { return 1; }
3854 std::string getDebugDesc() const;
3855 bool isEqual(const BitwiseNotNode &other) const;
3856 llvm::hash_code getHash() const;
3857 void visit(Node *parent, NodeWalker *visitor);
3858 Node* clone() const;
3859 bool verify() const;
3860};
3861} // namespace glow
3862
3863
3864namespace glow {
3865/// Performs an element-wise negation (sign flip) of the Input operand.
3866class NegNode final : public Node {
3867 NodeHandle Input_;
3868
3869 public:
3870 enum InputIndices {
3871 InputIdx = 0,
3872 };
3873
3874 enum ResultIndices {
3875 ResultIdx = 0,
3876 };
3877
3878 NegNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
3879 : Node(Kinded::Kind::NegNodeKind, name), Input_(this, Input) {
3880 addResult(Result);
3881 }
3882 const NodeValue getInput() const { return Input_; }
3883 NodeValue getResult() { return getNthResult(0); }
3884 const NodeValue getResult() const { return getNthResult(0); }
3885
3886 static bool classof(const Kinded *k) {
3887 return k->getKind() == Kinded::Kind::NegNodeKind;
3888 }
3889
3890
3891 bool isOverwrittenNthInput(unsigned idx) const {
3892 return false;
3893 }
3894
3895 unsigned getNumInputs() const;
3896 std::string getInputName(unsigned idx) const;
3897 NodeValue getNthInput(unsigned idx);
3898 void setNthInput(unsigned idx, NodeValue val);
3899 llvm::StringRef getOutputName(unsigned idx) const;
3900 bool hasSideEffects() const { return 0; }
3901 bool isCanonical() const { return 1; }
3902 bool isDataParallel() const { return 1; }
3903 std::string getDebugDesc() const;
3904 bool isEqual(const NegNode &other) const;
3905 llvm::hash_code getHash() const;
3906 void visit(Node *parent, NodeWalker *visitor);
3907 Node* clone() const;
3908 bool verify() const;
3909};
3910} // namespace glow
3911
3912
3913namespace glow {
3914/// Performs an element-wise ABS(x) of the Input operand.
3915class AbsNode final : public Node {
3916 NodeHandle Input_;
3917
3918 public:
3919 enum InputIndices {
3920 InputIdx = 0,
3921 };
3922
3923 enum ResultIndices {
3924 ResultIdx = 0,
3925 };
3926
3927 AbsNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
3928 : Node(Kinded::Kind::AbsNodeKind, name), Input_(this, Input) {
3929 addResult(Result);
3930 }
3931 const NodeValue getInput() const { return Input_; }
3932 NodeValue getResult() { return getNthResult(0); }
3933 const NodeValue getResult() const { return getNthResult(0); }
3934
3935 static bool classof(const Kinded *k) {
3936 return k->getKind() == Kinded::Kind::AbsNodeKind;
3937 }
3938
3939
3940 bool isOverwrittenNthInput(unsigned idx) const {
3941 return false;
3942 }
3943
3944 unsigned getNumInputs() const;
3945 std::string getInputName(unsigned idx) const;
3946 NodeValue getNthInput(unsigned idx);
3947 void setNthInput(unsigned idx, NodeValue val);
3948 llvm::StringRef getOutputName(unsigned idx) const;
3949 bool hasSideEffects() const { return 0; }
3950 bool isCanonical() const { return 1; }
3951 bool isDataParallel() const { return 1; }
3952 std::string getDebugDesc() const;
3953 bool isEqual(const AbsNode &other) const;
3954 llvm::hash_code getHash() const;
3955 void visit(Node *parent, NodeWalker *visitor);
3956 Node* clone() const;
3957 bool verify() const;
3958};
3959} // namespace glow
3960
3961
3962namespace glow {
3963/// Performs an element-wise FLOOR(x) of the Input operand.
3964class FloorNode final : public Node {
3965 NodeHandle Input_;
3966
3967 public:
3968 enum InputIndices {
3969 InputIdx = 0,
3970 };
3971
3972 enum ResultIndices {
3973 ResultIdx = 0,
3974 };
3975
3976 FloorNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
3977 : Node(Kinded::Kind::FloorNodeKind, name), Input_(this, Input) {
3978 addResult(Result);
3979 }
3980 const NodeValue getInput() const { return Input_; }
3981 NodeValue getResult() { return getNthResult(0); }
3982 const NodeValue getResult() const { return getNthResult(0); }
3983
3984 static bool classof(const Kinded *k) {
3985 return k->getKind() == Kinded::Kind::FloorNodeKind;
3986 }
3987
3988
3989 bool isOverwrittenNthInput(unsigned idx) const {
3990 return false;
3991 }
3992
3993 unsigned getNumInputs() const;
3994 std::string getInputName(unsigned idx) const;
3995 NodeValue getNthInput(unsigned idx);
3996 void setNthInput(unsigned idx, NodeValue val);
3997 llvm::StringRef getOutputName(unsigned idx) const;
3998 bool hasSideEffects() const { return 0; }
3999 bool isCanonical() const { return 1; }
4000 bool isDataParallel() const { return 1; }
4001 std::string getDebugDesc() const;
4002 bool isEqual(const FloorNode &other) const;
4003 llvm::hash_code getHash() const;
4004 void visit(Node *parent, NodeWalker *visitor);
4005 Node* clone() const;
4006 bool verify() const;
4007};
4008} // namespace glow
4009
4010
4011namespace glow {
4012/// Performs an element-wise Sign(x) of the Input operand
4013class SignNode final : public Node {
4014 NodeHandle Input_;
4015
4016 public:
4017 enum InputIndices {
4018 InputIdx = 0,
4019 };
4020
4021 enum ResultIndices {
4022 ResultIdx = 0,
4023 };
4024
4025 SignNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4026 : Node(Kinded::Kind::SignNodeKind, name), Input_(this, Input) {
4027 addResult(Result);
4028 }
4029 const NodeValue getInput() const { return Input_; }
4030 NodeValue getResult() { return getNthResult(0); }
4031 const NodeValue getResult() const { return getNthResult(0); }
4032
4033 static bool classof(const Kinded *k) {
4034 return k->getKind() == Kinded::Kind::SignNodeKind;
4035 }
4036
4037
4038 bool isOverwrittenNthInput(unsigned idx) const {
4039 return false;
4040 }
4041
4042 unsigned getNumInputs() const;
4043 std::string getInputName(unsigned idx) const;
4044 NodeValue getNthInput(unsigned idx);
4045 void setNthInput(unsigned idx, NodeValue val);
4046 llvm::StringRef getOutputName(unsigned idx) const;
4047 bool hasSideEffects() const { return 0; }
4048 bool isCanonical() const { return 1; }
4049 bool isDataParallel() const { return 1; }
4050 std::string getDebugDesc() const;
4051 bool isEqual(const SignNode &other) const;
4052 llvm::hash_code getHash() const;
4053 void visit(Node *parent, NodeWalker *visitor);
4054 Node* clone() const;
4055 bool verify() const;
4056};
4057} // namespace glow
4058
4059
4060namespace glow {
4061/// Performs an element-wise CEIL(x) of the Input operand.
4062class CeilNode final : public Node {
4063 NodeHandle Input_;
4064
4065 public:
4066 enum InputIndices {
4067 InputIdx = 0,
4068 };
4069
4070 enum ResultIndices {
4071 ResultIdx = 0,
4072 };
4073
4074 CeilNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4075 : Node(Kinded::Kind::CeilNodeKind, name), Input_(this, Input) {
4076 addResult(Result);
4077 }
4078 const NodeValue getInput() const { return Input_; }
4079 NodeValue getResult() { return getNthResult(0); }
4080 const NodeValue getResult() const { return getNthResult(0); }
4081
4082 static bool classof(const Kinded *k) {
4083 return k->getKind() == Kinded::Kind::CeilNodeKind;
4084 }
4085
4086
4087 bool isOverwrittenNthInput(unsigned idx) const {
4088 return false;
4089 }
4090
4091 unsigned getNumInputs() const;
4092 std::string getInputName(unsigned idx) const;
4093 NodeValue getNthInput(unsigned idx);
4094 void setNthInput(unsigned idx, NodeValue val);
4095 llvm::StringRef getOutputName(unsigned idx) const;
4096 bool hasSideEffects() const { return 0; }
4097 bool isCanonical() const { return 1; }
4098 bool isDataParallel() const { return 1; }
4099 std::string getDebugDesc() const;
4100 bool isEqual(const CeilNode &other) const;
4101 llvm::hash_code getHash() const;
4102 void visit(Node *parent, NodeWalker *visitor);
4103 Node* clone() const;
4104 bool verify() const;
4105};
4106} // namespace glow
4107
4108
4109namespace glow {
4110/// Performs an element-wise ROUND(x) of the Input operand.
4111class RoundNode final : public Node {
4112 NodeHandle Input_;
4113
4114 public:
4115 enum InputIndices {
4116 InputIdx = 0,
4117 };
4118
4119 enum ResultIndices {
4120 ResultIdx = 0,
4121 };
4122
4123 RoundNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4124 : Node(Kinded::Kind::RoundNodeKind, name), Input_(this, Input) {
4125 addResult(Result);
4126 }
4127 const NodeValue getInput() const { return Input_; }
4128 NodeValue getResult() { return getNthResult(0); }
4129 const NodeValue getResult() const { return getNthResult(0); }
4130
4131 static bool classof(const Kinded *k) {
4132 return k->getKind() == Kinded::Kind::RoundNodeKind;
4133 }
4134
4135
4136 bool isOverwrittenNthInput(unsigned idx) const {
4137 return false;
4138 }
4139
4140 unsigned getNumInputs() const;
4141 std::string getInputName(unsigned idx) const;
4142 NodeValue getNthInput(unsigned idx);
4143 void setNthInput(unsigned idx, NodeValue val);
4144 llvm::StringRef getOutputName(unsigned idx) const;
4145 bool hasSideEffects() const { return 0; }
4146 bool isCanonical() const { return 1; }
4147 bool isDataParallel() const { return 1; }
4148 std::string getDebugDesc() const;
4149 bool isEqual(const RoundNode &other) const;
4150 llvm::hash_code getHash() const;
4151 void visit(Node *parent, NodeWalker *visitor);
4152 Node* clone() const;
4153 bool verify() const;
4154};
4155} // namespace glow
4156
4157
4158namespace glow {
4159/// Performs an element-wise TRUNCATE(x) of the Input operand.
4160class TruncateNode final : public Node {
4161 NodeHandle Input_;
4162
4163 public:
4164 enum InputIndices {
4165 InputIdx = 0,
4166 };
4167
4168 enum ResultIndices {
4169 ResultIdx = 0,
4170 };
4171
4172 TruncateNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4173 : Node(Kinded::Kind::TruncateNodeKind, name), Input_(this, Input) {
4174 addResult(Result);
4175 }
4176 const NodeValue getInput() const { return Input_; }
4177 NodeValue getResult() { return getNthResult(0); }
4178 const NodeValue getResult() const { return getNthResult(0); }
4179
4180 static bool classof(const Kinded *k) {
4181 return k->getKind() == Kinded::Kind::TruncateNodeKind;
4182 }
4183
4184
4185 bool isOverwrittenNthInput(unsigned idx) const {
4186 return false;
4187 }
4188
4189 unsigned getNumInputs() const;
4190 std::string getInputName(unsigned idx) const;
4191 NodeValue getNthInput(unsigned idx);
4192 void setNthInput(unsigned idx, NodeValue val);
4193 llvm::StringRef getOutputName(unsigned idx) const;
4194 bool hasSideEffects() const { return 0; }
4195 bool isCanonical() const { return 1; }
4196 bool isDataParallel() const { return 1; }
4197 std::string getDebugDesc() const;
4198 bool isEqual(const TruncateNode &other) const;
4199 llvm::hash_code getHash() const;
4200 void visit(Node *parent, NodeWalker *visitor);
4201 Node* clone() const;
4202 bool verify() const;
4203};
4204} // namespace glow
4205
4206
4207namespace glow {
4208/// Performs an element-wise SQRT(x) of the Input operand.
4209class SqrtNode final : public Node {
4210 NodeHandle Input_;
4211
4212 public:
4213 enum InputIndices {
4214 InputIdx = 0,
4215 };
4216
4217 enum ResultIndices {
4218 ResultIdx = 0,
4219 };
4220
4221 SqrtNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4222 : Node(Kinded::Kind::SqrtNodeKind, name), Input_(this, Input) {
4223 addResult(Result);
4224 }
4225 const NodeValue getInput() const { return Input_; }
4226 NodeValue getResult() { return getNthResult(0); }
4227 const NodeValue getResult() const { return getNthResult(0); }
4228
4229 static bool classof(const Kinded *k) {
4230 return k->getKind() == Kinded::Kind::SqrtNodeKind;
4231 }
4232
4233
4234 bool isOverwrittenNthInput(unsigned idx) const {
4235 return false;
4236 }
4237
4238 unsigned getNumInputs() const;
4239 std::string getInputName(unsigned idx) const;
4240 NodeValue getNthInput(unsigned idx);
4241 void setNthInput(unsigned idx, NodeValue val);
4242 llvm::StringRef getOutputName(unsigned idx) const;
4243 bool hasSideEffects() const { return 0; }
4244 bool isCanonical() const { return 1; }
4245 bool isDataParallel() const { return 1; }
4246 std::string getDebugDesc() const;
4247 bool isEqual(const SqrtNode &other) const;
4248 llvm::hash_code getHash() const;
4249 void visit(Node *parent, NodeWalker *visitor);
4250 Node* clone() const;
4251 bool verify() const;
4252};
4253} // namespace glow
4254
4255
4256namespace glow {
4257/// Performs an element-wise RSQRT(x) = 1 / SQRT(x) of the Input operand.
4258class RsqrtNode final : public Node {
4259 NodeHandle Input_;
4260
4261 public:
4262 enum InputIndices {
4263 InputIdx = 0,
4264 };
4265
4266 enum ResultIndices {
4267 ResultIdx = 0,
4268 };
4269
4270 RsqrtNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4271 : Node(Kinded::Kind::RsqrtNodeKind, name), Input_(this, Input) {
4272 addResult(Result);
4273 }
4274 const NodeValue getInput() const { return Input_; }
4275 NodeValue getResult() { return getNthResult(0); }
4276 const NodeValue getResult() const { return getNthResult(0); }
4277
4278 static bool classof(const Kinded *k) {
4279 return k->getKind() == Kinded::Kind::RsqrtNodeKind;
4280 }
4281
4282
4283 bool isOverwrittenNthInput(unsigned idx) const {
4284 return false;
4285 }
4286
4287 unsigned getNumInputs() const;
4288 std::string getInputName(unsigned idx) const;
4289 NodeValue getNthInput(unsigned idx);
4290 void setNthInput(unsigned idx, NodeValue val);
4291 llvm::StringRef getOutputName(unsigned idx) const;
4292 bool hasSideEffects() const { return 0; }
4293 bool isCanonical() const { return 1; }
4294 bool isDataParallel() const { return 1; }
4295 std::string getDebugDesc() const;
4296 bool isEqual(const RsqrtNode &other) const;
4297 llvm::hash_code getHash() const;
4298 void visit(Node *parent, NodeWalker *visitor);
4299 Node* clone() const;
4300 bool verify() const;
4301};
4302} // namespace glow
4303
4304
4305namespace glow {
4306/// Performs an element-wise RECIPROCAL(x) = 1 / x of the Input operand.
4307class ReciprocalNode final : public Node {
4308 NodeHandle Input_;
4309
4310 public:
4311 enum InputIndices {
4312 InputIdx = 0,
4313 };
4314
4315 enum ResultIndices {
4316 ResultIdx = 0,
4317 };
4318
4319 ReciprocalNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4320 : Node(Kinded::Kind::ReciprocalNodeKind, name), Input_(this, Input) {
4321 addResult(Result);
4322 }
4323 const NodeValue getInput() const { return Input_; }
4324 NodeValue getResult() { return getNthResult(0); }
4325 const NodeValue getResult() const { return getNthResult(0); }
4326
4327 static bool classof(const Kinded *k) {
4328 return k->getKind() == Kinded::Kind::ReciprocalNodeKind;
4329 }
4330
4331
4332 bool isOverwrittenNthInput(unsigned idx) const {
4333 return false;
4334 }
4335
4336 unsigned getNumInputs() const;
4337 std::string getInputName(unsigned idx) const;
4338 NodeValue getNthInput(unsigned idx);
4339 void setNthInput(unsigned idx, NodeValue val);
4340 llvm::StringRef getOutputName(unsigned idx) const;
4341 bool hasSideEffects() const { return 0; }
4342 bool isCanonical() const { return 1; }
4343 bool isDataParallel() const { return 1; }
4344 std::string getDebugDesc() const;
4345 bool isEqual(const ReciprocalNode &other) const;
4346 llvm::hash_code getHash() const;
4347 void visit(Node *parent, NodeWalker *visitor);
4348 Node* clone() const;
4349 bool verify() const;
4350};
4351} // namespace glow
4352
4353
4354namespace glow {
4355/// Performs an element-wise SIN(x) of the Input operand.
4356class SinNode final : public Node {
4357 NodeHandle Input_;
4358
4359 public:
4360 enum InputIndices {
4361 InputIdx = 0,
4362 };
4363
4364 enum ResultIndices {
4365 ResultIdx = 0,
4366 };
4367
4368 SinNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4369 : Node(Kinded::Kind::SinNodeKind, name), Input_(this, Input) {
4370 addResult(Result);
4371 }
4372 const NodeValue getInput() const { return Input_; }
4373 NodeValue getResult() { return getNthResult(0); }
4374 const NodeValue getResult() const { return getNthResult(0); }
4375
4376 static bool classof(const Kinded *k) {
4377 return k->getKind() == Kinded::Kind::SinNodeKind;
4378 }
4379
4380
4381 bool isOverwrittenNthInput(unsigned idx) const {
4382 return false;
4383 }
4384
4385 unsigned getNumInputs() const;
4386 std::string getInputName(unsigned idx) const;
4387 NodeValue getNthInput(unsigned idx);
4388 void setNthInput(unsigned idx, NodeValue val);
4389 llvm::StringRef getOutputName(unsigned idx) const;
4390 bool hasSideEffects() const { return 0; }
4391 bool isCanonical() const { return 1; }
4392 bool isDataParallel() const { return 1; }
4393 std::string getDebugDesc() const;
4394 bool isEqual(const SinNode &other) const;
4395 llvm::hash_code getHash() const;
4396 void visit(Node *parent, NodeWalker *visitor);
4397 Node* clone() const;
4398 bool verify() const;
4399};
4400} // namespace glow
4401
4402
4403namespace glow {
4404/// Performs an element-wise COS(x) of the Input operand.
4405class CosNode final : public Node {
4406 NodeHandle Input_;
4407
4408 public:
4409 enum InputIndices {
4410 InputIdx = 0,
4411 };
4412
4413 enum ResultIndices {
4414 ResultIdx = 0,
4415 };
4416
4417 CosNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4418 : Node(Kinded::Kind::CosNodeKind, name), Input_(this, Input) {
4419 addResult(Result);
4420 }
4421 const NodeValue getInput() const { return Input_; }
4422 NodeValue getResult() { return getNthResult(0); }
4423 const NodeValue getResult() const { return getNthResult(0); }
4424
4425 static bool classof(const Kinded *k) {
4426 return k->getKind() == Kinded::Kind::CosNodeKind;
4427 }
4428
4429
4430 bool isOverwrittenNthInput(unsigned idx) const {
4431 return false;
4432 }
4433
4434 unsigned getNumInputs() const;
4435 std::string getInputName(unsigned idx) const;
4436 NodeValue getNthInput(unsigned idx);
4437 void setNthInput(unsigned idx, NodeValue val);
4438 llvm::StringRef getOutputName(unsigned idx) const;
4439 bool hasSideEffects() const { return 0; }
4440 bool isCanonical() const { return 1; }
4441 bool isDataParallel() const { return 1; }
4442 std::string getDebugDesc() const;
4443 bool isEqual(const CosNode &other) const;
4444 llvm::hash_code getHash() const;
4445 void visit(Node *parent, NodeWalker *visitor);
4446 Node* clone() const;
4447 bool verify() const;
4448};
4449} // namespace glow
4450
4451
4452namespace glow {
4453/// Performs element-wise natural log to the Input.
4454class LogNode final : public Node {
4455 NodeHandle Input_;
4456
4457 public:
4458 enum InputIndices {
4459 InputIdx = 0,
4460 };
4461
4462 enum ResultIndices {
4463 ResultIdx = 0,
4464 };
4465
4466 LogNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4467 : Node(Kinded::Kind::LogNodeKind, name), Input_(this, Input) {
4468 addResult(Result);
4469 }
4470 const NodeValue getInput() const { return Input_; }
4471 NodeValue getResult() { return getNthResult(0); }
4472 const NodeValue getResult() const { return getNthResult(0); }
4473
4474 static bool classof(const Kinded *k) {
4475 return k->getKind() == Kinded::Kind::LogNodeKind;
4476 }
4477
4478
4479 bool isOverwrittenNthInput(unsigned idx) const {
4480 return false;
4481 }
4482
4483 unsigned getNumInputs() const;
4484 std::string getInputName(unsigned idx) const;
4485 NodeValue getNthInput(unsigned idx);
4486 void setNthInput(unsigned idx, NodeValue val);
4487 llvm::StringRef getOutputName(unsigned idx) const;
4488 bool hasSideEffects() const { return 0; }
4489 bool isCanonical() const { return 1; }
4490 bool isDataParallel() const { return 1; }
4491 std::string getDebugDesc() const;
4492 bool isEqual(const LogNode &other) const;
4493 llvm::hash_code getHash() const;
4494 void visit(Node *parent, NodeWalker *visitor);
4495 Node* clone() const;
4496 bool verify() const;
4497};
4498} // namespace glow
4499
4500
4501namespace glow {
4502/// Performs an element-wise Arccosine(x) of the Input operand.
4503class AcosNode final : public Node {
4504 NodeHandle Input_;
4505
4506 public:
4507 enum InputIndices {
4508 InputIdx = 0,
4509 };
4510
4511 enum ResultIndices {
4512 ResultIdx = 0,
4513 };
4514
4515 AcosNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4516 : Node(Kinded::Kind::AcosNodeKind, name), Input_(this, Input) {
4517 addResult(Result);
4518 }
4519 const NodeValue getInput() const { return Input_; }
4520 NodeValue getResult() { return getNthResult(0); }
4521 const NodeValue getResult() const { return getNthResult(0); }
4522
4523 static bool classof(const Kinded *k) {
4524 return k->getKind() == Kinded::Kind::AcosNodeKind;
4525 }
4526
4527
4528 bool isOverwrittenNthInput(unsigned idx) const {
4529 return false;
4530 }
4531
4532 unsigned getNumInputs() const;
4533 std::string getInputName(unsigned idx) const;
4534 NodeValue getNthInput(unsigned idx);
4535 void setNthInput(unsigned idx, NodeValue val);
4536 llvm::StringRef getOutputName(unsigned idx) const;
4537 bool hasSideEffects() const { return 0; }
4538 bool isCanonical() const { return 1; }
4539 bool isDataParallel() const { return 1; }
4540 std::string getDebugDesc() const;
4541 bool isEqual(const AcosNode &other) const;
4542 llvm::hash_code getHash() const;
4543 void visit(Node *parent, NodeWalker *visitor);
4544 Node* clone() const;
4545 bool verify() const;
4546};
4547} // namespace glow
4548
4549
4550namespace glow {
4551/// Performs an element-wise Arcsine(x) of the Input operand.
4552class AsinNode final : public Node {
4553 NodeHandle Input_;
4554
4555 public:
4556 enum InputIndices {
4557 InputIdx = 0,
4558 };
4559
4560 enum ResultIndices {
4561 ResultIdx = 0,
4562 };
4563
4564 AsinNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4565 : Node(Kinded::Kind::AsinNodeKind, name), Input_(this, Input) {
4566 addResult(Result);
4567 }
4568 const NodeValue getInput() const { return Input_; }
4569 NodeValue getResult() { return getNthResult(0); }
4570 const NodeValue getResult() const { return getNthResult(0); }
4571
4572 static bool classof(const Kinded *k) {
4573 return k->getKind() == Kinded::Kind::AsinNodeKind;
4574 }
4575
4576
4577 bool isOverwrittenNthInput(unsigned idx) const {
4578 return false;
4579 }
4580
4581 unsigned getNumInputs() const;
4582 std::string getInputName(unsigned idx) const;
4583 NodeValue getNthInput(unsigned idx);
4584 void setNthInput(unsigned idx, NodeValue val);
4585 llvm::StringRef getOutputName(unsigned idx) const;
4586 bool hasSideEffects() const { return 0; }
4587 bool isCanonical() const { return 1; }
4588 bool isDataParallel() const { return 1; }
4589 std::string getDebugDesc() const;
4590 bool isEqual(const AsinNode &other) const;
4591 llvm::hash_code getHash() const;
4592 void visit(Node *parent, NodeWalker *visitor);
4593 Node* clone() const;
4594 bool verify() const;
4595};
4596} // namespace glow
4597
4598
4599namespace glow {
4600/// Performs an element-wise Arctan(x) of the Input operand.
4601class AtanNode final : public Node {
4602 NodeHandle Input_;
4603
4604 public:
4605 enum InputIndices {
4606 InputIdx = 0,
4607 };
4608
4609 enum ResultIndices {
4610 ResultIdx = 0,
4611 };
4612
4613 AtanNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4614 : Node(Kinded::Kind::AtanNodeKind, name), Input_(this, Input) {
4615 addResult(Result);
4616 }
4617 const NodeValue getInput() const { return Input_; }
4618 NodeValue getResult() { return getNthResult(0); }
4619 const NodeValue getResult() const { return getNthResult(0); }
4620
4621 static bool classof(const Kinded *k) {
4622 return k->getKind() == Kinded::Kind::AtanNodeKind;
4623 }
4624
4625
4626 bool isOverwrittenNthInput(unsigned idx) const {
4627 return false;
4628 }
4629
4630 unsigned getNumInputs() const;
4631 std::string getInputName(unsigned idx) const;
4632 NodeValue getNthInput(unsigned idx);
4633 void setNthInput(unsigned idx, NodeValue val);
4634 llvm::StringRef getOutputName(unsigned idx) const;
4635 bool hasSideEffects() const { return 0; }
4636 bool isCanonical() const { return 1; }
4637 bool isDataParallel() const { return 1; }
4638 std::string getDebugDesc() const;
4639 bool isEqual(const AtanNode &other) const;
4640 llvm::hash_code getHash() const;
4641 void visit(Node *parent, NodeWalker *visitor);
4642 Node* clone() const;
4643 bool verify() const;
4644};
4645} // namespace glow
4646
4647
4648namespace glow {
4649/// Performs an element-wise Erf(x) of the Input operand.
4650class ErfNode final : public Node {
4651 NodeHandle Input_;
4652
4653 public:
4654 enum InputIndices {
4655 InputIdx = 0,
4656 };
4657
4658 enum ResultIndices {
4659 ResultIdx = 0,
4660 };
4661
4662 ErfNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4663 : Node(Kinded::Kind::ErfNodeKind, name), Input_(this, Input) {
4664 addResult(Result);
4665 }
4666 const NodeValue getInput() const { return Input_; }
4667 NodeValue getResult() { return getNthResult(0); }
4668 const NodeValue getResult() const { return getNthResult(0); }
4669
4670 static bool classof(const Kinded *k) {
4671 return k->getKind() == Kinded::Kind::ErfNodeKind;
4672 }
4673
4674
4675 bool isOverwrittenNthInput(unsigned idx) const {
4676 return false;
4677 }
4678
4679 unsigned getNumInputs() const;
4680 std::string getInputName(unsigned idx) const;
4681 NodeValue getNthInput(unsigned idx);
4682 void setNthInput(unsigned idx, NodeValue val);
4683 llvm::StringRef getOutputName(unsigned idx) const;
4684 bool hasSideEffects() const { return 0; }
4685 bool isCanonical() const { return 1; }
4686 bool isDataParallel() const { return 1; }
4687 std::string getDebugDesc() const;
4688 bool isEqual(const ErfNode &other) const;
4689 llvm::hash_code getHash() const;
4690 void visit(Node *parent, NodeWalker *visitor);
4691 Node* clone() const;
4692 bool verify() const;
4693};
4694} // namespace glow
4695
4696
4697namespace glow {
4698/// Performs element-wise exponential to the Input.
4699class ExpNode final : public Node {
4700 NodeHandle Input_;
4701
4702 public:
4703 enum InputIndices {
4704 InputIdx = 0,
4705 };
4706
4707 enum ResultIndices {
4708 ResultIdx = 0,
4709 };
4710
4711 ExpNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
4712 : Node(Kinded::Kind::ExpNodeKind, name), Input_(this, Input) {
4713 addResult(Result);
4714 }
4715 const NodeValue getInput() const { return Input_; }
4716 NodeValue getResult() { return getNthResult(0); }
4717 const NodeValue getResult() const { return getNthResult(0); }
4718
4719 static bool classof(const Kinded *k) {
4720 return k->getKind() == Kinded::Kind::ExpNodeKind;
4721 }
4722
4723
4724 bool isOverwrittenNthInput(unsigned idx) const {
4725 return false;
4726 }
4727
4728 unsigned getNumInputs() const;
4729 std::string getInputName(unsigned idx) const;
4730 NodeValue getNthInput(unsigned idx);
4731 void setNthInput(unsigned idx, NodeValue val);
4732 llvm::StringRef getOutputName(unsigned idx) const;
4733 bool hasSideEffects() const { return 0; }
4734 bool isCanonical() const { return 1; }
4735 bool isDataParallel() const { return 1; }
4736 std::string getDebugDesc() const;
4737 bool isEqual(const ExpNode &other) const;
4738 llvm::hash_code getHash() const;
4739 void visit(Node *parent, NodeWalker *visitor);
4740 Node* clone() const;
4741 bool verify() const;
4742};
4743} // namespace glow
4744
4745
4746namespace glow {
4747/// Computes elementwise: result = log(input / (1 - input)).
4748class LogitNode final : public Node {
4749 NodeHandle Input_;
4750 float Epsilon_;
4751
4752 public:
4753 enum InputIndices {
4754 InputIdx = 0,
4755 };
4756
4757 enum ResultIndices {
4758 ResultIdx = 0,
4759 };
4760
4761 LogitNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Epsilon)
4762 : Node(Kinded::Kind::LogitNodeKind, name), Input_(this, Input), Epsilon_(Epsilon) {
4763 addResult(Result);
4764 }
4765 const NodeValue getInput() const { return Input_; }
4766 NodeValue getResult() { return getNthResult(0); }
4767 const NodeValue getResult() const { return getNthResult(0); }
4768 float getEpsilon() const { return Epsilon_; }
4769
4770 static bool classof(const Kinded *k) {
4771 return k->getKind() == Kinded::Kind::LogitNodeKind;
4772 }
4773
4774
4775 bool isOverwrittenNthInput(unsigned idx) const {
4776 return false;
4777 }
4778
4779 unsigned getNumInputs() const;
4780 std::string getInputName(unsigned idx) const;
4781 NodeValue getNthInput(unsigned idx);
4782 void setNthInput(unsigned idx, NodeValue val);
4783 llvm::StringRef getOutputName(unsigned idx) const;
4784 bool hasSideEffects() const { return 0; }
4785 bool isCanonical() const { return 1; }
4786 bool isDataParallel() const { return 1; }
4787 std::string getDebugDesc() const;
4788 bool isEqual(const LogitNode &other) const;
4789 llvm::hash_code getHash() const;
4790 void visit(Node *parent, NodeWalker *visitor);
4791 Node* clone() const;
4792 bool verify() const;
4793};
4794} // namespace glow
4795
4796
4797namespace glow {
4798/// Selects indices of the true elements in Cond
4799class NonZeroNode final : public Node {
4800 NodeHandle Cond_;
4801
4802 public:
4803 enum InputIndices {
4804 CondIdx = 0,
4805 };
4806
4807 enum ResultIndices {
4808 ResultIdx = 0,
4809 };
4810
4811 NonZeroNode(llvm::StringRef name, TypeRef Result , NodeValue Cond)
4812 : Node(Kinded::Kind::NonZeroNodeKind, name), Cond_(this, Cond) {
4813 addResult(Result);
4814 }
4815 const NodeValue getCond() const { return Cond_; }
4816 NodeValue getResult() { return getNthResult(0); }
4817 const NodeValue getResult() const { return getNthResult(0); }
4818
4819 static bool classof(const Kinded *k) {
4820 return k->getKind() == Kinded::Kind::NonZeroNodeKind;
4821 }
4822
4823
4824 bool isOverwrittenNthInput(unsigned idx) const {
4825 return false;
4826 }
4827
4828 unsigned getNumInputs() const;
4829 std::string getInputName(unsigned idx) const;
4830 NodeValue getNthInput(unsigned idx);
4831 void setNthInput(unsigned idx, NodeValue val);
4832 llvm::StringRef getOutputName(unsigned idx) const;
4833 bool hasSideEffects() const { return 0; }
4834 bool isCanonical() const { return 1; }
4835 bool isDataParallel() const { return 1; }
4836 std::string getDebugDesc() const;
4837 bool isEqual(const NonZeroNode &other) const;
4838 llvm::hash_code getHash() const;
4839 void visit(Node *parent, NodeWalker *visitor);
4840 Node* clone() const;
4841 bool verify() const;
4842};
4843} // namespace glow
4844
4845
4846namespace glow {
4847/// Selects between values on the LHS or RHS, depending on the value of Cond. Cond is generated by the compare instruction, and is target- and type-specific.
4848class SelectNode final : public Node {
4849 NodeHandle Cond_;
4850 NodeHandle LHS_;
4851 NodeHandle RHS_;
4852
4853 public:
4854 enum InputIndices {
4855 CondIdx = 0,
4856 LHSIdx = 1,
4857 RHSIdx = 2,
4858 };
4859
4860 enum ResultIndices {
4861 ResultIdx = 0,
4862 };
4863
4864 SelectNode(llvm::StringRef name, TypeRef Result , NodeValue Cond, NodeValue LHS, NodeValue RHS)
4865 : Node(Kinded::Kind::SelectNodeKind, name), Cond_(this, Cond), LHS_(this, LHS), RHS_(this, RHS) {
4866 addResult(Result);
4867 }
4868 const NodeValue getCond() const { return Cond_; }
4869 const NodeValue getLHS() const { return LHS_; }
4870 const NodeValue getRHS() const { return RHS_; }
4871 NodeValue getResult() { return getNthResult(0); }
4872 const NodeValue getResult() const { return getNthResult(0); }
4873
4874 static bool classof(const Kinded *k) {
4875 return k->getKind() == Kinded::Kind::SelectNodeKind;
4876 }
4877
4878
4879 bool isOverwrittenNthInput(unsigned idx) const {
4880 return false;
4881 }
4882
4883 unsigned getNumInputs() const;
4884 std::string getInputName(unsigned idx) const;
4885 NodeValue getNthInput(unsigned idx);
4886 void setNthInput(unsigned idx, NodeValue val);
4887 llvm::StringRef getOutputName(unsigned idx) const;
4888 bool hasSideEffects() const { return 0; }
4889 bool isCanonical() const { return 1; }
4890 bool isDataParallel() const { return 1; }
4891 std::string getDebugDesc() const;
4892 bool isEqual(const SelectNode &other) const;
4893 llvm::hash_code getHash() const;
4894 void visit(Node *parent, NodeWalker *visitor);
4895 Node* clone() const;
4896 bool verify() const;
4897};
4898} // namespace glow
4899
4900
4901namespace glow {
4902/// Adds the 'Slice' operand to each one of the slices in the batch.
4903class BatchedAddNode final : public Node {
4904 NodeHandle Batch_;
4905 NodeHandle Slice_;
4906
4907 public:
4908 enum InputIndices {
4909 BatchIdx = 0,
4910 SliceIdx = 1,
4911 };
4912
4913 enum ResultIndices {
4914 ResultIdx = 0,
4915 };
4916
4917 BatchedAddNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, NodeValue Slice)
4918 : Node(Kinded::Kind::BatchedAddNodeKind, name), Batch_(this, Batch), Slice_(this, Slice) {
4919 addResult(Result);
4920 }
4921 const NodeValue getBatch() const { return Batch_; }
4922 const NodeValue getSlice() const { return Slice_; }
4923 NodeValue getResult() { return getNthResult(0); }
4924 const NodeValue getResult() const { return getNthResult(0); }
4925
4926 static bool classof(const Kinded *k) {
4927 return k->getKind() == Kinded::Kind::BatchedAddNodeKind;
4928 }
4929
4930
4931 bool isOverwrittenNthInput(unsigned idx) const {
4932 return false;
4933 }
4934
4935 unsigned getNumInputs() const;
4936 std::string getInputName(unsigned idx) const;
4937 NodeValue getNthInput(unsigned idx);
4938 void setNthInput(unsigned idx, NodeValue val);
4939 llvm::StringRef getOutputName(unsigned idx) const;
4940 bool hasSideEffects() const { return 0; }
4941 bool isCanonical() const { return 1; }
4942 bool isDataParallel() const { return 0; }
4943 std::string getDebugDesc() const;
4944 bool isEqual(const BatchedAddNode &other) const;
4945 llvm::hash_code getHash() const;
4946 void visit(Node *parent, NodeWalker *visitor);
4947 Node* clone() const;
4948 bool verify() const;
4949};
4950} // namespace glow
4951
4952
4953namespace glow {
4954/// Multiplies the 'Slice' operand to each one of the slices in the batch.
4955class BatchedMulNode final : public Node {
4956 NodeHandle Batch_;
4957 NodeHandle Slice_;
4958
4959 public:
4960 enum InputIndices {
4961 BatchIdx = 0,
4962 SliceIdx = 1,
4963 };
4964
4965 enum ResultIndices {
4966 ResultIdx = 0,
4967 };
4968
4969 BatchedMulNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, NodeValue Slice)
4970 : Node(Kinded::Kind::BatchedMulNodeKind, name), Batch_(this, Batch), Slice_(this, Slice) {
4971 addResult(Result);
4972 }
4973 const NodeValue getBatch() const { return Batch_; }
4974 const NodeValue getSlice() const { return Slice_; }
4975 NodeValue getResult() { return getNthResult(0); }
4976 const NodeValue getResult() const { return getNthResult(0); }
4977
4978 static bool classof(const Kinded *k) {
4979 return k->getKind() == Kinded::Kind::BatchedMulNodeKind;
4980 }
4981
4982
4983 bool isOverwrittenNthInput(unsigned idx) const {
4984 return false;
4985 }
4986
4987 unsigned getNumInputs() const;
4988 std::string getInputName(unsigned idx) const;
4989 NodeValue getNthInput(unsigned idx);
4990 void setNthInput(unsigned idx, NodeValue val);
4991 llvm::StringRef getOutputName(unsigned idx) const;
4992 bool hasSideEffects() const { return 0; }
4993 bool isCanonical() const { return 1; }
4994 bool isDataParallel() const { return 0; }
4995 std::string getDebugDesc() const;
4996 bool isEqual(const BatchedMulNode &other) const;
4997 llvm::hash_code getHash() const;
4998 void visit(Node *parent, NodeWalker *visitor);
4999 Node* clone() const;
5000 bool verify() const;
5001};
5002} // namespace glow
5003
5004
5005namespace glow {
5006/// Performs matrix multiplication between the LHS and RHS.Example: (A, Z) x (Z, B) => (A, B)
5007class MatMulNode final : public Node {
5008 NodeHandle LHS_;
5009 NodeHandle RHS_;
5010
5011 public:
5012 enum InputIndices {
5013 LHSIdx = 0,
5014 RHSIdx = 1,
5015 };
5016
5017 enum ResultIndices {
5018 ResultIdx = 0,
5019 };
5020
5021 MatMulNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
5022 : Node(Kinded::Kind::MatMulNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
5023 addResult(Result);
5024 }
5025 const NodeValue getLHS() const { return LHS_; }
5026 const NodeValue getRHS() const { return RHS_; }
5027 NodeValue getResult() { return getNthResult(0); }
5028 const NodeValue getResult() const { return getNthResult(0); }
5029
5030 static bool classof(const Kinded *k) {
5031 return k->getKind() == Kinded::Kind::MatMulNodeKind;
5032 }
5033
5034
5035 bool isOverwrittenNthInput(unsigned idx) const {
5036 return false;
5037 }
5038
5039 unsigned getNumInputs() const;
5040 std::string getInputName(unsigned idx) const;
5041 NodeValue getNthInput(unsigned idx);
5042 void setNthInput(unsigned idx, NodeValue val);
5043 llvm::StringRef getOutputName(unsigned idx) const;
5044 bool hasSideEffects() const { return 0; }
5045 bool isCanonical() const { return 1; }
5046 bool isDataParallel() const { return 0; }
5047 std::string getDebugDesc() const;
5048 bool isEqual(const MatMulNode &other) const;
5049 llvm::hash_code getHash() const;
5050 void visit(Node *parent, NodeWalker *visitor);
5051 Node* clone() const;
5052 bool verify() const;
5053};
5054} // namespace glow
5055
5056
5057namespace glow {
5058/// Performs batch matrix multiplication between the LHS and RHS. The operands are a stack of two dimensional matrices. Example: (N, A, Z) x (N, Z, B) => (N, A, B)
5059class BatchMatMulNode final : public Node {
5060 NodeHandle LHS_;
5061 NodeHandle RHS_;
5062
5063 public:
5064 enum InputIndices {
5065 LHSIdx = 0,
5066 RHSIdx = 1,
5067 };
5068
5069 enum ResultIndices {
5070 ResultIdx = 0,
5071 };
5072
5073 BatchMatMulNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS)
5074 : Node(Kinded::Kind::BatchMatMulNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) {
5075 addResult(Result);
5076 }
5077 const NodeValue getLHS() const { return LHS_; }
5078 const NodeValue getRHS() const { return RHS_; }
5079 NodeValue getResult() { return getNthResult(0); }
5080 const NodeValue getResult() const { return getNthResult(0); }
5081
5082 static bool classof(const Kinded *k) {
5083 return k->getKind() == Kinded::Kind::BatchMatMulNodeKind;
5084 }
5085
5086
5087 bool isOverwrittenNthInput(unsigned idx) const {
5088 return false;
5089 }
5090
5091 unsigned getNumInputs() const;
5092 std::string getInputName(unsigned idx) const;
5093 NodeValue getNthInput(unsigned idx);
5094 void setNthInput(unsigned idx, NodeValue val);
5095 llvm::StringRef getOutputName(unsigned idx) const;
5096 bool hasSideEffects() const { return 0; }
5097 bool isCanonical() const { return 1; }
5098 bool isDataParallel() const { return 0; }
5099 std::string getDebugDesc() const;
5100 bool isEqual(const BatchMatMulNode &other) const;
5101 llvm::hash_code getHash() const;
5102 void visit(Node *parent, NodeWalker *visitor);
5103 Node* clone() const;
5104 bool verify() const;
5105};
5106} // namespace glow
5107
5108
5109namespace glow {
5110/// Accumulates all of the layers in the batch and produce a tensor that has the same dimensions as the input tensor without the first dimension.
5111class BatchedReduceAddNode final : public Node {
5112 NodeHandle Batch_;
5113 unsigned_t Axis_;
5114
5115 public:
5116 enum InputIndices {
5117 BatchIdx = 0,
5118 };
5119
5120 enum ResultIndices {
5121 ResultIdx = 0,
5122 };
5123
5124 BatchedReduceAddNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, unsigned_t Axis)
5125 : Node(Kinded::Kind::BatchedReduceAddNodeKind, name), Batch_(this, Batch), Axis_(Axis) {
5126 addResult(Result);
5127 }
5128 const NodeValue getBatch() const { return Batch_; }
5129 NodeValue getResult() { return getNthResult(0); }
5130 const NodeValue getResult() const { return getNthResult(0); }
5131 unsigned_t getAxis() const { return Axis_; }
5132
5133 static bool classof(const Kinded *k) {
5134 return k->getKind() == Kinded::Kind::BatchedReduceAddNodeKind;
5135 }
5136
5137
5138 bool isOverwrittenNthInput(unsigned idx) const {
5139 return false;
5140 }
5141
5142 unsigned getNumInputs() const;
5143 std::string getInputName(unsigned idx) const;
5144 NodeValue getNthInput(unsigned idx);
5145 void setNthInput(unsigned idx, NodeValue val);
5146 llvm::StringRef getOutputName(unsigned idx) const;
5147 bool hasSideEffects() const { return 0; }
5148 bool isCanonical() const { return 1; }
5149 bool isDataParallel() const { return 0; }
5150 std::string getDebugDesc() const;
5151 bool isEqual(const BatchedReduceAddNode &other) const;
5152 llvm::hash_code getHash() const;
5153 void visit(Node *parent, NodeWalker *visitor);
5154 Node* clone() const;
5155 bool verify() const;
5156};
5157} // namespace glow
5158
5159
5160namespace glow {
5161/// Accumulates squares of all of the layers in the batch and produce a tensor that has the same dimensions as the input tensor without the first dimension.
5162class BatchedReduceSumSquareNode final : public Node {
5163 NodeHandle Batch_;
5164 unsigned_t Axis_;
5165
5166 public:
5167 enum InputIndices {
5168 BatchIdx = 0,
5169 };
5170
5171 enum ResultIndices {
5172 ResultIdx = 0,
5173 };
5174
5175 BatchedReduceSumSquareNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, unsigned_t Axis)
5176 : Node(Kinded::Kind::BatchedReduceSumSquareNodeKind, name), Batch_(this, Batch), Axis_(Axis) {
5177 addResult(Result);
5178 }
5179 const NodeValue getBatch() const { return Batch_; }
5180 NodeValue getResult() { return getNthResult(0); }
5181 const NodeValue getResult() const { return getNthResult(0); }
5182 unsigned_t getAxis() const { return Axis_; }
5183
5184 static bool classof(const Kinded *k) {
5185 return k->getKind() == Kinded::Kind::BatchedReduceSumSquareNodeKind;
5186 }
5187
5188
5189 bool isOverwrittenNthInput(unsigned idx) const {
5190 return false;
5191 }
5192
5193 unsigned getNumInputs() const;
5194 std::string getInputName(unsigned idx) const;
5195 NodeValue getNthInput(unsigned idx);
5196 void setNthInput(unsigned idx, NodeValue val);
5197 llvm::StringRef getOutputName(unsigned idx) const;
5198 bool hasSideEffects() const { return 0; }
5199 bool isCanonical() const { return 1; }
5200 bool isDataParallel() const { return 0; }
5201 std::string getDebugDesc() const;
5202 bool isEqual(const BatchedReduceSumSquareNode &other) const;
5203 llvm::hash_code getHash() const;
5204 void visit(Node *parent, NodeWalker *visitor);
5205 Node* clone() const;
5206 bool verify() const;
5207};
5208} // namespace glow
5209
5210
5211namespace glow {
5212/// Performs Average Mean operation on the Input given Axes.
5213class BatchedReduceMeanNode final : public Node {
5214 NodeHandle Batch_;
5215 std::vector<unsigned_t> Axes_;
5216
5217 public:
5218 enum InputIndices {
5219 BatchIdx = 0,
5220 };
5221
5222 enum ResultIndices {
5223 ResultIdx = 0,
5224 };
5225
5226 BatchedReduceMeanNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, std::vector<unsigned_t> Axes)
5227 : Node(Kinded::Kind::BatchedReduceMeanNodeKind, name), Batch_(this, Batch), Axes_(Axes) {
5228 addResult(Result);
5229 }
5230 const NodeValue getBatch() const { return Batch_; }
5231 NodeValue getResult() { return getNthResult(0); }
5232 const NodeValue getResult() const { return getNthResult(0); }
5233 llvm::ArrayRef<unsigned_t> getAxes() const { return Axes_; }
5234
5235 static bool classof(const Kinded *k) {
5236 return k->getKind() == Kinded::Kind::BatchedReduceMeanNodeKind;
5237 }
5238
5239
5240 bool isOverwrittenNthInput(unsigned idx) const {
5241 return false;
5242 }
5243
5244 unsigned getNumInputs() const;
5245 std::string getInputName(unsigned idx) const;
5246 NodeValue getNthInput(unsigned idx);
5247 void setNthInput(unsigned idx, NodeValue val);
5248 llvm::StringRef getOutputName(unsigned idx) const;
5249 bool hasSideEffects() const { return 0; }
5250 bool isCanonical() const { return 1; }
5251 bool isDataParallel() const { return 0; }
5252 std::string getDebugDesc() const;
5253 bool isEqual(const BatchedReduceMeanNode &other) const;
5254 llvm::hash_code getHash() const;
5255 void visit(Node *parent, NodeWalker *visitor);
5256 Node* clone() const;
5257 bool verify() const;
5258};
5259} // namespace glow
5260
5261
5262namespace glow {
5263/// Performs Reduce Min operation on the Input given Axes.
5264class BatchedReduceMinNode final : public Node {
5265 NodeHandle Batch_;
5266 std::vector<unsigned_t> Axes_;
5267
5268 public:
5269 enum InputIndices {
5270 BatchIdx = 0,
5271 };
5272
5273 enum ResultIndices {
5274 ResultIdx = 0,
5275 };
5276
5277 BatchedReduceMinNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, std::vector<unsigned_t> Axes)
5278 : Node(Kinded::Kind::BatchedReduceMinNodeKind, name), Batch_(this, Batch), Axes_(Axes) {
5279 addResult(Result);
5280 }
5281 const NodeValue getBatch() const { return Batch_; }
5282 NodeValue getResult() { return getNthResult(0); }
5283 const NodeValue getResult() const { return getNthResult(0); }
5284 llvm::ArrayRef<unsigned_t> getAxes() const { return Axes_; }
5285
5286 static bool classof(const Kinded *k) {
5287 return k->getKind() == Kinded::Kind::BatchedReduceMinNodeKind;
5288 }
5289
5290
5291 bool isOverwrittenNthInput(unsigned idx) const {
5292 return false;
5293 }
5294
5295 unsigned getNumInputs() const;
5296 std::string getInputName(unsigned idx) const;
5297 NodeValue getNthInput(unsigned idx);
5298 void setNthInput(unsigned idx, NodeValue val);
5299 llvm::StringRef getOutputName(unsigned idx) const;
5300 bool hasSideEffects() const { return 0; }
5301 bool isCanonical() const { return 1; }
5302 bool isDataParallel() const { return 0; }
5303 std::string getDebugDesc() const;
5304 bool isEqual(const BatchedReduceMinNode &other) const;
5305 llvm::hash_code getHash() const;
5306 void visit(Node *parent, NodeWalker *visitor);
5307 Node* clone() const;
5308 bool verify() const;
5309};
5310} // namespace glow
5311
5312
5313namespace glow {
5314/// Performs Reduce Max operation on the Input given Axes.
5315class BatchedReduceMaxNode final : public Node {
5316 NodeHandle Batch_;
5317 std::vector<unsigned_t> Axes_;
5318
5319 public:
5320 enum InputIndices {
5321 BatchIdx = 0,
5322 };
5323
5324 enum ResultIndices {
5325 ResultIdx = 0,
5326 };
5327
5328 BatchedReduceMaxNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, std::vector<unsigned_t> Axes)
5329 : Node(Kinded::Kind::BatchedReduceMaxNodeKind, name), Batch_(this, Batch), Axes_(Axes) {
5330 addResult(Result);
5331 }
5332 const NodeValue getBatch() const { return Batch_; }
5333 NodeValue getResult() { return getNthResult(0); }
5334 const NodeValue getResult() const { return getNthResult(0); }
5335 llvm::ArrayRef<unsigned_t> getAxes() const { return Axes_; }
5336
5337 static bool classof(const Kinded *k) {
5338 return k->getKind() == Kinded::Kind::BatchedReduceMaxNodeKind;
5339 }
5340
5341
5342 bool isOverwrittenNthInput(unsigned idx) const {
5343 return false;
5344 }
5345
5346 unsigned getNumInputs() const;
5347 std::string getInputName(unsigned idx) const;
5348 NodeValue getNthInput(unsigned idx);
5349 void setNthInput(unsigned idx, NodeValue val);
5350 llvm::StringRef getOutputName(unsigned idx) const;
5351 bool hasSideEffects() const { return 0; }
5352 bool isCanonical() const { return 1; }
5353 bool isDataParallel() const { return 0; }
5354 std::string getDebugDesc() const;
5355 bool isEqual(const BatchedReduceMaxNode &other) const;
5356 llvm::hash_code getHash() const;
5357 void visit(Node *parent, NodeWalker *visitor);
5358 Node* clone() const;
5359 bool verify() const;
5360};
5361} // namespace glow
5362
5363
5364namespace glow {
5365/// Accumulates the product all of the layers in the batch and produce a tensor that has the same dimensions as the input tensor without the first dimension.
5366class BatchedReduceProdNode final : public Node {
5367 NodeHandle Batch_;
5368 unsigned_t Axis_;
5369
5370 public:
5371 enum InputIndices {
5372 BatchIdx = 0,
5373 };
5374
5375 enum ResultIndices {
5376 ResultIdx = 0,
5377 };
5378
5379 BatchedReduceProdNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, unsigned_t Axis)
5380 : Node(Kinded::Kind::BatchedReduceProdNodeKind, name), Batch_(this, Batch), Axis_(Axis) {
5381 addResult(Result);
5382 }
5383 const NodeValue getBatch() const { return Batch_; }
5384 NodeValue getResult() { return getNthResult(0); }
5385 const NodeValue getResult() const { return getNthResult(0); }
5386 unsigned_t getAxis() const { return Axis_; }
5387
5388 static bool classof(const Kinded *k) {
5389 return k->getKind() == Kinded::Kind::BatchedReduceProdNodeKind;
5390 }
5391
5392
5393 bool isOverwrittenNthInput(unsigned idx) const {
5394 return false;
5395 }
5396
5397 unsigned getNumInputs() const;
5398 std::string getInputName(unsigned idx) const;
5399 NodeValue getNthInput(unsigned idx);
5400 void setNthInput(unsigned idx, NodeValue val);
5401 llvm::StringRef getOutputName(unsigned idx) const;
5402 bool hasSideEffects() const { return 0; }
5403 bool isCanonical() const { return 1; }
5404 bool isDataParallel() const { return 0; }
5405 std::string getDebugDesc() const;
5406 bool isEqual(const BatchedReduceProdNode &other) const;
5407 llvm::hash_code getHash() const;
5408 void visit(Node *parent, NodeWalker *visitor);
5409 Node* clone() const;
5410 bool verify() const;
5411};
5412} // namespace glow
5413
5414
5415namespace glow {
5416/// Performs Channel shuffle.
5417class ChannelShuffleNode final : public Node {
5418 NodeHandle Input_;
5419 unsigned_t Group_;
5420 unsigned_t Kernel_;
5421
5422 public:
5423 enum InputIndices {
5424 InputIdx = 0,
5425 };
5426
5427 enum ResultIndices {
5428 ResultIdx = 0,
5429 };
5430
5431 ChannelShuffleNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Group, unsigned_t Kernel)
5432 : Node(Kinded::Kind::ChannelShuffleNodeKind, name), Input_(this, Input), Group_(Group), Kernel_(Kernel) {
5433 addResult(Result);
5434 }
5435 const NodeValue getInput() const { return Input_; }
5436 NodeValue getResult() { return getNthResult(0); }
5437 const NodeValue getResult() const { return getNthResult(0); }
5438 unsigned_t getGroup() const { return Group_; }
5439 unsigned_t getKernel() const { return Kernel_; }
5440
5441 static bool classof(const Kinded *k) {
5442 return k->getKind() == Kinded::Kind::ChannelShuffleNodeKind;
5443 }
5444
5445
5446 bool isOverwrittenNthInput(unsigned idx) const {
5447 return false;
5448 }
5449
5450 unsigned getNumInputs() const;
5451 std::string getInputName(unsigned idx) const;
5452 NodeValue getNthInput(unsigned idx);
5453 void setNthInput(unsigned idx, NodeValue val);
5454 llvm::StringRef getOutputName(unsigned idx) const;
5455 bool hasSideEffects() const { return 0; }
5456 bool isCanonical() const { return 1; }
5457 bool isDataParallel() const { return 0; }
5458 std::string getDebugDesc() const;
5459 bool isEqual(const ChannelShuffleNode &other) const;
5460 llvm::hash_code getHash() const;
5461 void visit(Node *parent, NodeWalker *visitor);
5462 Node* clone() const;
5463 bool verify() const;
5464};
5465} // namespace glow
5466
5467
5468namespace glow {
5469/// Performs a Cumulative Sum operation over a 1D vector with flags for working in exclusive mode and in reverse. In each case the output size is the same as in input size.e.g (default) [1, 2, 3, 4] -> [1, 3, 6, 10]. (exclusive) [1, 2, 3, 4] -> [0, 1, 3, 6]. (reverse) [1, 2, 3, 4] -> [10, 9, 7, 4].
5470class CumSumNode final : public Node {
5471 NodeHandle Input_;
5472 int64_t Dim_;
5473 unsigned_t Exclusive_;
5474 unsigned_t Reverse_;
5475
5476 public:
5477 enum InputIndices {
5478 InputIdx = 0,
5479 };
5480
5481 enum ResultIndices {
5482 ResultIdx = 0,
5483 };
5484
5485 CumSumNode(llvm::StringRef name, TypeRef Result , NodeValue Input, int64_t Dim, unsigned_t Exclusive, unsigned_t Reverse)
5486 : Node(Kinded::Kind::CumSumNodeKind, name), Input_(this, Input), Dim_(Dim), Exclusive_(Exclusive), Reverse_(Reverse) {
5487 addResult(Result);
5488 }
5489 const NodeValue getInput() const { return Input_; }
5490 NodeValue getResult() { return getNthResult(0); }
5491 const NodeValue getResult() const { return getNthResult(0); }
5492 int64_t getDim() const { return Dim_; }
5493 unsigned_t getExclusive() const { return Exclusive_; }
5494 unsigned_t getReverse() const { return Reverse_; }
5495
5496 static bool classof(const Kinded *k) {
5497 return k->getKind() == Kinded::Kind::CumSumNodeKind;
5498 }
5499
5500
5501 bool isOverwrittenNthInput(unsigned idx) const {
5502 return false;
5503 }
5504
5505 unsigned getNumInputs() const;
5506 std::string getInputName(unsigned idx) const;
5507 NodeValue getNthInput(unsigned idx);
5508 void setNthInput(unsigned idx, NodeValue val);
5509 llvm::StringRef getOutputName(unsigned idx) const;
5510 bool hasSideEffects() const { return 0; }
5511 bool isCanonical() const { return 1; }
5512 bool isDataParallel() const { return 1; }
5513 std::string getDebugDesc() const;
5514 bool isEqual(const CumSumNode &other) const;
5515 llvm::hash_code getHash() const;
5516 void visit(Node *parent, NodeWalker *visitor);
5517 Node* clone() const;
5518 bool verify() const;
5519};
5520} // namespace glow
5521
5522
5523namespace glow {
5524/// Sums slices of the outermost dimension of Data in groups defined by Lengths. The first Lengths[0] slices are added together and stored in Result[0], the subsequent Lengths[1] slices are added together and stored in Result[1], etc.
5525class LengthsSumNode final : public Node {
5526 NodeHandle Data_;
5527 NodeHandle Lengths_;
5528
5529 public:
5530 enum InputIndices {
5531 DataIdx = 0,
5532 LengthsIdx = 1,
5533 };
5534
5535 enum ResultIndices {
5536 ResultIdx = 0,
5537 };
5538
5539 LengthsSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Lengths)
5540 : Node(Kinded::Kind::LengthsSumNodeKind, name), Data_(this, Data), Lengths_(this, Lengths) {
5541 addResult(Result);
5542 }
5543 const NodeValue getData() const { return Data_; }
5544 const NodeValue getLengths() const { return Lengths_; }
5545 NodeValue getResult() { return getNthResult(0); }
5546 const NodeValue getResult() const { return getNthResult(0); }
5547
5548 static bool classof(const Kinded *k) {
5549 return k->getKind() == Kinded::Kind::LengthsSumNodeKind;
5550 }
5551
5552
5553 bool isOverwrittenNthInput(unsigned idx) const {
5554 return false;
5555 }
5556
5557 unsigned getNumInputs() const;
5558 std::string getInputName(unsigned idx) const;
5559 NodeValue getNthInput(unsigned idx);
5560 void setNthInput(unsigned idx, NodeValue val);
5561 llvm::StringRef getOutputName(unsigned idx) const;
5562 bool hasSideEffects() const { return 0; }
5563 bool isCanonical() const { return 1; }
5564 bool isDataParallel() const { return 0; }
5565 std::string getDebugDesc() const;
5566 bool isEqual(const LengthsSumNode &other) const;
5567 llvm::hash_code getHash() const;
5568 void visit(Node *parent, NodeWalker *visitor);
5569 Node* clone() const;
5570 bool verify() const;
5571};
5572} // namespace glow
5573
5574
5575namespace glow {
5576class SparseLengthsSumGradNode final : public Node {
5577 NodeHandle Data_;
5578 NodeHandle Indices_;
5579 NodeHandle Lengths_;
5580 NodeHandle OriginalOutputForResult_;
5581 NodeHandle GradOfOriginalOutputNamedResult_;
5582 glow::LengthsMode LengthsMode_;
5583 float AvgLength_;
5584
5585 public:
5586 enum InputIndices {
5587 DataIdx = 0,
5588 IndicesIdx = 1,
5589 LengthsIdx = 2,
5590 OriginalOutputForResultIdx = 3,
5591 GradOfOriginalOutputNamedResultIdx = 4,
5592 };
5593
5594 enum ResultIndices {
5595 GradOfInputNamedDataIdx = 0,
5596 GradOfInputNamedIndicesIdx = 1,
5597 GradOfInputNamedLengthsIdx = 2,
5598 };
5599
5600 SparseLengthsSumGradNode(llvm::StringRef name, NodeValue Data, NodeValue Indices, NodeValue Lengths, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, glow::LengthsMode LengthsMode, float AvgLength)
5601 : Node(Kinded::Kind::SparseLengthsSumGradNodeKind, name), Data_(this, Data), Indices_(this, Indices), Lengths_(this, Lengths), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), LengthsMode_(LengthsMode), AvgLength_(AvgLength) {
5602 addResult(Data.getType());
5603 addResult(Indices.getType());
5604 addResult(Lengths.getType());
5605 }
5606 const NodeValue getData() const { return Data_; }
5607 const NodeValue getIndices() const { return Indices_; }
5608 const NodeValue getLengths() const { return Lengths_; }
5609 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
5610 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
5611 NodeValue getGradOfInputNamedData() { return getNthResult(0); }
5612 const NodeValue getGradOfInputNamedData() const { return getNthResult(0); }
5613 NodeValue getGradOfInputNamedIndices() { return getNthResult(1); }
5614 const NodeValue getGradOfInputNamedIndices() const { return getNthResult(1); }
5615 NodeValue getGradOfInputNamedLengths() { return getNthResult(2); }
5616 const NodeValue getGradOfInputNamedLengths() const { return getNthResult(2); }
5617 glow::LengthsMode getLengthsMode() const { return LengthsMode_; }
5618 float getAvgLength() const { return AvgLength_; }
5619
5620 static bool classof(const Kinded *k) {
5621 return k->getKind() == Kinded::Kind::SparseLengthsSumGradNodeKind;
5622 }
5623
5624
5625 bool isOverwrittenNthInput(unsigned idx) const {
5626 return false;
5627 }
5628
5629 unsigned getNumInputs() const;
5630 std::string getInputName(unsigned idx) const;
5631 NodeValue getNthInput(unsigned idx);
5632 void setNthInput(unsigned idx, NodeValue val);
5633 llvm::StringRef getOutputName(unsigned idx) const;
5634 bool hasSideEffects() const { return 0; }
5635 bool isCanonical() const { return 1; }
5636 bool isDataParallel() const { return 0; }
5637 std::string getDebugDesc() const;
5638 bool isEqual(const SparseLengthsSumGradNode &other) const;
5639 llvm::hash_code getHash() const;
5640 void visit(Node *parent, NodeWalker *visitor);
5641 Node* clone() const;
5642 bool verify() const;
5643};
5644} // namespace glow
5645
5646
5647namespace glow {
5648/// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices).
5649class SparseLengthsSumNode final : public Node {
5650 NodeHandle Data_;
5651 NodeHandle Indices_;
5652 NodeHandle Lengths_;
5653 glow::LengthsMode LengthsMode_;
5654 float AvgLength_;
5655
5656 public:
5657 enum InputIndices {
5658 DataIdx = 0,
5659 IndicesIdx = 1,
5660 LengthsIdx = 2,
5661 };
5662
5663 enum ResultIndices {
5664 ResultIdx = 0,
5665 };
5666
5667 SparseLengthsSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, NodeValue Lengths, glow::LengthsMode LengthsMode, float AvgLength)
5668 : Node(Kinded::Kind::SparseLengthsSumNodeKind, name), Data_(this, Data), Indices_(this, Indices), Lengths_(this, Lengths), LengthsMode_(LengthsMode), AvgLength_(AvgLength) {
5669 addResult(Result);
5670 }
5671 const NodeValue getData() const { return Data_; }
5672 const NodeValue getIndices() const { return Indices_; }
5673 const NodeValue getLengths() const { return Lengths_; }
5674 NodeValue getResult() { return getNthResult(0); }
5675 const NodeValue getResult() const { return getNthResult(0); }
5676 glow::LengthsMode getLengthsMode() const { return LengthsMode_; }
5677 float getAvgLength() const { return AvgLength_; }
5678
5679 static bool classof(const Kinded *k) {
5680 return k->getKind() == Kinded::Kind::SparseLengthsSumNodeKind;
5681 }
5682
5683
5684 bool isOverwrittenNthInput(unsigned idx) const {
5685 return false;
5686 }
5687
5688 unsigned getNumInputs() const;
5689 std::string getInputName(unsigned idx) const;
5690 NodeValue getNthInput(unsigned idx);
5691 void setNthInput(unsigned idx, NodeValue val);
5692 llvm::StringRef getOutputName(unsigned idx) const;
5693 bool hasSideEffects() const { return 0; }
5694 bool isCanonical() const { return 1; }
5695 bool isDataParallel() const { return 0; }
5696 std::string getDebugDesc() const;
5697 bool isEqual(const SparseLengthsSumNode &other) const;
5698 llvm::hash_code getHash() const;
5699 void visit(Node *parent, NodeWalker *visitor);
5700 Node* clone() const;
5701 bool verify() const;
5702 SparseLengthsSumGradNode *getGrad(GraphGradMapper &builder);
5703};
5704} // namespace glow
5705
5706
5707namespace glow {
5708class SparseLengthsWeightedSumGradNode final : public Node {
5709 NodeHandle Data_;
5710 NodeHandle Weights_;
5711 NodeHandle Indices_;
5712 NodeHandle Lengths_;
5713 NodeHandle OriginalOutputForResult_;
5714 NodeHandle GradOfOriginalOutputNamedResult_;
5715 glow::LengthsMode LengthsMode_;
5716 float AvgLength_;
5717
5718 public:
5719 enum InputIndices {
5720 DataIdx = 0,
5721 WeightsIdx = 1,
5722 IndicesIdx = 2,
5723 LengthsIdx = 3,
5724 OriginalOutputForResultIdx = 4,
5725 GradOfOriginalOutputNamedResultIdx = 5,
5726 };
5727
5728 enum ResultIndices {
5729 GradOfInputNamedDataIdx = 0,
5730 GradOfInputNamedWeightsIdx = 1,
5731 GradOfInputNamedIndicesIdx = 2,
5732 GradOfInputNamedLengthsIdx = 3,
5733 };
5734
5735 SparseLengthsWeightedSumGradNode(llvm::StringRef name, NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Lengths, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, glow::LengthsMode LengthsMode, float AvgLength)
5736 : Node(Kinded::Kind::SparseLengthsWeightedSumGradNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Lengths_(this, Lengths), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), LengthsMode_(LengthsMode), AvgLength_(AvgLength) {
5737 addResult(Data.getType());
5738 addResult(Weights.getType());
5739 addResult(Indices.getType());
5740 addResult(Lengths.getType());
5741 }
5742 const NodeValue getData() const { return Data_; }
5743 const NodeValue getWeights() const { return Weights_; }
5744 const NodeValue getIndices() const { return Indices_; }
5745 const NodeValue getLengths() const { return Lengths_; }
5746 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
5747 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
5748 NodeValue getGradOfInputNamedData() { return getNthResult(0); }
5749 const NodeValue getGradOfInputNamedData() const { return getNthResult(0); }
5750 NodeValue getGradOfInputNamedWeights() { return getNthResult(1); }
5751 const NodeValue getGradOfInputNamedWeights() const { return getNthResult(1); }
5752 NodeValue getGradOfInputNamedIndices() { return getNthResult(2); }
5753 const NodeValue getGradOfInputNamedIndices() const { return getNthResult(2); }
5754 NodeValue getGradOfInputNamedLengths() { return getNthResult(3); }
5755 const NodeValue getGradOfInputNamedLengths() const { return getNthResult(3); }
5756 glow::LengthsMode getLengthsMode() const { return LengthsMode_; }
5757 float getAvgLength() const { return AvgLength_; }
5758
5759 static bool classof(const Kinded *k) {
5760 return k->getKind() == Kinded::Kind::SparseLengthsWeightedSumGradNodeKind;
5761 }
5762
5763
5764 bool isOverwrittenNthInput(unsigned idx) const {
5765 return false;
5766 }
5767
5768 unsigned getNumInputs() const;
5769 std::string getInputName(unsigned idx) const;
5770 NodeValue getNthInput(unsigned idx);
5771 void setNthInput(unsigned idx, NodeValue val);
5772 llvm::StringRef getOutputName(unsigned idx) const;
5773 bool hasSideEffects() const { return 0; }
5774 bool isCanonical() const { return 1; }
5775 bool isDataParallel() const { return 0; }
5776 std::string getDebugDesc() const;
5777 bool isEqual(const SparseLengthsWeightedSumGradNode &other) const;
5778 llvm::hash_code getHash() const;
5779 void visit(Node *parent, NodeWalker *visitor);
5780 Node* clone() const;
5781 bool verify() const;
5782};
5783} // namespace glow
5784
5785
5786namespace glow {
5787/// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). Before doing aggregation, each individual slice is scaled by its weight: Result[0] = Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... It implies that len(Weights) == len(Indices).
5788class SparseLengthsWeightedSumNode final : public Node {
5789 NodeHandle Data_;
5790 NodeHandle Weights_;
5791 NodeHandle Indices_;
5792 NodeHandle Lengths_;
5793 glow::LengthsMode LengthsMode_;
5794 float AvgLength_;
5795
5796 public:
5797 enum InputIndices {
5798 DataIdx = 0,
5799 WeightsIdx = 1,
5800 IndicesIdx = 2,
5801 LengthsIdx = 3,
5802 };
5803
5804 enum ResultIndices {
5805 ResultIdx = 0,
5806 };
5807
5808 SparseLengthsWeightedSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Lengths, glow::LengthsMode LengthsMode, float AvgLength)
5809 : Node(Kinded::Kind::SparseLengthsWeightedSumNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Lengths_(this, Lengths), LengthsMode_(LengthsMode), AvgLength_(AvgLength) {
5810 addResult(Result);
5811 }
5812 const NodeValue getData() const { return Data_; }
5813 const NodeValue getWeights() const { return Weights_; }
5814 const NodeValue getIndices() const { return Indices_; }
5815 const NodeValue getLengths() const { return Lengths_; }
5816 NodeValue getResult() { return getNthResult(0); }
5817 const NodeValue getResult() const { return getNthResult(0); }
5818 glow::LengthsMode getLengthsMode() const { return LengthsMode_; }
5819 float getAvgLength() const { return AvgLength_; }
5820
5821 static bool classof(const Kinded *k) {
5822 return k->getKind() == Kinded::Kind::SparseLengthsWeightedSumNodeKind;
5823 }
5824
5825
5826 bool isOverwrittenNthInput(unsigned idx) const {
5827 return false;
5828 }
5829
5830 unsigned getNumInputs() const;
5831 std::string getInputName(unsigned idx) const;
5832 NodeValue getNthInput(unsigned idx);
5833 void setNthInput(unsigned idx, NodeValue val);
5834 llvm::StringRef getOutputName(unsigned idx) const;
5835 bool hasSideEffects() const { return 0; }
5836 bool isCanonical() const { return 1; }
5837 bool isDataParallel() const { return 0; }
5838 std::string getDebugDesc() const;
5839 bool isEqual(const SparseLengthsWeightedSumNode &other) const;
5840 llvm::hash_code getHash() const;
5841 void visit(Node *parent, NodeWalker *visitor);
5842 Node* clone() const;
5843 bool verify() const;
5844 SparseLengthsWeightedSumGradNode *getGrad(GraphGradMapper &builder);
5845};
5846} // namespace glow
5847
5848
5849namespace glow {
5850/// Gathers slices of the outer-most dimension of Weights indexed by Indices tensor.
5851class EmbeddingNode final : public Node {
5852 NodeHandle Weights_;
5853 NodeHandle Indices_;
5854 int64_t PadIdx_;
5855 bool Scale_;
5856 bool Sparse_;
5857
5858 public:
5859 enum InputIndices {
5860 WeightsIdx = 0,
5861 IndicesIdx = 1,
5862 };
5863
5864 enum ResultIndices {
5865 ResultIdx = 0,
5866 };
5867
5868 EmbeddingNode(llvm::StringRef name, TypeRef Result , NodeValue Weights, NodeValue Indices, int64_t PadIdx, bool Scale, bool Sparse)
5869 : Node(Kinded::Kind::EmbeddingNodeKind, name), Weights_(this, Weights), Indices_(this, Indices), PadIdx_(PadIdx), Scale_(Scale), Sparse_(Sparse) {
5870 addResult(Result);
5871 }
5872 const NodeValue getWeights() const { return Weights_; }
5873 const NodeValue getIndices() const { return Indices_; }
5874 NodeValue getResult() { return getNthResult(0); }
5875 const NodeValue getResult() const { return getNthResult(0); }
5876 int64_t getPadIdx() const { return PadIdx_; }
5877 bool getScale() const { return Scale_; }
5878 bool getSparse() const { return Sparse_; }
5879
5880 static bool classof(const Kinded *k) {
5881 return k->getKind() == Kinded::Kind::EmbeddingNodeKind;
5882 }
5883
5884
5885 bool isOverwrittenNthInput(unsigned idx) const {
5886 return false;
5887 }
5888
5889 unsigned getNumInputs() const;
5890 std::string getInputName(unsigned idx) const;
5891 NodeValue getNthInput(unsigned idx);
5892 void setNthInput(unsigned idx, NodeValue val);
5893 llvm::StringRef getOutputName(unsigned idx) const;
5894 bool hasSideEffects() const { return 0; }
5895 bool isCanonical() const { return 1; }
5896 bool isDataParallel() const { return 0; }
5897 std::string getDebugDesc() const;
5898 bool isEqual(const EmbeddingNode &other) const;
5899 llvm::hash_code getHash() const;
5900 void visit(Node *parent, NodeWalker *visitor);
5901 Node* clone() const;
5902 bool verify() const;
5903};
5904} // namespace glow
5905
5906
5907namespace glow {
5908/// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Offsets) entries: first slice between Offsets[0] and Offsets[1] (or total length if there's only one elem in Offsets) are aggregated to Result[0], etc. I.e. largest offset must be less than or equal to len(Indices). Before doing aggregation, each individual slice is scaled by its weight: Result[0] = Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... It implies that len(Weights) == len(Indices).
5909class EmbeddingBagNode final : public Node {
5910 NodeHandle Data_;
5911 NodeHandle Weights_;
5912 NodeHandle Indices_;
5913 NodeHandle Offsets_;
5914 bool HasEndOffset_;
5915 glow::LengthsMode LengthsMode_;
5916 float AvgLength_;
5917
5918 public:
5919 enum InputIndices {
5920 DataIdx = 0,
5921 WeightsIdx = 1,
5922 IndicesIdx = 2,
5923 OffsetsIdx = 3,
5924 };
5925
5926 enum ResultIndices {
5927 ResultIdx = 0,
5928 };
5929
5930 EmbeddingBagNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Offsets, bool HasEndOffset, glow::LengthsMode LengthsMode, float AvgLength)
5931 : Node(Kinded::Kind::EmbeddingBagNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Offsets_(this, Offsets), HasEndOffset_(HasEndOffset), LengthsMode_(LengthsMode), AvgLength_(AvgLength) {
5932 addResult(Result);
5933 }
5934 const NodeValue getData() const { return Data_; }
5935 const NodeValue getWeights() const { return Weights_; }
5936 const NodeValue getIndices() const { return Indices_; }
5937 const NodeValue getOffsets() const { return Offsets_; }
5938 NodeValue getResult() { return getNthResult(0); }
5939 const NodeValue getResult() const { return getNthResult(0); }
5940 bool getHasEndOffset() const { return HasEndOffset_; }
5941 glow::LengthsMode getLengthsMode() const { return LengthsMode_; }
5942 float getAvgLength() const { return AvgLength_; }
5943
5944 static bool classof(const Kinded *k) {
5945 return k->getKind() == Kinded::Kind::EmbeddingBagNodeKind;
5946 }
5947
5948
5949 bool isOverwrittenNthInput(unsigned idx) const {
5950 return false;
5951 }
5952
5953 unsigned getNumInputs() const;
5954 std::string getInputName(unsigned idx) const;
5955 NodeValue getNthInput(unsigned idx);
5956 void setNthInput(unsigned idx, NodeValue val);
5957 llvm::StringRef getOutputName(unsigned idx) const;
5958 bool hasSideEffects() const { return 0; }
5959 bool isCanonical() const { return 1; }
5960 bool isDataParallel() const { return 0; }
5961 std::string getDebugDesc() const;
5962 bool isEqual(const EmbeddingBagNode &other) const;
5963 llvm::hash_code getHash() const;
5964 void visit(Node *parent, NodeWalker *visitor);
5965 Node* clone() const;
5966 bool verify() const;
5967};
5968} // namespace glow
5969
5970
5971namespace glow {
5972/// Same as FusedRowwiseQuantizedSparseLengthsWeightedSum but using offsets instead of lengths.
5973class EmbeddingBagByteRowwiseOffsetsNode final : public Node {
5974 NodeHandle Data_;
5975 NodeHandle Weights_;
5976 NodeHandle Indices_;
5977 NodeHandle Offsets_;
5978 bool UseFP16Accumulation_;
5979 bool HasEndOffset_;
5980 glow::LengthsMode LengthsMode_;
5981 float AvgLength_;
5982
5983 public:
5984 enum InputIndices {
5985 DataIdx = 0,
5986 WeightsIdx = 1,
5987 IndicesIdx = 2,
5988 OffsetsIdx = 3,
5989 };
5990
5991 enum ResultIndices {
5992 ResultIdx = 0,
5993 };
5994
5995 EmbeddingBagByteRowwiseOffsetsNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Offsets, bool UseFP16Accumulation, bool HasEndOffset, glow::LengthsMode LengthsMode, float AvgLength)
5996 : Node(Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Offsets_(this, Offsets), UseFP16Accumulation_(UseFP16Accumulation), HasEndOffset_(HasEndOffset), LengthsMode_(LengthsMode), AvgLength_(AvgLength) {
5997 addResult(Result);
5998 }
5999 const NodeValue getData() const { return Data_; }
6000 const NodeValue getWeights() const { return Weights_; }
6001 const NodeValue getIndices() const { return Indices_; }
6002 const NodeValue getOffsets() const { return Offsets_; }
6003 NodeValue getResult() { return getNthResult(0); }
6004 const NodeValue getResult() const { return getNthResult(0); }
6005 bool getUseFP16Accumulation() const { return UseFP16Accumulation_; }
6006 void setUseFP16Accumulation(bool a) {UseFP16Accumulation_ = a; }
6007 bool getHasEndOffset() const { return HasEndOffset_; }
6008 glow::LengthsMode getLengthsMode() const { return LengthsMode_; }
6009 float getAvgLength() const { return AvgLength_; }
6010
6011 static bool classof(const Kinded *k) {
6012 return k->getKind() == Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind;
6013 }
6014
6015
6016 bool isOverwrittenNthInput(unsigned idx) const {
6017 return false;
6018 }
6019
6020 unsigned getNumInputs() const;
6021 std::string getInputName(unsigned idx) const;
6022 NodeValue getNthInput(unsigned idx);
6023 void setNthInput(unsigned idx, NodeValue val);
6024 llvm::StringRef getOutputName(unsigned idx) const;
6025 bool hasSideEffects() const { return 0; }
6026 bool isCanonical() const { return 1; }
6027 bool isDataParallel() const { return 0; }
6028 std::string getDebugDesc() const;
6029 bool isEqual(const EmbeddingBagByteRowwiseOffsetsNode &other) const;
6030 llvm::hash_code getHash() const;
6031 void visit(Node *parent, NodeWalker *visitor);
6032 Node* clone() const;
6033 bool verify() const;
6034};
6035} // namespace glow
6036
6037
6038namespace glow {
6039/// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). Before doing aggregation, each individual slice is scaled by its weight: Result[0] = Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... It implies that len(Weights) == len(Indices). The input data is rowwise-quantized, where the Scales and Offsets are 1D tensors of length equal to the first dim of Data.
6040class RowwiseQuantizedSparseLengthsWeightedSumNode final : public Node {
6041 NodeHandle Data_;
6042 NodeHandle Scales_;
6043 NodeHandle Offsets_;
6044 NodeHandle Weights_;
6045 NodeHandle Indices_;
6046 NodeHandle Lengths_;
6047 bool UseFP16Accumulation_;
6048 glow::LengthsMode LengthsMode_;
6049 float AvgLength_;
6050
6051 public:
6052 enum InputIndices {
6053 DataIdx = 0,
6054 ScalesIdx = 1,
6055 OffsetsIdx = 2,
6056 WeightsIdx = 3,
6057 IndicesIdx = 4,
6058 LengthsIdx = 5,
6059 };
6060
6061 enum ResultIndices {
6062 ResultIdx = 0,
6063 };
6064
6065 RowwiseQuantizedSparseLengthsWeightedSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Scales, NodeValue Offsets, NodeValue Weights, NodeValue Indices, NodeValue Lengths, bool UseFP16Accumulation, glow::LengthsMode LengthsMode, float AvgLength)
6066 : Node(Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind, name), Data_(this, Data), Scales_(this, Scales), Offsets_(this, Offsets), Weights_(this, Weights), Indices_(this, Indices), Lengths_(this, Lengths), UseFP16Accumulation_(UseFP16Accumulation), LengthsMode_(LengthsMode), AvgLength_(AvgLength) {
6067 addResult(Result);
6068 }
6069 const NodeValue getData() const { return Data_; }
6070 const NodeValue getScales() const { return Scales_; }
6071 const NodeValue getOffsets() const { return Offsets_; }
6072 const NodeValue getWeights() const { return Weights_; }
6073 const NodeValue getIndices() const { return Indices_; }
6074 const NodeValue getLengths() const { return Lengths_; }
6075 NodeValue getResult() { return getNthResult(0); }
6076 const NodeValue getResult() const { return getNthResult(0); }
6077 bool getUseFP16Accumulation() const { return UseFP16Accumulation_; }
6078 void setUseFP16Accumulation(bool a) {UseFP16Accumulation_ = a; }
6079 glow::LengthsMode getLengthsMode() const { return LengthsMode_; }
6080 float getAvgLength() const { return AvgLength_; }
6081
6082 static bool classof(const Kinded *k) {
6083 return k->getKind() == Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind;
6084 }
6085
6086
6087 bool isOverwrittenNthInput(unsigned idx) const {
6088 return false;
6089 }
6090
6091 unsigned getNumInputs() const;
6092 std::string getInputName(unsigned idx) const;
6093 NodeValue getNthInput(unsigned idx);
6094 void setNthInput(unsigned idx, NodeValue val);
6095 llvm::StringRef getOutputName(unsigned idx) const;
6096 bool hasSideEffects() const { return 0; }
6097 bool isCanonical() const { return 1; }
6098 bool isDataParallel() const { return 0; }
6099 std::string getDebugDesc() const;
6100 bool isEqual(const RowwiseQuantizedSparseLengthsWeightedSumNode &other) const;
6101 llvm::hash_code getHash() const;
6102 void visit(Node *parent, NodeWalker *visitor);
6103 Node* clone() const;
6104 bool verify() const;
6105};
6106} // namespace glow
6107
6108
6109namespace glow {
6110/// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). Before doing aggregation, each individual slice is scaled by its weight: Result[0] = Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... It implies that len(Weights) == len(Indices). The input data is fused rowwise-quantized, where the Scales and Offsets are appended to the end of each row. Thus, Data must be a two-dimensional tensor.
6111class FusedRowwiseQuantizedSparseLengthsWeightedSumNode final : public Node {
6112 NodeHandle Data_;
6113 NodeHandle Weights_;
6114 NodeHandle Indices_;
6115 NodeHandle Lengths_;
6116 bool UseFP16Accumulation_;
6117 glow::LengthsMode LengthsMode_;
6118 float AvgLength_;
6119
6120 public:
6121 enum InputIndices {
6122 DataIdx = 0,
6123 WeightsIdx = 1,
6124 IndicesIdx = 2,
6125 LengthsIdx = 3,
6126 };
6127
6128 enum ResultIndices {
6129 ResultIdx = 0,
6130 };
6131
6132 FusedRowwiseQuantizedSparseLengthsWeightedSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Lengths, bool UseFP16Accumulation, glow::LengthsMode LengthsMode, float AvgLength)
6133 : Node(Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Lengths_(this, Lengths), UseFP16Accumulation_(UseFP16Accumulation), LengthsMode_(LengthsMode), AvgLength_(AvgLength) {
6134 addResult(Result);
6135 }
6136 const NodeValue getData() const { return Data_; }
6137 const NodeValue getWeights() const { return Weights_; }
6138 const NodeValue getIndices() const { return Indices_; }
6139 const NodeValue getLengths() const { return Lengths_; }
6140 NodeValue getResult() { return getNthResult(0); }
6141 const NodeValue getResult() const { return getNthResult(0); }
6142 bool getUseFP16Accumulation() const { return UseFP16Accumulation_; }
6143 void setUseFP16Accumulation(bool a) {UseFP16Accumulation_ = a; }
6144 glow::LengthsMode getLengthsMode() const { return LengthsMode_; }
6145 float getAvgLength() const { return AvgLength_; }
6146
6147 static bool classof(const Kinded *k) {
6148 return k->getKind() == Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind;
6149 }
6150
6151
6152 bool isOverwrittenNthInput(unsigned idx) const {
6153 return false;
6154 }
6155
6156 unsigned getNumInputs() const;
6157 std::string getInputName(unsigned idx) const;
6158 NodeValue getNthInput(unsigned idx);
6159 void setNthInput(unsigned idx, NodeValue val);
6160 llvm::StringRef getOutputName(unsigned idx) const;
6161 bool hasSideEffects() const { return 0; }
6162 bool isCanonical() const { return 1; }
6163 bool isDataParallel() const { return 0; }
6164 std::string getDebugDesc() const;
6165 bool isEqual(const FusedRowwiseQuantizedSparseLengthsWeightedSumNode &other) const;
6166 llvm::hash_code getHash() const;
6167 void visit(Node *parent, NodeWalker *visitor);
6168 Node* clone() const;
6169 bool verify() const;
6170};
6171} // namespace glow
6172
6173
6174namespace glow {
6175/// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). The input data is fused rowwise-quantized, where the Scales and Offsets are appended to the end of each row. Thus, Data must be a two-dimensional tensor.
6176class FusedRowwiseQuantizedSparseLengthsSumNode final : public Node {
6177 NodeHandle Data_;
6178 NodeHandle Indices_;
6179 NodeHandle Lengths_;
6180 bool UseFP16Accumulation_;
6181 glow::LengthsMode LengthsMode_;
6182 float AvgLength_;
6183
6184 public:
6185 enum InputIndices {
6186 DataIdx = 0,
6187 IndicesIdx = 1,
6188 LengthsIdx = 2,
6189 };
6190
6191 enum ResultIndices {
6192 ResultIdx = 0,
6193 };
6194
6195 FusedRowwiseQuantizedSparseLengthsSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, NodeValue Lengths, bool UseFP16Accumulation, glow::LengthsMode LengthsMode, float AvgLength)
6196 : Node(Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind, name), Data_(this, Data), Indices_(this, Indices), Lengths_(this, Lengths), UseFP16Accumulation_(UseFP16Accumulation), LengthsMode_(LengthsMode), AvgLength_(AvgLength) {
6197 addResult(Result);
6198 }
6199 const NodeValue getData() const { return Data_; }
6200 const NodeValue getIndices() const { return Indices_; }
6201 const NodeValue getLengths() const { return Lengths_; }
6202 NodeValue getResult() { return getNthResult(0); }
6203 const NodeValue getResult() const { return getNthResult(0); }
6204 bool getUseFP16Accumulation() const { return UseFP16Accumulation_; }
6205 void setUseFP16Accumulation(bool a) {UseFP16Accumulation_ = a; }
6206 glow::LengthsMode getLengthsMode() const { return LengthsMode_; }
6207 float getAvgLength() const { return AvgLength_; }
6208
6209 static bool classof(const Kinded *k) {
6210 return k->getKind() == Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind;
6211 }
6212
6213
6214 bool isOverwrittenNthInput(unsigned idx) const {
6215 return false;
6216 }
6217
6218 unsigned getNumInputs() const;
6219 std::string getInputName(unsigned idx) const;
6220 NodeValue getNthInput(unsigned idx);
6221 void setNthInput(unsigned idx, NodeValue val);
6222 llvm::StringRef getOutputName(unsigned idx) const;
6223 bool hasSideEffects() const { return 0; }
6224 bool isCanonical() const { return 1; }
6225 bool isDataParallel() const { return 0; }
6226 std::string getDebugDesc() const;
6227 bool isEqual(const FusedRowwiseQuantizedSparseLengthsSumNode &other) const;
6228 llvm::hash_code getHash() const;
6229 void visit(Node *parent, NodeWalker *visitor);
6230 Node* clone() const;
6231 bool verify() const;
6232};
6233} // namespace glow
6234
6235
6236namespace glow {
6237/// Given a vector of segment lengths, calculates offsets of each segment and packs them next to the lengths. For the input vector of length N the output is a Nx2 matrix with (offset, lengths) packaged for each segment.
6238class LengthsToRangesNode final : public Node {
6239 NodeHandle Lengths_;
6240
6241 public:
6242 enum InputIndices {
6243 LengthsIdx = 0,
6244 };
6245
6246 enum ResultIndices {
6247 ResultIdx = 0,
6248 };
6249
6250 LengthsToRangesNode(llvm::StringRef name, TypeRef Result , NodeValue Lengths)
6251 : Node(Kinded::Kind::LengthsToRangesNodeKind, name), Lengths_(this, Lengths) {
6252 addResult(Result);
6253 }
6254 const NodeValue getLengths() const { return Lengths_; }
6255 NodeValue getResult() { return getNthResult(0); }
6256 const NodeValue getResult() const { return getNthResult(0); }
6257
6258 static bool classof(const Kinded *k) {
6259 return k->getKind() == Kinded::Kind::LengthsToRangesNodeKind;
6260 }
6261
6262
6263 bool isOverwrittenNthInput(unsigned idx) const {
6264 return false;
6265 }
6266
6267 unsigned getNumInputs() const;
6268 std::string getInputName(unsigned idx) const;
6269 NodeValue getNthInput(unsigned idx);
6270 void setNthInput(unsigned idx, NodeValue val);
6271 llvm::StringRef getOutputName(unsigned idx) const;
6272 bool hasSideEffects() const { return 0; }
6273 bool isCanonical() const { return 1; }
6274 bool isDataParallel() const { return 0; }
6275 std::string getDebugDesc() const;
6276 bool isEqual(const LengthsToRangesNode &other) const;
6277 llvm::hash_code getHash() const;
6278 void visit(Node *parent, NodeWalker *visitor);
6279 Node* clone() const;
6280 bool verify() const;
6281};
6282} // namespace glow
6283
6284
6285namespace glow {
6286/// Converts an input Lengths 1D vector into a range sequence.
6287class LengthsRangeFillNode final : public Node {
6288 NodeHandle Lengths_;
6289
6290 public:
6291 enum InputIndices {
6292 LengthsIdx = 0,
6293 };
6294
6295 enum ResultIndices {
6296 ResultIdx = 0,
6297 };
6298
6299 LengthsRangeFillNode(llvm::StringRef name, TypeRef Result , NodeValue Lengths)
6300 : Node(Kinded::Kind::LengthsRangeFillNodeKind, name), Lengths_(this, Lengths) {
6301 addResult(Result);
6302 }
6303 const NodeValue getLengths() const { return Lengths_; }
6304 NodeValue getResult() { return getNthResult(0); }
6305 const NodeValue getResult() const { return getNthResult(0); }
6306
6307 static bool classof(const Kinded *k) {
6308 return k->getKind() == Kinded::Kind::LengthsRangeFillNodeKind;
6309 }
6310
6311
6312 bool isOverwrittenNthInput(unsigned idx) const {
6313 return false;
6314 }
6315
6316 unsigned getNumInputs() const;
6317 std::string getInputName(unsigned idx) const;
6318 NodeValue getNthInput(unsigned idx);
6319 void setNthInput(unsigned idx, NodeValue val);
6320 llvm::StringRef getOutputName(unsigned idx) const;
6321 bool hasSideEffects() const { return 0; }
6322 bool isCanonical() const { return 1; }
6323 bool isDataParallel() const { return 0; }
6324 std::string getDebugDesc() const;
6325 bool isEqual(const LengthsRangeFillNode &other) const;
6326 llvm::hash_code getHash() const;
6327 void visit(Node *parent, NodeWalker *visitor);
6328 Node* clone() const;
6329 bool verify() const;
6330};
6331} // namespace glow
6332
6333
6334namespace glow {
6335/// Converts the sparse representation specified by (Lengths, Indices, Values) into a dense one. In the dense representation, elements of the lengths vector represent the number of indices in the corresponding batch, where each batch contains each value from Values at the corresponding index specified in Indices, and is filled with DefaultValue otherwise. Within each batch, Indices shouldn't contain duplicate indices.
6336class BatchSparseToDenseNode final : public Node {
6337 NodeHandle Lengths_;
6338 NodeHandle Indices_;
6339 NodeHandle Values_;
6340 float DefaultValue_;
6341 unsigned_t DenseLastDim_;
6342
6343 public:
6344 enum InputIndices {
6345 LengthsIdx = 0,
6346 IndicesIdx = 1,
6347 ValuesIdx = 2,
6348 };
6349
6350 enum ResultIndices {
6351 ResultIdx = 0,
6352 };
6353
6354 BatchSparseToDenseNode(llvm::StringRef name, TypeRef Result , NodeValue Lengths, NodeValue Indices, NodeValue Values, float DefaultValue, unsigned_t DenseLastDim)
6355 : Node(Kinded::Kind::BatchSparseToDenseNodeKind, name), Lengths_(this, Lengths), Indices_(this, Indices), Values_(this, Values), DefaultValue_(DefaultValue), DenseLastDim_(DenseLastDim) {
6356 addResult(Result);
6357 }
6358 const NodeValue getLengths() const { return Lengths_; }
6359 const NodeValue getIndices() const { return Indices_; }
6360 const NodeValue getValues() const { return Values_; }
6361 NodeValue getResult() { return getNthResult(0); }
6362 const NodeValue getResult() const { return getNthResult(0); }
6363 float getDefaultValue() const { return DefaultValue_; }
6364 unsigned_t getDenseLastDim() const { return DenseLastDim_; }
6365
6366 static bool classof(const Kinded *k) {
6367 return k->getKind() == Kinded::Kind::BatchSparseToDenseNodeKind;
6368 }
6369
6370
6371 bool isOverwrittenNthInput(unsigned idx) const {
6372 return false;
6373 }
6374
6375 unsigned getNumInputs() const;
6376 std::string getInputName(unsigned idx) const;
6377 NodeValue getNthInput(unsigned idx);
6378 void setNthInput(unsigned idx, NodeValue val);
6379 llvm::StringRef getOutputName(unsigned idx) const;
6380 bool hasSideEffects() const { return 0; }
6381 bool isCanonical() const { return 1; }
6382 bool isDataParallel() const { return 0; }
6383 std::string getDebugDesc() const;
6384 bool isEqual(const BatchSparseToDenseNode &other) const;
6385 llvm::hash_code getHash() const;
6386 void visit(Node *parent, NodeWalker *visitor);
6387 Node* clone() const;
6388 bool verify() const;
6389};
6390} // namespace glow
6391
6392
6393namespace glow {
6394/// Inserts zeros into data along axis=0 for indices where indicator is zero.
6395class FillExamplesWithIndicatorNode final : public Node {
6396 NodeHandle Data_;
6397 NodeHandle Indicator_;
6398
6399 public:
6400 enum InputIndices {
6401 DataIdx = 0,
6402 IndicatorIdx = 1,
6403 };
6404
6405 enum ResultIndices {
6406 ResultIdx = 0,
6407 };
6408
6409 FillExamplesWithIndicatorNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indicator)
6410 : Node(Kinded::Kind::FillExamplesWithIndicatorNodeKind, name), Data_(this, Data), Indicator_(this, Indicator) {
6411 addResult(Result);
6412 }
6413 const NodeValue getData() const { return Data_; }
6414 const NodeValue getIndicator() const { return Indicator_; }
6415 NodeValue getResult() { return getNthResult(0); }
6416 const NodeValue getResult() const { return getNthResult(0); }
6417
6418 static bool classof(const Kinded *k) {
6419 return k->getKind() == Kinded::Kind::FillExamplesWithIndicatorNodeKind;
6420 }
6421
6422
6423 bool isOverwrittenNthInput(unsigned idx) const {
6424 return false;
6425 }
6426
6427 unsigned getNumInputs() const;
6428 std::string getInputName(unsigned idx) const;
6429 NodeValue getNthInput(unsigned idx);
6430 void setNthInput(unsigned idx, NodeValue val);
6431 llvm::StringRef getOutputName(unsigned idx) const;
6432 bool hasSideEffects() const { return 0; }
6433 bool isCanonical() const { return 1; }
6434 bool isDataParallel() const { return 0; }
6435 std::string getDebugDesc() const;
6436 bool isEqual(const FillExamplesWithIndicatorNode &other) const;
6437 llvm::hash_code getHash() const;
6438 void visit(Node *parent, NodeWalker *visitor);
6439 Node* clone() const;
6440 bool verify() const;
6441};
6442} // namespace glow
6443
6444
6445namespace glow {
6446/// Converts the sparse representation specified by the pair (Indices, Values) into a dense one, where compacted tensor only contains IDs from given Mask. Indices cannot contain duplicate values. Lengths is used to distinguish elements from different examples of one batch. That is, first Lengths[0] index-value pairs belong to batch's example 0, next Lengths[1] pairs belong to example 1, and so on.
6447class SparseToDenseMaskNode final : public Node {
6448 NodeHandle Indices_;
6449 NodeHandle Values_;
6450 NodeHandle DefaultValue_;
6451 NodeHandle Lengths_;
6452 std::vector<dim_t> Mask_;
6453
6454 public:
6455 enum InputIndices {
6456 IndicesIdx = 0,
6457 ValuesIdx = 1,
6458 DefaultValueIdx = 2,
6459 LengthsIdx = 3,
6460 };
6461
6462 enum ResultIndices {
6463 ResultIdx = 0,
6464 };
6465
6466 SparseToDenseMaskNode(llvm::StringRef name, TypeRef Result , NodeValue Indices, NodeValue Values, NodeValue DefaultValue, NodeValue Lengths, std::vector<dim_t> Mask)
6467 : Node(Kinded::Kind::SparseToDenseMaskNodeKind, name), Indices_(this, Indices), Values_(this, Values), DefaultValue_(this, DefaultValue), Lengths_(this, Lengths), Mask_(Mask) {
6468 addResult(Result);
6469 }
6470 const NodeValue getIndices() const { return Indices_; }
6471 const NodeValue getValues() const { return Values_; }
6472 const NodeValue getDefaultValue() const { return DefaultValue_; }
6473 const NodeValue getLengths() const { return Lengths_; }
6474 NodeValue getResult() { return getNthResult(0); }
6475 const NodeValue getResult() const { return getNthResult(0); }
6476 llvm::ArrayRef<dim_t> getMask() const { return Mask_; }
6477
6478 static bool classof(const Kinded *k) {
6479 return k->getKind() == Kinded::Kind::SparseToDenseMaskNodeKind;
6480 }
6481
6482
6483 bool isOverwrittenNthInput(unsigned idx) const {
6484 return false;
6485 }
6486
6487 unsigned getNumInputs() const;
6488 std::string getInputName(unsigned idx) const;
6489 NodeValue getNthInput(unsigned idx);
6490 void setNthInput(unsigned idx, NodeValue val);
6491 llvm::StringRef getOutputName(unsigned idx) const;
6492 bool hasSideEffects() const { return 0; }
6493 bool isCanonical() const { return 1; }
6494 bool isDataParallel() const { return 0; }
6495 std::string getDebugDesc() const;
6496 bool isEqual(const SparseToDenseMaskNode &other) const;
6497 llvm::hash_code getHash() const;
6498 void visit(Node *parent, NodeWalker *visitor);
6499 Node* clone() const;
6500 bool verify() const;
6501};
6502} // namespace glow
6503
6504
6505namespace glow {
6506/// Determines whether each element of the Input is NaN and generates a mask that can be consumed by a Select node.
6507class IsNaNNode final : public Node {
6508 NodeHandle Input_;
6509
6510 public:
6511 enum InputIndices {
6512 InputIdx = 0,
6513 };
6514
6515 enum ResultIndices {
6516 ResultIdx = 0,
6517 };
6518
6519 IsNaNNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
6520 : Node(Kinded::Kind::IsNaNNodeKind, name), Input_(this, Input) {
6521 addResult(Result);
6522 }
6523 const NodeValue getInput() const { return Input_; }
6524 NodeValue getResult() { return getNthResult(0); }
6525 const NodeValue getResult() const { return getNthResult(0); }
6526
6527 static bool classof(const Kinded *k) {
6528 return k->getKind() == Kinded::Kind::IsNaNNodeKind;
6529 }
6530
6531
6532 bool isOverwrittenNthInput(unsigned idx) const {
6533 return false;
6534 }
6535
6536 unsigned getNumInputs() const;
6537 std::string getInputName(unsigned idx) const;
6538 NodeValue getNthInput(unsigned idx);
6539 void setNthInput(unsigned idx, NodeValue val);
6540 llvm::StringRef getOutputName(unsigned idx) const;
6541 bool hasSideEffects() const { return 0; }
6542 bool isCanonical() const { return 1; }
6543 bool isDataParallel() const { return 1; }
6544 std::string getDebugDesc() const;
6545 bool isEqual(const IsNaNNode &other) const;
6546 llvm::hash_code getHash() const;
6547 void visit(Node *parent, NodeWalker *visitor);
6548 Node* clone() const;
6549 bool verify() const;
6550};
6551} // namespace glow
6552
6553
6554namespace glow {
6555/// Replaces NaNs found in Input with Value.
6556class ReplaceNaNNode final : public Node {
6557 NodeHandle Input_;
6558 float Value_;
6559
6560 public:
6561 enum InputIndices {
6562 InputIdx = 0,
6563 };
6564
6565 enum ResultIndices {
6566 ResultIdx = 0,
6567 };
6568
6569 ReplaceNaNNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Value)
6570 : Node(Kinded::Kind::ReplaceNaNNodeKind, name), Input_(this, Input), Value_(Value) {
6571 addResult(Result);
6572 }
6573 const NodeValue getInput() const { return Input_; }
6574 NodeValue getResult() { return getNthResult(0); }
6575 const NodeValue getResult() const { return getNthResult(0); }
6576 float getValue() const { return Value_; }
6577
6578 static bool classof(const Kinded *k) {
6579 return k->getKind() == Kinded::Kind::ReplaceNaNNodeKind;
6580 }
6581
6582
6583 bool isOverwrittenNthInput(unsigned idx) const {
6584 return false;
6585 }
6586
6587 unsigned getNumInputs() const;
6588 std::string getInputName(unsigned idx) const;
6589 NodeValue getNthInput(unsigned idx);
6590 void setNthInput(unsigned idx, NodeValue val);
6591 llvm::StringRef getOutputName(unsigned idx) const;
6592 bool hasSideEffects() const { return 0; }
6593 bool isCanonical() const { return 1; }
6594 bool isDataParallel() const { return 0; }
6595 std::string getDebugDesc() const;
6596 bool isEqual(const ReplaceNaNNode &other) const;
6597 llvm::hash_code getHash() const;
6598 void visit(Node *parent, NodeWalker *visitor);
6599 Node* clone() const;
6600 bool verify() const;
6601};
6602} // namespace glow
6603
6604
6605namespace glow {
6606/// Performs elementwise modulo operation on the input where each element in the output is the corresponding element in the input data modulo Divisor.
6607class ModuloNode final : public Node {
6608 NodeHandle Input_;
6609 int64_t Divisor_;
6610 bool SignFollowDivisor_;
6611
6612 public:
6613 enum InputIndices {
6614 InputIdx = 0,
6615 };
6616
6617 enum ResultIndices {
6618 ResultIdx = 0,
6619 };
6620
6621 ModuloNode(llvm::StringRef name, TypeRef Result , NodeValue Input, int64_t Divisor, bool SignFollowDivisor)
6622 : Node(Kinded::Kind::ModuloNodeKind, name), Input_(this, Input), Divisor_(Divisor), SignFollowDivisor_(SignFollowDivisor) {
6623 addResult(Result);
6624 }
6625 const NodeValue getInput() const { return Input_; }
6626 NodeValue getResult() { return getNthResult(0); }
6627 const NodeValue getResult() const { return getNthResult(0); }
6628 int64_t getDivisor() const { return Divisor_; }
6629 bool getSignFollowDivisor() const { return SignFollowDivisor_; }
6630
6631 static bool classof(const Kinded *k) {
6632 return k->getKind() == Kinded::Kind::ModuloNodeKind;
6633 }
6634
6635
6636 bool isOverwrittenNthInput(unsigned idx) const {
6637 return false;
6638 }
6639
6640 unsigned getNumInputs() const;
6641 std::string getInputName(unsigned idx) const;
6642 NodeValue getNthInput(unsigned idx);
6643 void setNthInput(unsigned idx, NodeValue val);
6644 llvm::StringRef getOutputName(unsigned idx) const;
6645 bool hasSideEffects() const { return 0; }
6646 bool isCanonical() const { return 1; }
6647 bool isDataParallel() const { return 1; }
6648 std::string getDebugDesc() const;
6649 bool isEqual(const ModuloNode &other) const;
6650 llvm::hash_code getHash() const;
6651 void visit(Node *parent, NodeWalker *visitor);
6652 Node* clone() const;
6653 bool verify() const;
6654};
6655} // namespace glow
6656
6657
6658namespace glow {
6659/// Performs batched pairwise dot products of the input vectors
6660class BatchedPairwiseDotProductNode final : public Node {
6661 std::vector<NodeHandle> Inputs_;
6662
6663 public:
6664 enum InputIndices {
6665 };
6666
6667 enum ResultIndices {
6668 ResultIdx = 0,
6669 };
6670
6671 BatchedPairwiseDotProductNode(llvm::StringRef name, TypeRef Result , std::vector<NodeValue> Inputs)
6672 : Node(Kinded::Kind::BatchedPairwiseDotProductNodeKind, name) {
6673 addResult(Result);
6674 Inputs_.resize(Inputs.size());
6675 for (size_t idx = 0, e = Inputs.size(); idx < e; ++idx) {
6676 Inputs_[idx] = Inputs[idx];
6677 Inputs_[idx].setParent(this);
6678 }
6679 }
6680 NodeValue getResult() { return getNthResult(0); }
6681 const NodeValue getResult() const { return getNthResult(0); }
6682 NodeValueArrayRef getInputs() const { return Inputs_; }
6683
6684 static bool classof(const Kinded *k) {
6685 return k->getKind() == Kinded::Kind::BatchedPairwiseDotProductNodeKind;
6686 }
6687
6688
6689 bool isOverwrittenNthInput(unsigned idx) const {
6690 return false;
6691 }
6692
6693 unsigned getNumInputs() const;
6694 std::string getInputName(unsigned idx) const;
6695 NodeValue getNthInput(unsigned idx);
6696 void setNthInput(unsigned idx, NodeValue val);
6697 llvm::StringRef getOutputName(unsigned idx) const;
6698 bool hasSideEffects() const { return 0; }
6699 bool isCanonical() const { return 1; }
6700 bool isDataParallel() const { return 0; }
6701 std::string getDebugDesc() const;
6702 bool isEqual(const BatchedPairwiseDotProductNode &other) const;
6703 llvm::hash_code getHash() const;
6704 void visit(Node *parent, NodeWalker *visitor);
6705 Node* clone() const;
6706 bool verify() const;
6707};
6708} // namespace glow
6709
6710
6711namespace glow {
6712/// Performs the gradient operation for BatchedPairwiseDotProduct
6713class BatchedPairwiseDotProductGradNode final : public Node {
6714 NodeHandle OutputGrad_;
6715 std::vector<NodeHandle> OriginalInputs_;
6716
6717 public:
6718 enum InputIndices {
6719 OutputGradIdx = 0,
6720 };
6721
6722 enum ResultIndices {
6723 };
6724
6725 BatchedPairwiseDotProductGradNode(llvm::StringRef name, NodeValue OutputGrad, std::vector<NodeValue> OriginalInputs)
6726 : Node(Kinded::Kind::BatchedPairwiseDotProductGradNodeKind, name), OutputGrad_(this, OutputGrad) {
6727 OriginalInputs_.resize(OriginalInputs.size());
6728 for (size_t idx = 0, e = OriginalInputs.size(); idx < e; ++idx) {
6729 OriginalInputs_[idx] = OriginalInputs[idx];
6730 OriginalInputs_[idx].setParent(this);
6731 }
6732 }
6733 const NodeValue getOutputGrad() const { return OutputGrad_; }
6734 NodeValueArrayRef getOriginalInputs() const { return OriginalInputs_; }
6735
6736 static bool classof(const Kinded *k) {
6737 return k->getKind() == Kinded::Kind::BatchedPairwiseDotProductGradNodeKind;
6738 }
6739
6740
6741 bool isOverwrittenNthInput(unsigned idx) const {
6742 return false;
6743 }
6744
6745 unsigned getNumInputs() const;
6746 std::string getInputName(unsigned idx) const;
6747 NodeValue getNthInput(unsigned idx);
6748 void setNthInput(unsigned idx, NodeValue val);
6749 llvm::StringRef getOutputName(unsigned idx) const;
6750 bool hasSideEffects() const { return 0; }
6751 bool isCanonical() const { return 1; }
6752 bool isDataParallel() const { return 0; }
6753 std::string getDebugDesc() const;
6754 bool isEqual(const BatchedPairwiseDotProductGradNode &other) const;
6755 llvm::hash_code getHash() const;
6756 void visit(Node *parent, NodeWalker *visitor);
6757 Node* clone() const;
6758 bool verify() const;
6759 void addExtraResult(TypeRef T) { addResult(T); }
6760};
6761} // namespace glow
6762
6763
6764namespace glow {
6765/// Sum weight embeddings according to offsets and indices
6766class BatchedUnaryEmbeddingsBagsNode final : public Node {
6767 NodeHandle Weights_;
6768 NodeHandle TableOffsets_;
6769 NodeHandle Offsets_;
6770 NodeHandle Indices_;
6771
6772 public:
6773 enum InputIndices {
6774 WeightsIdx = 0,
6775 TableOffsetsIdx = 1,
6776 OffsetsIdx = 2,
6777 IndicesIdx = 3,
6778 };
6779
6780 enum ResultIndices {
6781 ResultIdx = 0,
6782 };
6783
6784 BatchedUnaryEmbeddingsBagsNode(llvm::StringRef name, TypeRef Result , NodeValue Weights, NodeValue TableOffsets, NodeValue Offsets, NodeValue Indices)
6785 : Node(Kinded::Kind::BatchedUnaryEmbeddingsBagsNodeKind, name), Weights_(this, Weights), TableOffsets_(this, TableOffsets), Offsets_(this, Offsets), Indices_(this, Indices) {
6786 addResult(Result);
6787 }
6788 const NodeValue getWeights() const { return Weights_; }
6789 const NodeValue getTableOffsets() const { return TableOffsets_; }
6790 const NodeValue getOffsets() const { return Offsets_; }
6791 const NodeValue getIndices() const { return Indices_; }
6792 NodeValue getResult() { return getNthResult(0); }
6793 const NodeValue getResult() const { return getNthResult(0); }
6794
6795 static bool classof(const Kinded *k) {
6796 return k->getKind() == Kinded::Kind::BatchedUnaryEmbeddingsBagsNodeKind;
6797 }
6798
6799
6800 bool isOverwrittenNthInput(unsigned idx) const {
6801 return false;
6802 }
6803
6804 unsigned getNumInputs() const;
6805 std::string getInputName(unsigned idx) const;
6806 NodeValue getNthInput(unsigned idx);
6807 void setNthInput(unsigned idx, NodeValue val);
6808 llvm::StringRef getOutputName(unsigned idx) const;
6809 bool hasSideEffects() const { return 0; }
6810 bool isCanonical() const { return 1; }
6811 bool isDataParallel() const { return 0; }
6812 std::string getDebugDesc() const;
6813 bool isEqual(const BatchedUnaryEmbeddingsBagsNode &other) const;
6814 llvm::hash_code getHash() const;
6815 void visit(Node *parent, NodeWalker *visitor);
6816 Node* clone() const;
6817 bool verify() const;
6818};
6819} // namespace glow
6820
6821
6822namespace glow {
6823/// Table based batched embeddingbags with quantization support. Experimental only and subject to change.
6824class IntNBitSplitEmbeddingBagsNode final : public Node {
6825 NodeHandle DevWeights_;
6826 NodeHandle UvmWeights_;
6827 NodeHandle WeightsPlacements_;
6828 NodeHandle WeightsOffsets_;
6829 NodeHandle WeightsTys_;
6830 NodeHandle DimOffsets_;
6831 NodeHandle Indices_;
6832 NodeHandle Offsets_;
6833 int64_t TotalDims_;
6834 glow::SplitEmbeddingPoolingMode PoolingMode_;
6835 glow::SplitEmbeddingSparseType OutputDType_;
6836
6837 public:
6838 enum InputIndices {
6839 DevWeightsIdx = 0,
6840 UvmWeightsIdx = 1,
6841 WeightsPlacementsIdx = 2,
6842 WeightsOffsetsIdx = 3,
6843 WeightsTysIdx = 4,
6844 DimOffsetsIdx = 5,
6845 IndicesIdx = 6,
6846 OffsetsIdx = 7,
6847 };
6848
6849 enum ResultIndices {
6850 ResultIdx = 0,
6851 };
6852
6853 IntNBitSplitEmbeddingBagsNode(llvm::StringRef name, TypeRef Result , NodeValue DevWeights, NodeValue UvmWeights, NodeValue WeightsPlacements, NodeValue WeightsOffsets, NodeValue WeightsTys, NodeValue DimOffsets, NodeValue Indices, NodeValue Offsets, int64_t TotalDims, glow::SplitEmbeddingPoolingMode PoolingMode, glow::SplitEmbeddingSparseType OutputDType)
6854 : Node(Kinded::Kind::IntNBitSplitEmbeddingBagsNodeKind, name), DevWeights_(this, DevWeights), UvmWeights_(this, UvmWeights), WeightsPlacements_(this, WeightsPlacements), WeightsOffsets_(this, WeightsOffsets), WeightsTys_(this, WeightsTys), DimOffsets_(this, DimOffsets), Indices_(this, Indices), Offsets_(this, Offsets), TotalDims_(TotalDims), PoolingMode_(PoolingMode), OutputDType_(OutputDType) {
6855 addResult(Result);
6856 }
6857 const NodeValue getDevWeights() const { return DevWeights_; }
6858 const NodeValue getUvmWeights() const { return UvmWeights_; }
6859 const NodeValue getWeightsPlacements() const { return WeightsPlacements_; }
6860 const NodeValue getWeightsOffsets() const { return WeightsOffsets_; }
6861 const NodeValue getWeightsTys() const { return WeightsTys_; }
6862 const NodeValue getDimOffsets() const { return DimOffsets_; }
6863 const NodeValue getIndices() const { return Indices_; }
6864 const NodeValue getOffsets() const { return Offsets_; }
6865 NodeValue getResult() { return getNthResult(0); }
6866 const NodeValue getResult() const { return getNthResult(0); }
6867 int64_t getTotalDims() const { return TotalDims_; }
6868 glow::SplitEmbeddingPoolingMode getPoolingMode() const { return PoolingMode_; }
6869 glow::SplitEmbeddingSparseType getOutputDType() const { return OutputDType_; }
6870
6871 static bool classof(const Kinded *k) {
6872 return k->getKind() == Kinded::Kind::IntNBitSplitEmbeddingBagsNodeKind;
6873 }
6874
6875
6876 bool isOverwrittenNthInput(unsigned idx) const {
6877 return false;
6878 }
6879
6880 unsigned getNumInputs() const;
6881 std::string getInputName(unsigned idx) const;
6882 NodeValue getNthInput(unsigned idx);
6883 void setNthInput(unsigned idx, NodeValue val);
6884 llvm::StringRef getOutputName(unsigned idx) const;
6885 bool hasSideEffects() const { return 0; }
6886 bool isCanonical() const { return 1; }
6887 bool isDataParallel() const { return 0; }
6888 std::string getDebugDesc() const;
6889 bool isEqual(const IntNBitSplitEmbeddingBagsNode &other) const;
6890 llvm::hash_code getHash() const;
6891 void visit(Node *parent, NodeWalker *visitor);
6892 Node* clone() const;
6893 bool verify() const;
6894};
6895} // namespace glow
6896
6897
6898namespace glow {
6899/// Table based batched embeddingbags with quantization support and indice weights. Experimental only and subject to change.
6900class IntNBitSplitEmbeddingWeightedBagsNode final : public Node {
6901 NodeHandle DevWeights_;
6902 NodeHandle UvmWeights_;
6903 NodeHandle WeightsPlacements_;
6904 NodeHandle WeightsOffsets_;
6905 NodeHandle WeightsTys_;
6906 NodeHandle DimOffsets_;
6907 NodeHandle Indices_;
6908 NodeHandle Offsets_;
6909 NodeHandle IndiceWeight_;
6910 int64_t TotalDims_;
6911 glow::SplitEmbeddingPoolingMode PoolingMode_;
6912 glow::SplitEmbeddingSparseType OutputDType_;
6913
6914 public:
6915 enum InputIndices {
6916 DevWeightsIdx = 0,
6917 UvmWeightsIdx = 1,
6918 WeightsPlacementsIdx = 2,
6919 WeightsOffsetsIdx = 3,
6920 WeightsTysIdx = 4,
6921 DimOffsetsIdx = 5,
6922 IndicesIdx = 6,
6923 OffsetsIdx = 7,
6924 IndiceWeightIdx = 8,
6925 };
6926
6927 enum ResultIndices {
6928 ResultIdx = 0,
6929 };
6930
6931 IntNBitSplitEmbeddingWeightedBagsNode(llvm::StringRef name, TypeRef Result , NodeValue DevWeights, NodeValue UvmWeights, NodeValue WeightsPlacements, NodeValue WeightsOffsets, NodeValue WeightsTys, NodeValue DimOffsets, NodeValue Indices, NodeValue Offsets, NodeValue IndiceWeight, int64_t TotalDims, glow::SplitEmbeddingPoolingMode PoolingMode, glow::SplitEmbeddingSparseType OutputDType)
6932 : Node(Kinded::Kind::IntNBitSplitEmbeddingWeightedBagsNodeKind, name), DevWeights_(this, DevWeights), UvmWeights_(this, UvmWeights), WeightsPlacements_(this, WeightsPlacements), WeightsOffsets_(this, WeightsOffsets), WeightsTys_(this, WeightsTys), DimOffsets_(this, DimOffsets), Indices_(this, Indices), Offsets_(this, Offsets), IndiceWeight_(this, IndiceWeight), TotalDims_(TotalDims), PoolingMode_(PoolingMode), OutputDType_(OutputDType) {
6933 addResult(Result);
6934 }
6935 const NodeValue getDevWeights() const { return DevWeights_; }
6936 const NodeValue getUvmWeights() const { return UvmWeights_; }
6937 const NodeValue getWeightsPlacements() const { return WeightsPlacements_; }
6938 const NodeValue getWeightsOffsets() const { return WeightsOffsets_; }
6939 const NodeValue getWeightsTys() const { return WeightsTys_; }
6940 const NodeValue getDimOffsets() const { return DimOffsets_; }
6941 const NodeValue getIndices() const { return Indices_; }
6942 const NodeValue getOffsets() const { return Offsets_; }
6943 const NodeValue getIndiceWeight() const { return IndiceWeight_; }
6944 NodeValue getResult() { return getNthResult(0); }
6945 const NodeValue getResult() const { return getNthResult(0); }
6946 int64_t getTotalDims() const { return TotalDims_; }
6947 glow::SplitEmbeddingPoolingMode getPoolingMode() const { return PoolingMode_; }
6948 glow::SplitEmbeddingSparseType getOutputDType() const { return OutputDType_; }
6949
6950 static bool classof(const Kinded *k) {
6951 return k->getKind() == Kinded::Kind::IntNBitSplitEmbeddingWeightedBagsNodeKind;
6952 }
6953
6954
6955 bool isOverwrittenNthInput(unsigned idx) const {
6956 return false;
6957 }
6958
6959 unsigned getNumInputs() const;
6960 std::string getInputName(unsigned idx) const;
6961 NodeValue getNthInput(unsigned idx);
6962 void setNthInput(unsigned idx, NodeValue val);
6963 llvm::StringRef getOutputName(unsigned idx) const;
6964 bool hasSideEffects() const { return 0; }
6965 bool isCanonical() const { return 1; }
6966 bool isDataParallel() const { return 0; }
6967 std::string getDebugDesc() const;
6968 bool isEqual(const IntNBitSplitEmbeddingWeightedBagsNode &other) const;
6969 llvm::hash_code getHash() const;
6970 void visit(Node *parent, NodeWalker *visitor);
6971 Node* clone() const;
6972 bool verify() const;
6973};
6974} // namespace glow
6975
6976
6977namespace glow {
6978/// Fills an output tensor with samples drawn from a normal distribution specified by the mean and standard deviation arguments. The output tensor shape is determined by the input shape if provided, and shape otherwise
6979class GaussianFillNode final : public Node {
6980 NodeHandle Input_;
6981 float Mean_;
6982 float Scale_;
6983 float Seed_;
6984
6985 public:
6986 enum InputIndices {
6987 InputIdx = 0,
6988 };
6989
6990 enum ResultIndices {
6991 ResultIdx = 0,
6992 };
6993
6994 GaussianFillNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Mean, float Scale, float Seed)
6995 : Node(Kinded::Kind::GaussianFillNodeKind, name), Input_(this, Input), Mean_(Mean), Scale_(Scale), Seed_(Seed) {
6996 addResult(Result);
6997 }
6998 const NodeValue getInput() const { return Input_; }
6999 NodeValue getResult() { return getNthResult(0); }
7000 const NodeValue getResult() const { return getNthResult(0); }
7001 float getMean() const { return Mean_; }
7002 float getScale() const { return Scale_; }
7003 float getSeed() const { return Seed_; }
7004
7005 static bool classof(const Kinded *k) {
7006 return k->getKind() == Kinded::Kind::GaussianFillNodeKind;
7007 }
7008
7009
7010 bool isOverwrittenNthInput(unsigned idx) const {
7011 return false;
7012 }
7013
7014 unsigned getNumInputs() const;
7015 std::string getInputName(unsigned idx) const;
7016 NodeValue getNthInput(unsigned idx);
7017 void setNthInput(unsigned idx, NodeValue val);
7018 llvm::StringRef getOutputName(unsigned idx) const;
7019 bool hasSideEffects() const { return 0; }
7020 bool isCanonical() const { return 1; }
7021 bool isDataParallel() const { return 0; }
7022 std::string getDebugDesc() const;
7023 bool isEqual(const GaussianFillNode &other) const;
7024 llvm::hash_code getHash() const;
7025 void visit(Node *parent, NodeWalker *visitor);
7026 Node* clone() const;
7027 bool verify() const;
7028};
7029} // namespace glow
7030
7031
7032namespace glow {
7033class ReluGradNode final : public Node {
7034 NodeHandle Input_;
7035 NodeHandle OriginalOutputForResult_;
7036 NodeHandle GradOfOriginalOutputNamedResult_;
7037
7038 public:
7039 enum InputIndices {
7040 InputIdx = 0,
7041 OriginalOutputForResultIdx = 1,
7042 GradOfOriginalOutputNamedResultIdx = 2,
7043 };
7044
7045 enum ResultIndices {
7046 GradOfInputNamedInputIdx = 0,
7047 };
7048
7049 ReluGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
7050 : Node(Kinded::Kind::ReluGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
7051 addResult(Input.getType());
7052 }
7053 const NodeValue getInput() const { return Input_; }
7054 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
7055 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
7056 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
7057 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
7058
7059 static bool classof(const Kinded *k) {
7060 return k->getKind() == Kinded::Kind::ReluGradNodeKind;
7061 }
7062
7063
7064 bool isOverwrittenNthInput(unsigned idx) const {
7065 return false;
7066 }
7067
7068 unsigned getNumInputs() const;
7069 std::string getInputName(unsigned idx) const;
7070 NodeValue getNthInput(unsigned idx);
7071 void setNthInput(unsigned idx, NodeValue val);
7072 llvm::StringRef getOutputName(unsigned idx) const;
7073 bool hasSideEffects() const { return 0; }
7074 bool isCanonical() const { return 1; }
7075 bool isDataParallel() const { return 1; }
7076 std::string getDebugDesc() const;
7077 bool isEqual(const ReluGradNode &other) const;
7078 llvm::hash_code getHash() const;
7079 void visit(Node *parent, NodeWalker *visitor);
7080 Node* clone() const;
7081 bool verify() const;
7082};
7083} // namespace glow
7084
7085
7086namespace glow {
7087/// Applies ReLU, max(0, x), to each element in the Input tensor.
7088class ReluNode final : public Node {
7089 NodeHandle Input_;
7090
7091 public:
7092 enum InputIndices {
7093 InputIdx = 0,
7094 };
7095
7096 enum ResultIndices {
7097 ResultIdx = 0,
7098 };
7099
7100 ReluNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
7101 : Node(Kinded::Kind::ReluNodeKind, name), Input_(this, Input) {
7102 addResult(Result);
7103 }
7104 const NodeValue getInput() const { return Input_; }
7105 NodeValue getResult() { return getNthResult(0); }
7106 const NodeValue getResult() const { return getNthResult(0); }
7107
7108 static bool classof(const Kinded *k) {
7109 return k->getKind() == Kinded::Kind::ReluNodeKind;
7110 }
7111
7112
7113 bool isOverwrittenNthInput(unsigned idx) const {
7114 return false;
7115 }
7116
7117 unsigned getNumInputs() const;
7118 std::string getInputName(unsigned idx) const;
7119 NodeValue getNthInput(unsigned idx);
7120 void setNthInput(unsigned idx, NodeValue val);
7121 llvm::StringRef getOutputName(unsigned idx) const;
7122 bool hasSideEffects() const { return 0; }
7123 bool isCanonical() const { return 1; }
7124 bool isDataParallel() const { return 1; }
7125 std::string getDebugDesc() const;
7126 bool isEqual(const ReluNode &other) const;
7127 llvm::hash_code getHash() const;
7128 void visit(Node *parent, NodeWalker *visitor);
7129 Node* clone() const;
7130 bool verify() const;
7131 ReluGradNode *getGrad(GraphGradMapper &builder);
7132};
7133} // namespace glow
7134
7135
7136namespace glow {
7137/// Applies HardSwish to each element in the Input tensor.
7138class HardSwishNode final : public Node {
7139 NodeHandle Input_;
7140
7141 public:
7142 enum InputIndices {
7143 InputIdx = 0,
7144 };
7145
7146 enum ResultIndices {
7147 ResultIdx = 0,
7148 };
7149
7150 HardSwishNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
7151 : Node(Kinded::Kind::HardSwishNodeKind, name), Input_(this, Input) {
7152 addResult(Result);
7153 }
7154 const NodeValue getInput() const { return Input_; }
7155 NodeValue getResult() { return getNthResult(0); }
7156 const NodeValue getResult() const { return getNthResult(0); }
7157
7158 static bool classof(const Kinded *k) {
7159 return k->getKind() == Kinded::Kind::HardSwishNodeKind;
7160 }
7161
7162
7163 bool isOverwrittenNthInput(unsigned idx) const {
7164 return false;
7165 }
7166
7167 unsigned getNumInputs() const;
7168 std::string getInputName(unsigned idx) const;
7169 NodeValue getNthInput(unsigned idx);
7170 void setNthInput(unsigned idx, NodeValue val);
7171 llvm::StringRef getOutputName(unsigned idx) const;
7172 bool hasSideEffects() const { return 0; }
7173 bool isCanonical() const { return 1; }
7174 bool isDataParallel() const { return 1; }
7175 std::string getDebugDesc() const;
7176 bool isEqual(const HardSwishNode &other) const;
7177 llvm::hash_code getHash() const;
7178 void visit(Node *parent, NodeWalker *visitor);
7179 Node* clone() const;
7180 bool verify() const;
7181};
7182} // namespace glow
7183
7184
7185namespace glow {
7186/// Applies GeLU, to each element in the Input tensor.
7187class GeluNode final : public Node {
7188 NodeHandle Input_;
7189
7190 public:
7191 enum InputIndices {
7192 InputIdx = 0,
7193 };
7194
7195 enum ResultIndices {
7196 ResultIdx = 0,
7197 };
7198
7199 GeluNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
7200 : Node(Kinded::Kind::GeluNodeKind, name), Input_(this, Input) {
7201 addResult(Result);
7202 }
7203 const NodeValue getInput() const { return Input_; }
7204 NodeValue getResult() { return getNthResult(0); }
7205 const NodeValue getResult() const { return getNthResult(0); }
7206
7207 static bool classof(const Kinded *k) {
7208 return k->getKind() == Kinded::Kind::GeluNodeKind;
7209 }
7210
7211
7212 bool isOverwrittenNthInput(unsigned idx) const {
7213 return false;
7214 }
7215
7216 unsigned getNumInputs() const;
7217 std::string getInputName(unsigned idx) const;
7218 NodeValue getNthInput(unsigned idx);
7219 void setNthInput(unsigned idx, NodeValue val);
7220 llvm::StringRef getOutputName(unsigned idx) const;
7221 bool hasSideEffects() const { return 0; }
7222 bool isCanonical() const { return 1; }
7223 bool isDataParallel() const { return 1; }
7224 std::string getDebugDesc() const;
7225 bool isEqual(const GeluNode &other) const;
7226 llvm::hash_code getHash() const;
7227 void visit(Node *parent, NodeWalker *visitor);
7228 Node* clone() const;
7229 bool verify() const;
7230};
7231} // namespace glow
7232
7233
7234namespace glow {
7235/// Clip range of inputs to lie in [Min, Max].
7236class ClipNode final : public Node {
7237 NodeHandle Input_;
7238 float Min_;
7239 float Max_;
7240
7241 public:
7242 enum InputIndices {
7243 InputIdx = 0,
7244 };
7245
7246 enum ResultIndices {
7247 ResultIdx = 0,
7248 };
7249
7250 ClipNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Min, float Max)
7251 : Node(Kinded::Kind::ClipNodeKind, name), Input_(this, Input), Min_(Min), Max_(Max) {
7252 addResult(Result);
7253 }
7254 const NodeValue getInput() const { return Input_; }
7255 NodeValue getResult() { return getNthResult(0); }
7256 const NodeValue getResult() const { return getNthResult(0); }
7257 float getMin() const { return Min_; }
7258 float getMax() const { return Max_; }
7259
7260 static bool classof(const Kinded *k) {
7261 return k->getKind() == Kinded::Kind::ClipNodeKind;
7262 }
7263
7264
7265 bool isOverwrittenNthInput(unsigned idx) const {
7266 return false;
7267 }
7268
7269 unsigned getNumInputs() const;
7270 std::string getInputName(unsigned idx) const;
7271 NodeValue getNthInput(unsigned idx);
7272 void setNthInput(unsigned idx, NodeValue val);
7273 llvm::StringRef getOutputName(unsigned idx) const;
7274 bool hasSideEffects() const { return 0; }
7275 bool isCanonical() const { return 1; }
7276 bool isDataParallel() const { return 1; }
7277 std::string getDebugDesc() const;
7278 bool isEqual(const ClipNode &other) const;
7279 llvm::hash_code getHash() const;
7280 void visit(Node *parent, NodeWalker *visitor);
7281 Node* clone() const;
7282 bool verify() const;
7283};
7284} // namespace glow
7285
7286
7287namespace glow {
7288/// Applies PReLU, slope * min(0, x) + max(0, x), to each element in the Input tensor.
7289class PReluNode final : public Node {
7290 NodeHandle Input_;
7291 NodeHandle Slope_;
7292
7293 public:
7294 enum InputIndices {
7295 InputIdx = 0,
7296 SlopeIdx = 1,
7297 };
7298
7299 enum ResultIndices {
7300 ResultIdx = 0,
7301 };
7302
7303 PReluNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Slope)
7304 : Node(Kinded::Kind::PReluNodeKind, name), Input_(this, Input), Slope_(this, Slope) {
7305 addResult(Result);
7306 }
7307 const NodeValue getInput() const { return Input_; }
7308 const NodeValue getSlope() const { return Slope_; }
7309 NodeValue getResult() { return getNthResult(0); }
7310 const NodeValue getResult() const { return getNthResult(0); }
7311
7312 static bool classof(const Kinded *k) {
7313 return k->getKind() == Kinded::Kind::PReluNodeKind;
7314 }
7315
7316
7317 bool isOverwrittenNthInput(unsigned idx) const {
7318 return false;
7319 }
7320
7321 unsigned getNumInputs() const;
7322 std::string getInputName(unsigned idx) const;
7323 NodeValue getNthInput(unsigned idx);
7324 void setNthInput(unsigned idx, NodeValue val);
7325 llvm::StringRef getOutputName(unsigned idx) const;
7326 bool hasSideEffects() const { return 0; }
7327 bool isCanonical() const { return 1; }
7328 bool isDataParallel() const { return 1; }
7329 std::string getDebugDesc() const;
7330 bool isEqual(const PReluNode &other) const;
7331 llvm::hash_code getHash() const;
7332 void visit(Node *parent, NodeWalker *visitor);
7333 Node* clone() const;
7334 bool verify() const;
7335};
7336} // namespace glow
7337
7338
7339namespace glow {
7340class SigmoidGradNode final : public Node {
7341 NodeHandle Input_;
7342 NodeHandle OriginalOutputForResult_;
7343 NodeHandle GradOfOriginalOutputNamedResult_;
7344
7345 public:
7346 enum InputIndices {
7347 InputIdx = 0,
7348 OriginalOutputForResultIdx = 1,
7349 GradOfOriginalOutputNamedResultIdx = 2,
7350 };
7351
7352 enum ResultIndices {
7353 GradOfInputNamedInputIdx = 0,
7354 };
7355
7356 SigmoidGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
7357 : Node(Kinded::Kind::SigmoidGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
7358 addResult(Input.getType());
7359 }
7360 const NodeValue getInput() const { return Input_; }
7361 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
7362 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
7363 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
7364 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
7365
7366 static bool classof(const Kinded *k) {
7367 return k->getKind() == Kinded::Kind::SigmoidGradNodeKind;
7368 }
7369
7370
7371 bool isOverwrittenNthInput(unsigned idx) const {
7372 return false;
7373 }
7374
7375 unsigned getNumInputs() const;
7376 std::string getInputName(unsigned idx) const;
7377 NodeValue getNthInput(unsigned idx);
7378 void setNthInput(unsigned idx, NodeValue val);
7379 llvm::StringRef getOutputName(unsigned idx) const;
7380 bool hasSideEffects() const { return 0; }
7381 bool isCanonical() const { return 1; }
7382 bool isDataParallel() const { return 1; }
7383 std::string getDebugDesc() const;
7384 bool isEqual(const SigmoidGradNode &other) const;
7385 llvm::hash_code getHash() const;
7386 void visit(Node *parent, NodeWalker *visitor);
7387 Node* clone() const;
7388 bool verify() const;
7389};
7390} // namespace glow
7391
7392
7393namespace glow {
7394/// Applies Sigmoid, 1 / (1 + exp(-x)), to each element in the Input tensor.
7395class SigmoidNode final : public Node {
7396 NodeHandle Input_;
7397
7398 public:
7399 enum InputIndices {
7400 InputIdx = 0,
7401 };
7402
7403 enum ResultIndices {
7404 ResultIdx = 0,
7405 };
7406
7407 SigmoidNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
7408 : Node(Kinded::Kind::SigmoidNodeKind, name), Input_(this, Input) {
7409 addResult(Result);
7410 }
7411 const NodeValue getInput() const { return Input_; }
7412 NodeValue getResult() { return getNthResult(0); }
7413 const NodeValue getResult() const { return getNthResult(0); }
7414
7415 static bool classof(const Kinded *k) {
7416 return k->getKind() == Kinded::Kind::SigmoidNodeKind;
7417 }
7418
7419
7420 bool isOverwrittenNthInput(unsigned idx) const {
7421 return false;
7422 }
7423
7424 unsigned getNumInputs() const;
7425 std::string getInputName(unsigned idx) const;
7426 NodeValue getNthInput(unsigned idx);
7427 void setNthInput(unsigned idx, NodeValue val);
7428 llvm::StringRef getOutputName(unsigned idx) const;
7429 bool hasSideEffects() const { return 0; }
7430 bool isCanonical() const { return 1; }
7431 bool isDataParallel() const { return 1; }
7432 std::string getDebugDesc() const;
7433 bool isEqual(const SigmoidNode &other) const;
7434 llvm::hash_code getHash() const;
7435 void visit(Node *parent, NodeWalker *visitor);
7436 Node* clone() const;
7437 bool verify() const;
7438 SigmoidGradNode *getGrad(GraphGradMapper &builder);
7439};
7440} // namespace glow
7441
7442
7443namespace glow {
7444/// Applies Swish, X * Sigmoid(X), to each element in the Input tensor.
7445class SwishNode final : public Node {
7446 NodeHandle Input_;
7447
7448 public:
7449 enum InputIndices {
7450 InputIdx = 0,
7451 };
7452
7453 enum ResultIndices {
7454 ResultIdx = 0,
7455 };
7456
7457 SwishNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
7458 : Node(Kinded::Kind::SwishNodeKind, name), Input_(this, Input) {
7459 addResult(Result);
7460 }
7461 const NodeValue getInput() const { return Input_; }
7462 NodeValue getResult() { return getNthResult(0); }
7463 const NodeValue getResult() const { return getNthResult(0); }
7464
7465 static bool classof(const Kinded *k) {
7466 return k->getKind() == Kinded::Kind::SwishNodeKind;
7467 }
7468
7469
7470 bool isOverwrittenNthInput(unsigned idx) const {
7471 return false;
7472 }
7473
7474 unsigned getNumInputs() const;
7475 std::string getInputName(unsigned idx) const;
7476 NodeValue getNthInput(unsigned idx);
7477 void setNthInput(unsigned idx, NodeValue val);
7478 llvm::StringRef getOutputName(unsigned idx) const;
7479 bool hasSideEffects() const { return 0; }
7480 bool isCanonical() const { return 1; }
7481 bool isDataParallel() const { return 1; }
7482 std::string getDebugDesc() const;
7483 bool isEqual(const SwishNode &other) const;
7484 llvm::hash_code getHash() const;
7485 void visit(Node *parent, NodeWalker *visitor);
7486 Node* clone() const;
7487 bool verify() const;
7488};
7489} // namespace glow
7490
7491
7492namespace glow {
7493class TanhGradNode final : public Node {
7494 NodeHandle Input_;
7495 NodeHandle OriginalOutputForResult_;
7496 NodeHandle GradOfOriginalOutputNamedResult_;
7497
7498 public:
7499 enum InputIndices {
7500 InputIdx = 0,
7501 OriginalOutputForResultIdx = 1,
7502 GradOfOriginalOutputNamedResultIdx = 2,
7503 };
7504
7505 enum ResultIndices {
7506 GradOfInputNamedInputIdx = 0,
7507 };
7508
7509 TanhGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult)
7510 : Node(Kinded::Kind::TanhGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) {
7511 addResult(Input.getType());
7512 }
7513 const NodeValue getInput() const { return Input_; }
7514 const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; }
7515 const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; }
7516 NodeValue getGradOfInputNamedInput() { return getNthResult(0); }
7517 const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); }
7518
7519 static bool classof(const Kinded *k) {
7520 return k->getKind() == Kinded::Kind::TanhGradNodeKind;
7521 }
7522
7523
7524 bool isOverwrittenNthInput(unsigned idx) const {
7525 return false;
7526 }
7527
7528 unsigned getNumInputs() const;
7529 std::string getInputName(unsigned idx) const;
7530 NodeValue getNthInput(unsigned idx);
7531 void setNthInput(unsigned idx, NodeValue val);
7532 llvm::StringRef getOutputName(unsigned idx) const;
7533 bool hasSideEffects() const { return 0; }
7534 bool isCanonical() const { return 1; }
7535 bool isDataParallel() const { return 1; }
7536 std::string getDebugDesc() const;
7537 bool isEqual(const TanhGradNode &other) const;
7538 llvm::hash_code getHash() const;
7539 void visit(Node *parent, NodeWalker *visitor);
7540 Node* clone() const;
7541 bool verify() const;
7542};
7543} // namespace glow
7544
7545
7546namespace glow {
7547/// Applies hyperbolic tangent to each element in the Input tensor.
7548class TanhNode final : public Node {
7549 NodeHandle Input_;
7550
7551 public:
7552 enum InputIndices {
7553 InputIdx = 0,
7554 };
7555
7556 enum ResultIndices {
7557 ResultIdx = 0,
7558 };
7559
7560 TanhNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
7561 : Node(Kinded::Kind::TanhNodeKind, name), Input_(this, Input) {
7562 addResult(Result);
7563 }
7564 const NodeValue getInput() const { return Input_; }
7565 NodeValue getResult() { return getNthResult(0); }
7566 const NodeValue getResult() const { return getNthResult(0); }
7567
7568 static bool classof(const Kinded *k) {
7569 return k->getKind() == Kinded::Kind::TanhNodeKind;
7570 }
7571
7572
7573 bool isOverwrittenNthInput(unsigned idx) const {
7574 return false;
7575 }
7576
7577 unsigned getNumInputs() const;
7578 std::string getInputName(unsigned idx) const;
7579 NodeValue getNthInput(unsigned idx);
7580 void setNthInput(unsigned idx, NodeValue val);
7581 llvm::StringRef getOutputName(unsigned idx) const;
7582 bool hasSideEffects() const { return 0; }
7583 bool isCanonical() const { return 1; }
7584 bool isDataParallel() const { return 1; }
7585 std::string getDebugDesc() const;
7586 bool isEqual(const TanhNode &other) const;
7587 llvm::hash_code getHash() const;
7588 void visit(Node *parent, NodeWalker *visitor);
7589 Node* clone() const;
7590 bool verify() const;
7591 TanhGradNode *getGrad(GraphGradMapper &builder);
7592};
7593} // namespace glow
7594
7595
7596namespace glow {
7597/// Applies LeakyReLU = x for positive x and alpha * x for negative x to each element in the Input tensor.
7598class LeakyReluNode final : public Node {
7599 NodeHandle Input_;
7600 float Alpha_;
7601
7602 public:
7603 enum InputIndices {
7604 InputIdx = 0,
7605 };
7606
7607 enum ResultIndices {
7608 ResultIdx = 0,
7609 };
7610
7611 LeakyReluNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Alpha)
7612 : Node(Kinded::Kind::LeakyReluNodeKind, name), Input_(this, Input), Alpha_(Alpha) {
7613 addResult(Result);
7614 }
7615 const NodeValue getInput() const { return Input_; }
7616 NodeValue getResult() { return getNthResult(0); }
7617 const NodeValue getResult() const { return getNthResult(0); }
7618 float getAlpha() const { return Alpha_; }
7619
7620 static bool classof(const Kinded *k) {
7621 return k->getKind() == Kinded::Kind::LeakyReluNodeKind;
7622 }
7623
7624
7625 bool isOverwrittenNthInput(unsigned idx) const {
7626 return false;
7627 }
7628
7629 unsigned getNumInputs() const;
7630 std::string getInputName(unsigned idx) const;
7631 NodeValue getNthInput(unsigned idx);
7632 void setNthInput(unsigned idx, NodeValue val);
7633 llvm::StringRef getOutputName(unsigned idx) const;
7634 bool hasSideEffects() const { return 0; }
7635 bool isCanonical() const { return 1; }
7636 bool isDataParallel() const { return 1; }
7637 std::string getDebugDesc() const;
7638 bool isEqual(const LeakyReluNode &other) const;
7639 llvm::hash_code getHash() const;
7640 void visit(Node *parent, NodeWalker *visitor);
7641 Node* clone() const;
7642 bool verify() const;
7643};
7644} // namespace glow
7645
7646
7647namespace glow {
7648/// Performs SoftPlus, ln(exp(x) + 1), to each element in the Input tensor.
7649class SoftPlusNode final : public Node {
7650 NodeHandle Input_;
7651
7652 public:
7653 enum InputIndices {
7654 InputIdx = 0,
7655 };
7656
7657 enum ResultIndices {
7658 ResultIdx = 0,
7659 };
7660
7661 SoftPlusNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
7662 : Node(Kinded::Kind::SoftPlusNodeKind, name), Input_(this, Input) {
7663 addResult(Result);
7664 }
7665 const NodeValue getInput() const { return Input_; }
7666 NodeValue getResult() { return getNthResult(0); }
7667 const NodeValue getResult() const { return getNthResult(0); }
7668
7669 static bool classof(const Kinded *k) {
7670 return k->getKind() == Kinded::Kind::SoftPlusNodeKind;
7671 }
7672
7673
7674 bool isOverwrittenNthInput(unsigned idx) const {
7675 return false;
7676 }
7677
7678 unsigned getNumInputs() const;
7679 std::string getInputName(unsigned idx) const;
7680 NodeValue getNthInput(unsigned idx);
7681 void setNthInput(unsigned idx, NodeValue val);
7682 llvm::StringRef getOutputName(unsigned idx) const;
7683 bool hasSideEffects() const { return 0; }
7684 bool isCanonical() const { return 1; }
7685 bool isDataParallel() const { return 1; }
7686 std::string getDebugDesc() const;
7687 bool isEqual(const SoftPlusNode &other) const;
7688 llvm::hash_code getHash() const;
7689 void visit(Node *parent, NodeWalker *visitor);
7690 Node* clone() const;
7691 bool verify() const;
7692};
7693} // namespace glow
7694
7695
7696namespace glow {
7697/// Reshape the Input tensor to shape Dims.
7698class ReshapeNode final : public Node {
7699 NodeHandle Input_;
7700 std::vector<dim_t> Dims_;
7701 std::string Layout_;
7702
7703 public:
7704 enum InputIndices {
7705 InputIdx = 0,
7706 };
7707
7708 enum ResultIndices {
7709 ResultIdx = 0,
7710 };
7711
7712 ReshapeNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<dim_t> Dims, std::string Layout)
7713 : Node(Kinded::Kind::ReshapeNodeKind, name), Input_(this, Input), Dims_(Dims), Layout_(Layout) {
7714 addResult(Result);
7715 }
7716 const NodeValue getInput() const { return Input_; }
7717 NodeValue getResult() { return getNthResult(0); }
7718 const NodeValue getResult() const { return getNthResult(0); }
7719 llvm::ArrayRef<dim_t> getDims() const { return Dims_; }
7720 std::string getLayout() const { return Layout_; }
7721
7722 static bool classof(const Kinded *k) {
7723 return k->getKind() == Kinded::Kind::ReshapeNodeKind;
7724 }
7725
7726
7727 bool isOverwrittenNthInput(unsigned idx) const {
7728 return false;
7729 }
7730
7731 unsigned getNumInputs() const;
7732 std::string getInputName(unsigned idx) const;
7733 NodeValue getNthInput(unsigned idx);
7734 void setNthInput(unsigned idx, NodeValue val);
7735 llvm::StringRef getOutputName(unsigned idx) const;
7736 bool hasSideEffects() const { return 0; }
7737 bool isCanonical() const { return 1; }
7738 bool isDataParallel() const { return 0; }
7739 std::string getDebugDesc() const;
7740 bool isEqual(const ReshapeNode &other) const;
7741 llvm::hash_code getHash() const;
7742 void visit(Node *parent, NodeWalker *visitor);
7743 Node* clone() const;
7744 bool verify() const;
7745};
7746} // namespace glow
7747
7748
7749namespace glow {
7750/// Transpose the Input tensor based on the vector Shuffle, which assigns a new axis for each dimension in Input.
7751class TransposeNode final : public Node {
7752 NodeHandle Input_;
7753 std::vector<unsigned_t> Shuffle_;
7754 std::string Layout_;
7755
7756 public:
7757 enum InputIndices {
7758 InputIdx = 0,
7759 };
7760
7761 enum ResultIndices {
7762 ResultIdx = 0,
7763 };
7764
7765 TransposeNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<unsigned_t> Shuffle, std::string Layout)
7766 : Node(Kinded::Kind::TransposeNodeKind, name), Input_(this, Input), Shuffle_(Shuffle), Layout_(Layout) {
7767 addResult(Result);
7768 }
7769 const NodeValue getInput() const { return Input_; }
7770 NodeValue getResult() { return getNthResult(0); }
7771 const NodeValue getResult() const { return getNthResult(0); }
7772 llvm::ArrayRef<unsigned_t> getShuffle() const { return Shuffle_; }
7773 std::string getLayout() const { return Layout_; }
7774
7775 static bool classof(const Kinded *k) {
7776 return k->getKind() == Kinded::Kind::TransposeNodeKind;
7777 }
7778
7779
7780 bool isOverwrittenNthInput(unsigned idx) const {
7781 return false;
7782 }
7783
7784 unsigned getNumInputs() const;
7785 std::string getInputName(unsigned idx) const;
7786 NodeValue getNthInput(unsigned idx);
7787 void setNthInput(unsigned idx, NodeValue val);
7788 llvm::StringRef getOutputName(unsigned idx) const;
7789 bool hasSideEffects() const { return 0; }
7790 bool isCanonical() const { return 1; }
7791 bool isDataParallel() const { return 0; }
7792 std::string getDebugDesc() const;
7793 bool isEqual(const TransposeNode &other) const;
7794 llvm::hash_code getHash() const;
7795 void visit(Node *parent, NodeWalker *visitor);
7796 Node* clone() const;
7797 bool verify() const;
7798};
7799} // namespace glow
7800
7801
7802namespace glow {
7803/// The concat operator adds two tensors together.
7804/// The parameter 'dim' specifies the dimension to use when joining the tensors.
7805class ConcatNode final : public Node {
7806 std::vector<NodeHandle> Inputs_;
7807 unsigned_t Dim_;
7808
7809 public:
7810 enum InputIndices {
7811 };
7812
7813 enum ResultIndices {
7814 ResultIdx = 0,
7815 };
7816
7817 ConcatNode(llvm::StringRef name, TypeRef Result , std::vector<NodeValue> Inputs, unsigned_t Dim)
7818 : Node(Kinded::Kind::ConcatNodeKind, name), Dim_(Dim) {
7819 addResult(Result);
7820 Inputs_.resize(Inputs.size());
7821 for (size_t idx = 0, e = Inputs.size(); idx < e; ++idx) {
7822 Inputs_[idx] = Inputs[idx];
7823 Inputs_[idx].setParent(this);
7824 }
7825 }
7826 NodeValue getResult() { return getNthResult(0); }
7827 const NodeValue getResult() const { return getNthResult(0); }
7828 NodeValueArrayRef getInputs() const { return Inputs_; }
7829 unsigned_t getDim() const { return Dim_; }
7830
7831 static bool classof(const Kinded *k) {
7832 return k->getKind() == Kinded::Kind::ConcatNodeKind;
7833 }
7834
7835
7836 bool isOverwrittenNthInput(unsigned idx) const {
7837 return false;
7838 }
7839
7840 unsigned getNumInputs() const;
7841 std::string getInputName(unsigned idx) const;
7842 NodeValue getNthInput(unsigned idx);
7843 void setNthInput(unsigned idx, NodeValue val);
7844 llvm::StringRef getOutputName(unsigned idx) const;
7845 bool hasSideEffects() const { return 0; }
7846 bool isCanonical() const { return 1; }
7847 bool isDataParallel() const { return 0; }
7848 std::string getDebugDesc() const;
7849 bool isEqual(const ConcatNode &other) const;
7850 llvm::hash_code getHash() const;
7851 void visit(Node *parent, NodeWalker *visitor);
7852 Node* clone() const;
7853 bool verify() const;
7854};
7855} // namespace glow
7856
7857
7858namespace glow {
7859/// Produces a slice of the Input tensor. The Start vector defines the starting indices for each dimension from which the slice should be taken. The end index for each dimension is determined from the input type's shape.
7860class SliceNode final : public Node {
7861 NodeHandle Input_;
7862 std::vector<dim_t> Start_;
7863
7864 public:
7865 enum InputIndices {
7866 InputIdx = 0,
7867 };
7868
7869 enum ResultIndices {
7870 ResultIdx = 0,
7871 };
7872
7873 SliceNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<dim_t> Start)
7874 : Node(Kinded::Kind::SliceNodeKind, name), Input_(this, Input), Start_(Start) {
7875 addResult(Result);
7876 }
7877 const NodeValue getInput() const { return Input_; }
7878 NodeValue getResult() { return getNthResult(0); }
7879 const NodeValue getResult() const { return getNthResult(0); }
7880 llvm::ArrayRef<dim_t> getStart() const { return Start_; }
7881
7882 static bool classof(const Kinded *k) {
7883 return k->getKind() == Kinded::Kind::SliceNodeKind;
7884 }
7885
7886
7887 bool isOverwrittenNthInput(unsigned idx) const {
7888 return false;
7889 }
7890
7891 unsigned getNumInputs() const;
7892 std::string getInputName(unsigned idx) const;
7893 NodeValue getNthInput(unsigned idx);
7894 void setNthInput(unsigned idx, NodeValue val);
7895 llvm::StringRef getOutputName(unsigned idx) const;
7896 bool hasSideEffects() const { return 0; }
7897 bool isCanonical() const { return 1; }
7898 bool isDataParallel() const { return 0; }
7899 std::string getDebugDesc() const;
7900 bool isEqual(const SliceNode &other) const;
7901 llvm::hash_code getHash() const;
7902 void visit(Node *parent, NodeWalker *visitor);
7903 Node* clone() const;
7904 bool verify() const;
7905};
7906} // namespace glow
7907
7908
7909namespace glow {
7910/// Insert tensor Small into tensor Big given indices Start. Small is inserted Count times along Axis. The resulting Tensor will have the same type as the input Big tensor.
7911class InsertTensorNode final : public Node {
7912 NodeHandle Big_;
7913 NodeHandle Small_;
7914 std::vector<dim_t> Start_;
7915 unsigned_t Count_;
7916 unsigned_t Axis_;
7917
7918 public:
7919 enum InputIndices {
7920 BigIdx = 0,
7921 SmallIdx = 1,
7922 };
7923
7924 enum ResultIndices {
7925 ResultIdx = 0,
7926 };
7927
7928 InsertTensorNode(llvm::StringRef name, NodeValue Big, NodeValue Small, std::vector<dim_t> Start, unsigned_t Count, unsigned_t Axis)
7929 : Node(Kinded::Kind::InsertTensorNodeKind, name), Big_(this, Big), Small_(this, Small), Start_(Start), Count_(Count), Axis_(Axis) {
7930 addResult(Big.getType());
7931 }
7932 const NodeValue getBig() const { return Big_; }
7933 const NodeValue getSmall() const { return Small_; }
7934 NodeValue getResult() { return getNthResult(0); }
7935 const NodeValue getResult() const { return getNthResult(0); }
7936 llvm::ArrayRef<dim_t> getStart() const { return Start_; }
7937 unsigned_t getCount() const { return Count_; }
7938 unsigned_t getAxis() const { return Axis_; }
7939
7940 static bool classof(const Kinded *k) {
7941 return k->getKind() == Kinded::Kind::InsertTensorNodeKind;
7942 }
7943
7944
7945 bool isOverwrittenNthInput(unsigned idx) const {
7946 return false;
7947 }
7948
7949 unsigned getNumInputs() const;
7950 std::string getInputName(unsigned idx) const;
7951 NodeValue getNthInput(unsigned idx);
7952 void setNthInput(unsigned idx, NodeValue val);
7953 llvm::StringRef getOutputName(unsigned idx) const;
7954 bool hasSideEffects() const { return 0; }
7955 bool isCanonical() const { return 1; }
7956 bool isDataParallel() const { return 0; }
7957 std::string getDebugDesc() const;
7958 bool isEqual(const InsertTensorNode &other) const;
7959 llvm::hash_code getHash() const;
7960 void visit(Node *parent, NodeWalker *visitor);
7961 Node* clone() const;
7962 bool verify() const;
7963};
7964} // namespace glow
7965
7966
7967namespace glow {
7968/// Gathers entries of the outer-most dimension of Data indexed by Indices, and concatenates them. Output tensor will have dimensions: {I_0, I_1, ... I_n, D_1, D_2, ... D_m}, where D_i and I_j denote Data and Indices dimensions respectively. If axis is not zero, the gather operator will treat the first axis as the batch and will concat the result of the gather operation on each sample in the batch.
7969class GatherNode final : public Node {
7970 NodeHandle Data_;
7971 NodeHandle Indices_;
7972 unsigned_t BatchDims_;
7973
7974 public:
7975 enum InputIndices {
7976 DataIdx = 0,
7977 IndicesIdx = 1,
7978 };
7979
7980 enum ResultIndices {
7981 ResultIdx = 0,
7982 };
7983
7984 GatherNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, unsigned_t BatchDims)
7985 : Node(Kinded::Kind::GatherNodeKind, name), Data_(this, Data), Indices_(this, Indices), BatchDims_(BatchDims) {
7986 addResult(Result);
7987 }
7988 const NodeValue getData() const { return Data_; }
7989 const NodeValue getIndices() const { return Indices_; }
7990 NodeValue getResult() { return getNthResult(0); }
7991 const NodeValue getResult() const { return getNthResult(0); }
7992 unsigned_t getBatchDims() const { return BatchDims_; }
7993
7994 static bool classof(const Kinded *k) {
7995 return k->getKind() == Kinded::Kind::GatherNodeKind;
7996 }
7997
7998
7999 bool isOverwrittenNthInput(unsigned idx) const {
8000 return false;
8001 }
8002
8003 unsigned getNumInputs() const;
8004 std::string getInputName(unsigned idx) const;
8005 NodeValue getNthInput(unsigned idx);
8006 void setNthInput(unsigned idx, NodeValue val);
8007 llvm::StringRef getOutputName(unsigned idx) const;
8008 bool hasSideEffects() const { return 0; }
8009 bool isCanonical() const { return 1; }
8010 bool isDataParallel() const { return 0; }
8011 std::string getDebugDesc() const;
8012 bool isEqual(const GatherNode &other) const;
8013 llvm::hash_code getHash() const;
8014 void visit(Node *parent, NodeWalker *visitor);
8015 Node* clone() const;
8016 bool verify() const;
8017};
8018} // namespace glow
8019
8020
8021namespace glow {
8022/// Given Data tensor of rank r >= 1, Indices tensor of rank q >= 1 This operator gathers slices of Data into an output tensor of rank q + r - Indices_shape[-1] - 1 .
8023class GatherNDNode final : public Node {
8024 NodeHandle Data_;
8025 NodeHandle Indices_;
8026 unsigned_t BatchDims_;
8027
8028 public:
8029 enum InputIndices {
8030 DataIdx = 0,
8031 IndicesIdx = 1,
8032 };
8033
8034 enum ResultIndices {
8035 ResultIdx = 0,
8036 };
8037
8038 GatherNDNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, unsigned_t BatchDims)
8039 : Node(Kinded::Kind::GatherNDNodeKind, name), Data_(this, Data), Indices_(this, Indices), BatchDims_(BatchDims) {
8040 addResult(Result);
8041 }
8042 const NodeValue getData() const { return Data_; }
8043 const NodeValue getIndices() const { return Indices_; }
8044 NodeValue getResult() { return getNthResult(0); }
8045 const NodeValue getResult() const { return getNthResult(0); }
8046 unsigned_t getBatchDims() const { return BatchDims_; }
8047
8048 static bool classof(const Kinded *k) {
8049 return k->getKind() == Kinded::Kind::GatherNDNodeKind;
8050 }
8051
8052
8053 bool isOverwrittenNthInput(unsigned idx) const {
8054 return false;
8055 }
8056
8057 unsigned getNumInputs() const;
8058 std::string getInputName(unsigned idx) const;
8059 NodeValue getNthInput(unsigned idx);
8060 void setNthInput(unsigned idx, NodeValue val);
8061 llvm::StringRef getOutputName(unsigned idx) const;
8062 bool hasSideEffects() const { return 0; }
8063 bool isCanonical() const { return 1; }
8064 bool isDataParallel() const { return 0; }
8065 std::string getDebugDesc() const;
8066 bool isEqual(const GatherNDNode &other) const;
8067 llvm::hash_code getHash() const;
8068 void visit(Node *parent, NodeWalker *visitor);
8069 Node* clone() const;
8070 bool verify() const;
8071};
8072} // namespace glow
8073
8074
8075namespace glow {
8076/// GatherElements takes inputs data and indices of the same rank r >= 1 and an attribute axis specified by dim. It is an indexingoperation that produces its output by indexing into the input data tensor at index positions determined by elements of the indices tensor. Its output shape is the same as the shape of indices and consists of one value (gathered from the data) for each element in indices.
8077class GatherElementsNode final : public Node {
8078 NodeHandle Data_;
8079 NodeHandle Indices_;
8080 unsigned_t Dim_;
8081
8082 public:
8083 enum InputIndices {
8084 DataIdx = 0,
8085 IndicesIdx = 1,
8086 };
8087
8088 enum ResultIndices {
8089 ResultIdx = 0,
8090 };
8091
8092 GatherElementsNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, unsigned_t Dim)
8093 : Node(Kinded::Kind::GatherElementsNodeKind, name), Data_(this, Data), Indices_(this, Indices), Dim_(Dim) {
8094 addResult(Result);
8095 }
8096 const NodeValue getData() const { return Data_; }
8097 const NodeValue getIndices() const { return Indices_; }
8098 NodeValue getResult() { return getNthResult(0); }
8099 const NodeValue getResult() const { return getNthResult(0); }
8100 unsigned_t getDim() const { return Dim_; }
8101
8102 static bool classof(const Kinded *k) {
8103 return k->getKind() == Kinded::Kind::GatherElementsNodeKind;
8104 }
8105
8106
8107 bool isOverwrittenNthInput(unsigned idx) const {
8108 return false;
8109 }
8110
8111 unsigned getNumInputs() const;
8112 std::string getInputName(unsigned idx) const;
8113 NodeValue getNthInput(unsigned idx);
8114 void setNthInput(unsigned idx, NodeValue val);
8115 llvm::StringRef getOutputName(unsigned idx) const;
8116 bool hasSideEffects() const { return 0; }
8117 bool isCanonical() const { return 1; }
8118 bool isDataParallel() const { return 0; }
8119 std::string getDebugDesc() const;
8120 bool isEqual(const GatherElementsNode &other) const;
8121 llvm::hash_code getHash() const;
8122 void visit(Node *parent, NodeWalker *visitor);
8123 Node* clone() const;
8124 bool verify() const;
8125};
8126} // namespace glow
8127
8128
8129namespace glow {
8130/// Gathers entries of Data into Output in groups specified by the elements of Ranges. Each element of Ranges contains a list of pairs of indices of the form (index, length) which specify which entries of data to gather. The ordering of elements in Ranges and of pairs within an element is preserved in Output. Lengths contains the lengths of the ranges gathered by each list of pairs in Ranges.
8131class GatherRangesNode final : public Node {
8132 NodeHandle Data_;
8133 NodeHandle Ranges_;
8134
8135 public:
8136 enum InputIndices {
8137 DataIdx = 0,
8138 RangesIdx = 1,
8139 };
8140
8141 enum ResultIndices {
8142 OutputIdx = 0,
8143 LengthsIdx = 1,
8144 };
8145
8146 GatherRangesNode(llvm::StringRef name, TypeRef Output , TypeRef Lengths , NodeValue Data, NodeValue Ranges)
8147 : Node(Kinded::Kind::GatherRangesNodeKind, name), Data_(this, Data), Ranges_(this, Ranges) {
8148 addResult(Output);
8149 addResult(Lengths);
8150 }
8151 const NodeValue getData() const { return Data_; }
8152 const NodeValue getRanges() const { return Ranges_; }
8153 NodeValue getOutput() { return getNthResult(0); }
8154 const NodeValue getOutput() const { return getNthResult(0); }
8155 NodeValue getLengths() { return getNthResult(1); }
8156 const NodeValue getLengths() const { return getNthResult(1); }
8157
8158 static bool classof(const Kinded *k) {
8159 return k->getKind() == Kinded::Kind::GatherRangesNodeKind;
8160 }
8161
8162
8163 bool isOverwrittenNthInput(unsigned idx) const {
8164 return false;
8165 }
8166
8167 unsigned getNumInputs() const;
8168 std::string getInputName(unsigned idx) const;
8169 NodeValue getNthInput(unsigned idx);
8170 void setNthInput(unsigned idx, NodeValue val);
8171 llvm::StringRef getOutputName(unsigned idx) const;
8172 bool hasSideEffects() const { return 0; }
8173 bool isCanonical() const { return 1; }
8174 bool isDataParallel() const { return 0; }
8175 std::string getDebugDesc() const;
8176 bool isEqual(const GatherRangesNode &other) const;
8177 llvm::hash_code getHash() const;
8178 void visit(Node *parent, NodeWalker *visitor);
8179 Node* clone() const;
8180 bool verify() const;
8181};
8182} // namespace glow
8183
8184
8185namespace glow {
8186/// Copies each slice from Slices into Data at the corresponding index in Indices. For example, given input Data {{1,2},{3,4},{5,6}}, Slices {{-3,-4}}, and Indices {{1}}, the result is {{1,2},{-3,-4},{5,6}}. It also supports multi-dimensional indices. For example, given input Data {{1,2},{3,4},{5,6}}, Slices {-3,-4}, and Indices {{1,0},{1,1}} also produces {{1,2},{-3,-4},{5,6}}. If Cumulative is true, the node adds values from Slices to Data instead of copying. For example, given input Data {{1,2},{3,4},{5,6}}, Slices {{-3,-4}}, and Indices {1}, the result is {{1,2},{0,0},{5,6}}. If an index is specified several times, its updates will be added several times as well.
8187class ScatterDataNode final : public Node {
8188 NodeHandle Data_;
8189 NodeHandle Indices_;
8190 NodeHandle Slices_;
8191 bool Cumulative_;
8192
8193 public:
8194 enum InputIndices {
8195 DataIdx = 0,
8196 IndicesIdx = 1,
8197 SlicesIdx = 2,
8198 };
8199
8200 enum ResultIndices {
8201 ResultIdx = 0,
8202 };
8203
8204 ScatterDataNode(llvm::StringRef name, NodeValue Data, NodeValue Indices, NodeValue Slices, bool Cumulative)
8205 : Node(Kinded::Kind::ScatterDataNodeKind, name), Data_(this, Data), Indices_(this, Indices), Slices_(this, Slices), Cumulative_(Cumulative) {
8206 addResult(Data.getType());
8207 }
8208 const NodeValue getData() const { return Data_; }
8209 const NodeValue getIndices() const { return Indices_; }
8210 const NodeValue getSlices() const { return Slices_; }
8211 NodeValue getResult() { return getNthResult(0); }
8212 const NodeValue getResult() const { return getNthResult(0); }
8213 bool getCumulative() const { return Cumulative_; }
8214
8215 static bool classof(const Kinded *k) {
8216 return k->getKind() == Kinded::Kind::ScatterDataNodeKind;
8217 }
8218
8219
8220 bool isOverwrittenNthInput(unsigned idx) const {
8221 return false;
8222 }
8223
8224 unsigned getNumInputs() const;
8225 std::string getInputName(unsigned idx) const;
8226 NodeValue getNthInput(unsigned idx);
8227 void setNthInput(unsigned idx, NodeValue val);
8228 llvm::StringRef getOutputName(unsigned idx) const;
8229 bool hasSideEffects() const { return 0; }
8230 bool isCanonical() const { return 1; }
8231 bool isDataParallel() const { return 0; }
8232 std::string getDebugDesc() const;
8233 bool isEqual(const ScatterDataNode &other) const;
8234 llvm::hash_code getHash() const;
8235 void visit(Node *parent, NodeWalker *visitor);
8236 Node* clone() const;
8237 bool verify() const;
8238};
8239} // namespace glow
8240
8241
8242namespace glow {
8243/// Tile an Input tensor Count times along Axis.
8244class TileNode final : public Node {
8245 NodeHandle Input_;
8246 unsigned_t Count_;
8247 unsigned_t Axis_;
8248
8249 public:
8250 enum InputIndices {
8251 InputIdx = 0,
8252 };
8253
8254 enum ResultIndices {
8255 ResultIdx = 0,
8256 };
8257
8258 TileNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Count, unsigned_t Axis)
8259 : Node(Kinded::Kind::TileNodeKind, name), Input_(this, Input), Count_(Count), Axis_(Axis) {
8260 addResult(Result);
8261 }
8262 const NodeValue getInput() const { return Input_; }
8263 NodeValue getResult() { return getNthResult(0); }
8264 const NodeValue getResult() const { return getNthResult(0); }
8265 unsigned_t getCount() const { return Count_; }
8266 unsigned_t getAxis() const { return Axis_; }
8267
8268 static bool classof(const Kinded *k) {
8269 return k->getKind() == Kinded::Kind::TileNodeKind;
8270 }
8271
8272
8273 bool isOverwrittenNthInput(unsigned idx) const {
8274 return false;
8275 }
8276
8277 unsigned getNumInputs() const;
8278 std::string getInputName(unsigned idx) const;
8279 NodeValue getNthInput(unsigned idx);
8280 void setNthInput(unsigned idx, NodeValue val);
8281 llvm::StringRef getOutputName(unsigned idx) const;
8282 bool hasSideEffects() const { return 0; }
8283 bool isCanonical() const { return 1; }
8284 bool isDataParallel() const { return 0; }
8285 std::string getDebugDesc() const;
8286 bool isEqual(const TileNode &other) const;
8287 llvm::hash_code getHash() const;
8288 void visit(Node *parent, NodeWalker *visitor);
8289 Node* clone() const;
8290 bool verify() const;
8291};
8292} // namespace glow
8293
8294
8295namespace glow {
8296/// Expands each row of the Data to a row of zeros and ones, according to One Hot Encoding. i-th element of Result's row is one iff Values[i] equals to the corresponding element of Data.
8297class BatchOneHotNode final : public Node {
8298 NodeHandle Data_;
8299 NodeHandle Lengths_;
8300 NodeHandle Values_;
8301
8302 public:
8303 enum InputIndices {
8304 DataIdx = 0,
8305 LengthsIdx = 1,
8306 ValuesIdx = 2,
8307 };
8308
8309 enum ResultIndices {
8310 ResultIdx = 0,
8311 };
8312
8313 BatchOneHotNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Lengths, NodeValue Values)
8314 : Node(Kinded::Kind::BatchOneHotNodeKind, name), Data_(this, Data), Lengths_(this, Lengths), Values_(this, Values) {
8315 addResult(Result);
8316 }
8317 const NodeValue getData() const { return Data_; }
8318 const NodeValue getLengths() const { return Lengths_; }
8319 const NodeValue getValues() const { return Values_; }
8320 NodeValue getResult() { return getNthResult(0); }
8321 const NodeValue getResult() const { return getNthResult(0); }
8322
8323 static bool classof(const Kinded *k) {
8324 return k->getKind() == Kinded::Kind::BatchOneHotNodeKind;
8325 }
8326
8327
8328 bool isOverwrittenNthInput(unsigned idx) const {
8329 return false;
8330 }
8331
8332 unsigned getNumInputs() const;
8333 std::string getInputName(unsigned idx) const;
8334 NodeValue getNthInput(unsigned idx);
8335 void setNthInput(unsigned idx, NodeValue val);
8336 llvm::StringRef getOutputName(unsigned idx) const;
8337 bool hasSideEffects() const { return 0; }
8338 bool isCanonical() const { return 1; }
8339 bool isDataParallel() const { return 0; }
8340 std::string getDebugDesc() const;
8341 bool isEqual(const BatchOneHotNode &other) const;
8342 llvm::hash_code getHash() const;
8343 void visit(Node *parent, NodeWalker *visitor);
8344 Node* clone() const;
8345 bool verify() const;
8346};
8347} // namespace glow
8348
8349
8350namespace glow {
8351/// Given Input tensor of [N,H,W,C], where N is the batch axis, C is the channel or depth, H is the height and W is the width. This produces Output tensor of [N, H/BlockSize, W/BlockSize, C * BlockSize * BlockSize].
8352class SpaceToDepthNode final : public Node {
8353 NodeHandle Input_;
8354 unsigned_t BlockSize_;
8355
8356 public:
8357 enum InputIndices {
8358 InputIdx = 0,
8359 };
8360
8361 enum ResultIndices {
8362 ResultIdx = 0,
8363 };
8364
8365 SpaceToDepthNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t BlockSize)
8366 : Node(Kinded::Kind::SpaceToDepthNodeKind, name), Input_(this, Input), BlockSize_(BlockSize) {
8367 addResult(Result);
8368 }
8369 const NodeValue getInput() const { return Input_; }
8370 NodeValue getResult() { return getNthResult(0); }
8371 const NodeValue getResult() const { return getNthResult(0); }
8372 unsigned_t getBlockSize() const { return BlockSize_; }
8373
8374 static bool classof(const Kinded *k) {
8375 return k->getKind() == Kinded::Kind::SpaceToDepthNodeKind;
8376 }
8377
8378
8379 bool isOverwrittenNthInput(unsigned idx) const {
8380 return false;
8381 }
8382
8383 unsigned getNumInputs() const;
8384 std::string getInputName(unsigned idx) const;
8385 NodeValue getNthInput(unsigned idx);
8386 void setNthInput(unsigned idx, NodeValue val);
8387 llvm::StringRef getOutputName(unsigned idx) const;
8388 bool hasSideEffects() const { return 0; }
8389 bool isCanonical() const { return 1; }
8390 bool isDataParallel() const { return 0; }
8391 std::string getDebugDesc() const;
8392 bool isEqual(const SpaceToDepthNode &other) const;
8393 llvm::hash_code getHash() const;
8394 void visit(Node *parent, NodeWalker *visitor);
8395 Node* clone() const;
8396 bool verify() const;
8397};
8398} // namespace glow
8399
8400
8401namespace glow {
8402/// Given Input tensor of 3D, 4D, 5D or 6D, generates an Output tensor with resized spatial dimensions using nearest neighbor interpolation. The Output tensor is of shape floor(input_dimension * scale)
8403class ResizeNearestNode final : public Node {
8404 NodeHandle Input_;
8405 std::vector<float> Scale_;
8406
8407 public:
8408 enum InputIndices {
8409 InputIdx = 0,
8410 };
8411
8412 enum ResultIndices {
8413 ResultIdx = 0,
8414 };
8415
8416 ResizeNearestNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<float> Scale)
8417 : Node(Kinded::Kind::ResizeNearestNodeKind, name), Input_(this, Input), Scale_(Scale) {
8418 addResult(Result);
8419 }
8420 const NodeValue getInput() const { return Input_; }
8421 NodeValue getResult() { return getNthResult(0); }
8422 const NodeValue getResult() const { return getNthResult(0); }
8423 llvm::ArrayRef<float> getScale() const { return Scale_; }
8424
8425 static bool classof(const Kinded *k) {
8426 return k->getKind() == Kinded::Kind::ResizeNearestNodeKind;
8427 }
8428
8429
8430 bool isOverwrittenNthInput(unsigned idx) const {
8431 return false;
8432 }
8433
8434 unsigned getNumInputs() const;
8435 std::string getInputName(unsigned idx) const;
8436 NodeValue getNthInput(unsigned idx);
8437 void setNthInput(unsigned idx, NodeValue val);
8438 llvm::StringRef getOutputName(unsigned idx) const;
8439 bool hasSideEffects() const { return 0; }
8440 bool isCanonical() const { return 1; }
8441 bool isDataParallel() const { return 0; }
8442 std::string getDebugDesc() const;
8443 bool isEqual(const ResizeNearestNode &other) const;
8444 llvm::hash_code getHash() const;
8445 void visit(Node *parent, NodeWalker *visitor);
8446 Node* clone() const;
8447 bool verify() const;
8448};
8449} // namespace glow
8450
8451
8452namespace glow {
8453/// Given Input tensor of [N,H,W,C], where N is the batch, C is the channel or depth, H is the height and W is the width, Generates an Output tensor with resized spatial dimensions using bilinear neighbor interpolation. The Output tensor is of shape floor(input_dimension * scale)
8454class ResizeBilinearNode final : public Node {
8455 NodeHandle Input_;
8456 std::vector<float> Scale_;
8457
8458 public:
8459 enum InputIndices {
8460 InputIdx = 0,
8461 };
8462
8463 enum ResultIndices {
8464 ResultIdx = 0,
8465 };
8466
8467 ResizeBilinearNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<float> Scale)
8468 : Node(Kinded::Kind::ResizeBilinearNodeKind, name), Input_(this, Input), Scale_(Scale) {
8469 addResult(Result);
8470 }
8471 const NodeValue getInput() const { return Input_; }
8472 NodeValue getResult() { return getNthResult(0); }
8473 const NodeValue getResult() const { return getNthResult(0); }
8474 llvm::ArrayRef<float> getScale() const { return Scale_; }
8475
8476 static bool classof(const Kinded *k) {
8477 return k->getKind() == Kinded::Kind::ResizeBilinearNodeKind;
8478 }
8479
8480
8481 bool isOverwrittenNthInput(unsigned idx) const {
8482 return false;
8483 }
8484
8485 unsigned getNumInputs() const;
8486 std::string getInputName(unsigned idx) const;
8487 NodeValue getNthInput(unsigned idx);
8488 void setNthInput(unsigned idx, NodeValue val);
8489 llvm::StringRef getOutputName(unsigned idx) const;
8490 bool hasSideEffects() const { return 0; }
8491 bool isCanonical() const { return 1; }
8492 bool isDataParallel() const { return 0; }
8493 std::string getDebugDesc() const;
8494 bool isEqual(const ResizeBilinearNode &other) const;
8495 llvm::hash_code getHash() const;
8496 void visit(Node *parent, NodeWalker *visitor);
8497 Node* clone() const;
8498 bool verify() const;
8499};
8500} // namespace glow
8501
8502
8503namespace glow {
8504/// Broadcast the Input tensor to TargetDim using Axis to indicate the offset between Input dimension and TargetDim
8505class BroadcastNode final : public Node {
8506 NodeHandle Input_;
8507 unsigned_t Axis_;
8508 std::vector<dim_t> TargetDim_;
8509
8510 public:
8511 enum InputIndices {
8512 InputIdx = 0,
8513 };
8514
8515 enum ResultIndices {
8516 ResultIdx = 0,
8517 };
8518
8519 BroadcastNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis, std::vector<dim_t> TargetDim)
8520 : Node(Kinded::Kind::BroadcastNodeKind, name), Input_(this, Input), Axis_(Axis), TargetDim_(TargetDim) {
8521 addResult(Result);
8522 }
8523 const NodeValue getInput() const { return Input_; }
8524 NodeValue getResult() { return getNthResult(0); }
8525 const NodeValue getResult() const { return getNthResult(0); }
8526 unsigned_t getAxis() const { return Axis_; }
8527 llvm::ArrayRef<dim_t> getTargetDim() const { return TargetDim_; }
8528
8529 static bool classof(const Kinded *k) {
8530 return k->getKind() == Kinded::Kind::BroadcastNodeKind;
8531 }
8532
8533
8534 bool isOverwrittenNthInput(unsigned idx) const {
8535 return false;
8536 }
8537
8538 unsigned getNumInputs() const;
8539 std::string getInputName(unsigned idx) const;
8540 NodeValue getNthInput(unsigned idx);
8541 void setNthInput(unsigned idx, NodeValue val);
8542 llvm::StringRef getOutputName(unsigned idx) const;
8543 bool hasSideEffects() const { return 0; }
8544 bool isCanonical() const { return 1; }
8545 bool isDataParallel() const { return 0; }
8546 std::string getDebugDesc() const;
8547 bool isEqual(const BroadcastNode &other) const;
8548 llvm::hash_code getHash() const;
8549 void visit(Node *parent, NodeWalker *visitor);
8550 Node* clone() const;
8551 bool verify() const;
8552};
8553} // namespace glow
8554
8555
8556namespace glow {
8557/// TODO
8558class SparseLabelSplitNode final : public Node {
8559 NodeHandle Lengths_;
8560 NodeHandle Indices_;
8561 NodeHandle Values_;
8562 unsigned_t NumLabels_;
8563
8564 public:
8565 enum InputIndices {
8566 LengthsIdx = 0,
8567 IndicesIdx = 1,
8568 ValuesIdx = 2,
8569 };
8570
8571 enum ResultIndices {
8572 LabelValuesIdx = 0,
8573 ExampleIdsIdx = 1,
8574 GradientOffsetMapIdx = 2,
8575 };
8576
8577 SparseLabelSplitNode(llvm::StringRef name, TypeRef LabelValues , TypeRef ExampleIds , TypeRef GradientOffsetMap , NodeValue Lengths, NodeValue Indices, NodeValue Values, unsigned_t NumLabels)
8578 : Node(Kinded::Kind::SparseLabelSplitNodeKind, name), Lengths_(this, Lengths), Indices_(this, Indices), Values_(this, Values), NumLabels_(NumLabels) {
8579 addResult(LabelValues);
8580 addResult(ExampleIds);
8581 addResult(GradientOffsetMap);
8582 }
8583 const NodeValue getLengths() const { return Lengths_; }
8584 const NodeValue getIndices() const { return Indices_; }
8585 const NodeValue getValues() const { return Values_; }
8586 NodeValue getLabelValues() { return getNthResult(0); }
8587 const NodeValue getLabelValues() const { return getNthResult(0); }
8588 NodeValue getExampleIds() { return getNthResult(1); }
8589 const NodeValue getExampleIds() const { return getNthResult(1); }
8590 NodeValue getGradientOffsetMap() { return getNthResult(2); }
8591 const NodeValue getGradientOffsetMap() const { return getNthResult(2); }
8592 unsigned_t getNumLabels() const { return NumLabels_; }
8593
8594 static bool classof(const Kinded *k) {
8595 return k->getKind() == Kinded::Kind::SparseLabelSplitNodeKind;
8596 }
8597
8598
8599 bool isOverwrittenNthInput(unsigned idx) const {
8600 return false;
8601 }
8602
8603 unsigned getNumInputs() const;
8604 std::string getInputName(unsigned idx) const;
8605 NodeValue getNthInput(unsigned idx);
8606 void setNthInput(unsigned idx, NodeValue val);
8607 llvm::StringRef getOutputName(unsigned idx) const;
8608 bool hasSideEffects() const { return 0; }
8609 bool isCanonical() const { return 1; }
8610 bool isDataParallel() const { return 0; }
8611 std::string getDebugDesc() const;
8612 bool isEqual(const SparseLabelSplitNode &other) const;
8613 llvm::hash_code getHash() const;
8614 void visit(Node *parent, NodeWalker *visitor);
8615 Node* clone() const;
8616 bool verify() const;
8617};
8618} // namespace glow
8619
8620
8621namespace glow {
8622/// Reverse the order of elements in a tensor along the given axis. The shape of the tensor is preserved, but the elements are reordered. The node is inspired from Python numpy.
8623class FlipNode final : public Node {
8624 NodeHandle Input_;
8625 unsigned_t Axis_;
8626
8627 public:
8628 enum InputIndices {
8629 InputIdx = 0,
8630 };
8631
8632 enum ResultIndices {
8633 ResultIdx = 0,
8634 };
8635
8636 FlipNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis)
8637 : Node(Kinded::Kind::FlipNodeKind, name), Input_(this, Input), Axis_(Axis) {
8638 addResult(Result);
8639 }
8640 const NodeValue getInput() const { return Input_; }
8641 NodeValue getResult() { return getNthResult(0); }
8642 const NodeValue getResult() const { return getNthResult(0); }
8643 unsigned_t getAxis() const { return Axis_; }
8644
8645 static bool classof(const Kinded *k) {
8646 return k->getKind() == Kinded::Kind::FlipNodeKind;
8647 }
8648
8649
8650 bool isOverwrittenNthInput(unsigned idx) const {
8651 return false;
8652 }
8653
8654 unsigned getNumInputs() const;
8655 std::string getInputName(unsigned idx) const;
8656 NodeValue getNthInput(unsigned idx);
8657 void setNthInput(unsigned idx, NodeValue val);
8658 llvm::StringRef getOutputName(unsigned idx) const;
8659 bool hasSideEffects() const { return 0; }
8660 bool isCanonical() const { return 1; }
8661 bool isDataParallel() const { return 0; }
8662 std::string getDebugDesc() const;
8663 bool isEqual(const FlipNode &other) const;
8664 llvm::hash_code getHash() const;
8665 void visit(Node *parent, NodeWalker *visitor);
8666 Node* clone() const;
8667 bool verify() const;
8668};
8669} // namespace glow
8670
8671
8672namespace glow {
8673/// Generate a tensor of a specific type filled with 'Value'.Splat always keep floating point value internally but canquantize it based on the output type.
8674class SplatNode final : public Node {
8675 float Value_;
8676
8677 public:
8678 enum InputIndices {
8679 };
8680
8681 enum ResultIndices {
8682 ResultIdx = 0,
8683 };
8684
8685 SplatNode(llvm::StringRef name, TypeRef Result , float Value)
8686 : Node(Kinded::Kind::SplatNodeKind, name), Value_(Value) {
8687 addResult(Result);
8688 }
8689 NodeValue getResult() { return getNthResult(0); }
8690 const NodeValue getResult() const { return getNthResult(0); }
8691 float getValue() const { return Value_; }
8692
8693 static bool classof(const Kinded *k) {
8694 return k->getKind() == Kinded::Kind::SplatNodeKind;
8695 }
8696
8697
8698 bool isOverwrittenNthInput(unsigned idx) const {
8699 return false;
8700 }
8701
8702 unsigned getNumInputs() const;
8703 std::string getInputName(unsigned idx) const;
8704 NodeValue getNthInput(unsigned idx);
8705 void setNthInput(unsigned idx, NodeValue val);
8706 llvm::StringRef getOutputName(unsigned idx) const;
8707 bool hasSideEffects() const { return 0; }
8708 bool isCanonical() const { return 1; }
8709 bool isDataParallel() const { return 0; }
8710 std::string getDebugDesc() const;
8711 bool isEqual(const SplatNode &other) const;
8712 llvm::hash_code getHash() const;
8713 void visit(Node *parent, NodeWalker *visitor);
8714 Node* clone() const;
8715 bool verify() const;
8716};
8717} // namespace glow
8718
8719
8720namespace glow {
8721/// Generate a tensor of a specific type without initializing it. This is useful when filling a big tensor entirely with multiple small slices using InsertTensor nodes such that the big tensor is not required to be initialized (filled) with some value prior to insertion. This node is intended to remove the overhead associated with the initialization in situations where it is not required.
8722class TouchNode final : public Node {
8723
8724 public:
8725 enum InputIndices {
8726 };
8727
8728 enum ResultIndices {
8729 ResultIdx = 0,
8730 };
8731
8732 TouchNode(llvm::StringRef name, TypeRef Result )
8733 : Node(Kinded::Kind::TouchNodeKind, name) {
8734 addResult(Result);
8735 }
8736 NodeValue getResult() { return getNthResult(0); }
8737 const NodeValue getResult() const { return getNthResult(0); }
8738
8739 static bool classof(const Kinded *k) {
8740 return k->getKind() == Kinded::Kind::TouchNodeKind;
8741 }
8742
8743
8744 bool isOverwrittenNthInput(unsigned idx) const {
8745 return false;
8746 }
8747
8748 unsigned getNumInputs() const;
8749 std::string getInputName(unsigned idx) const;
8750 NodeValue getNthInput(unsigned idx);
8751 void setNthInput(unsigned idx, NodeValue val);
8752 llvm::StringRef getOutputName(unsigned idx) const;
8753 bool hasSideEffects() const { return 0; }
8754 bool isCanonical() const { return 1; }
8755 bool isDataParallel() const { return 0; }
8756 std::string getDebugDesc() const;
8757 bool isEqual(const TouchNode &other) const;
8758 llvm::hash_code getHash() const;
8759 void visit(Node *parent, NodeWalker *visitor);
8760 Node* clone() const;
8761 bool verify() const;
8762};
8763} // namespace glow
8764
8765
8766namespace glow {
8767/// Stochastic Gradient Descent node used during training. Produces the updated weight that needs to be used instead of Weight for the next iteration.
8768class SGDNode final : public Node {
8769 NodeHandle Gradient_;
8770 NodeHandle Weight_;
8771 float L1Decay_;
8772 float L2Decay_;
8773 float LearningRate_;
8774 float Momentum_;
8775 unsigned_t BatchSize_;
8776
8777 public:
8778 enum InputIndices {
8779 GradientIdx = 0,
8780 WeightIdx = 1,
8781 };
8782
8783 enum ResultIndices {
8784 UpdatedWeightIdx = 0,
8785 };
8786
8787 SGDNode(llvm::StringRef name, NodeValue Gradient, NodeValue Weight, float L1Decay, float L2Decay, float LearningRate, float Momentum, unsigned_t BatchSize)
8788 : Node(Kinded::Kind::SGDNodeKind, name), Gradient_(this, Gradient), Weight_(this, Weight), L1Decay_(L1Decay), L2Decay_(L2Decay), LearningRate_(LearningRate), Momentum_(Momentum), BatchSize_(BatchSize) {
8789 addResult(Weight.getType());
8790 }
8791 const NodeValue getGradient() const { return Gradient_; }
8792 const NodeValue getWeight() const { return Weight_; }
8793 NodeValue getUpdatedWeight() { return getNthResult(0); }
8794 const NodeValue getUpdatedWeight() const { return getNthResult(0); }
8795 float getL1Decay() const { return L1Decay_; }
8796 float getL2Decay() const { return L2Decay_; }
8797 float getLearningRate() const { return LearningRate_; }
8798 float getMomentum() const { return Momentum_; }
8799 unsigned_t getBatchSize() const { return BatchSize_; }
8800
8801 static bool classof(const Kinded *k) {
8802 return k->getKind() == Kinded::Kind::SGDNodeKind;
8803 }
8804
8805
8806 bool isOverwrittenNthInput(unsigned idx) const {
8807 return false;
8808 }
8809
8810 unsigned getNumInputs() const;
8811 std::string getInputName(unsigned idx) const;
8812 NodeValue getNthInput(unsigned idx);
8813 void setNthInput(unsigned idx, NodeValue val);
8814 llvm::StringRef getOutputName(unsigned idx) const;
8815 bool hasSideEffects() const { return 1; }
8816 bool isCanonical() const { return 1; }
8817 bool isDataParallel() const { return 0; }
8818 std::string getDebugDesc() const;
8819 bool isEqual(const SGDNode &other) const;
8820 llvm::hash_code getHash() const;
8821 void visit(Node *parent, NodeWalker *visitor);
8822 Node* clone() const;
8823 bool verify() const;
8824};
8825} // namespace glow
8826
8827
8828namespace glow {
8829/// Inserts a TraceEvent for profiling.
8830class TraceEventNode final : public Node {
8831 NodeHandle Data_;
8832 std::string EventName_;
8833 std::string EventType_;
8834 unsigned_t Index_;
8835
8836 public:
8837 enum InputIndices {
8838 DataIdx = 0,
8839 };
8840
8841 enum ResultIndices {
8842 };
8843
8844 TraceEventNode(llvm::StringRef name, NodeValue Data, std::string EventName, std::string EventType, unsigned_t Index)
8845 : Node(Kinded::Kind::TraceEventNodeKind, name), Data_(this, Data), EventName_(EventName), EventType_(EventType), Index_(Index) {
8846 }
8847 const NodeValue getData() const { return Data_; }
8848 std::string getEventName() const { return EventName_; }
8849 std::string getEventType() const { return EventType_; }
8850 unsigned_t getIndex() const { return Index_; }
8851
8852 static bool classof(const Kinded *k) {
8853 return k->getKind() == Kinded::Kind::TraceEventNodeKind;
8854 }
8855
8856
8857 bool isOverwrittenNthInput(unsigned idx) const {
8858 return false;
8859 }
8860
8861 unsigned getNumInputs() const;
8862 std::string getInputName(unsigned idx) const;
8863 NodeValue getNthInput(unsigned idx);
8864 void setNthInput(unsigned idx, NodeValue val);
8865 llvm::StringRef getOutputName(unsigned idx) const;
8866 bool hasSideEffects() const { return 1; }
8867 bool isCanonical() const { return 1; }
8868 bool isDataParallel() const { return 0; }
8869 std::string getDebugDesc() const;
8870 bool isEqual(const TraceEventNode &other) const;
8871 llvm::hash_code getHash() const;
8872 void visit(Node *parent, NodeWalker *visitor);
8873 Node* clone() const;
8874 bool verify() const;
8875};
8876} // namespace glow
8877
8878
8879namespace glow {
8880/// Generate profile (distribution of values) of the Input tensor. This data is used for quantization of the tensor later on. ProfiledNodeName contains the name of the node which is profiled by the QuantizationProfile node. ProfiledNodeName is helpful as lowering might transform the original graph. ProfiledOutputNumber contains the position of the node's output which gets profiled.
8881class QuantizationProfileNode final : public Node {
8882 NodeHandle Input_;
8883 NodeHandle Histogram_;
8884 NodeHandle ComputationInfo_;
8885 std::string ProfiledNodeName_;
8886 unsigned_t ProfiledOutputNumber_;
8887
8888 public:
8889 enum InputIndices {
8890 InputIdx = 0,
8891 HistogramIdx = 1,
8892 ComputationInfoIdx = 2,
8893 };
8894
8895 enum ResultIndices {
8896 };
8897
8898 QuantizationProfileNode(llvm::StringRef name, NodeValue Input, NodeValue Histogram, NodeValue ComputationInfo, std::string ProfiledNodeName, unsigned_t ProfiledOutputNumber)
8899 : Node(Kinded::Kind::QuantizationProfileNodeKind, name), Input_(this, Input), Histogram_(this, Histogram), ComputationInfo_(this, ComputationInfo), ProfiledNodeName_(ProfiledNodeName), ProfiledOutputNumber_(ProfiledOutputNumber) {
8900 }
8901 const NodeValue getInput() const { return Input_; }
8902 const NodeValue getHistogram() const { return Histogram_; }
8903 const NodeValue getComputationInfo() const { return ComputationInfo_; }
8904 std::string getProfiledNodeName() const { return ProfiledNodeName_; }
8905 unsigned_t getProfiledOutputNumber() const { return ProfiledOutputNumber_; }
8906
8907 static bool classof(const Kinded *k) {
8908 return k->getKind() == Kinded::Kind::QuantizationProfileNodeKind;
8909 }
8910
8911
8912 bool isOverwrittenNthInput(unsigned idx) const {
8913 if (idx == 2) return true;
8914 if (idx == 1) return true;
8915 return false;
8916 }
8917
8918 unsigned getNumInputs() const;
8919 std::string getInputName(unsigned idx) const;
8920 NodeValue getNthInput(unsigned idx);
8921 void setNthInput(unsigned idx, NodeValue val);
8922 llvm::StringRef getOutputName(unsigned idx) const;
8923 bool hasSideEffects() const { return 1; }
8924 bool isCanonical() const { return 1; }
8925 bool isDataParallel() const { return 0; }
8926 std::string getDebugDesc() const;
8927 bool isEqual(const QuantizationProfileNode &other) const;
8928 llvm::hash_code getHash() const;
8929 void visit(Node *parent, NodeWalker *visitor);
8930 Node* clone() const;
8931 bool verify() const;
8932 Placeholder *getHistogramPlaceholder() const ;
8933 Placeholder *getComputationInfoPlaceholder() const;
8934};
8935} // namespace glow
8936
8937
8938namespace glow {
8939/// Simple mapping between quantized numbers.This can be used as quantized sigmoid or tanh functions.
8940class IntLookupTableNode final : public Node {
8941 NodeHandle Input_;
8942 NodeHandle Mapping_;
8943
8944 public:
8945 enum InputIndices {
8946 InputIdx = 0,
8947 MappingIdx = 1,
8948 };
8949
8950 enum ResultIndices {
8951 ResultIdx = 0,
8952 };
8953
8954 IntLookupTableNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Mapping)
8955 : Node(Kinded::Kind::IntLookupTableNodeKind, name), Input_(this, Input), Mapping_(this, Mapping) {
8956 addResult(Result);
8957 }
8958 const NodeValue getInput() const { return Input_; }
8959 const NodeValue getMapping() const { return Mapping_; }
8960 NodeValue getResult() { return getNthResult(0); }
8961 const NodeValue getResult() const { return getNthResult(0); }
8962
8963 static bool classof(const Kinded *k) {
8964 return k->getKind() == Kinded::Kind::IntLookupTableNodeKind;
8965 }
8966
8967
8968 bool isOverwrittenNthInput(unsigned idx) const {
8969 return false;
8970 }
8971
8972 unsigned getNumInputs() const;
8973 std::string getInputName(unsigned idx) const;
8974 NodeValue getNthInput(unsigned idx);
8975 void setNthInput(unsigned idx, NodeValue val);
8976 llvm::StringRef getOutputName(unsigned idx) const;
8977 bool hasSideEffects() const { return 0; }
8978 bool isCanonical() const { return 1; }
8979 bool isDataParallel() const { return 1; }
8980 std::string getDebugDesc() const;
8981 bool isEqual(const IntLookupTableNode &other) const;
8982 llvm::hash_code getHash() const;
8983 void visit(Node *parent, NodeWalker *visitor);
8984 Node* clone() const;
8985 bool verify() const;
8986};
8987} // namespace glow
8988
8989
8990namespace glow {
8991/// Quantize floating point tensor. This operation converts floating point numbers to integers based on the given Scale and Offset. Scale and Offset are deduced from the type of the output.x_q = clip(round(x/Scale) + Offset, -128, 127)
8992class QuantizeNode final : public Node {
8993 NodeHandle Input_;
8994
8995 public:
8996 enum InputIndices {
8997 InputIdx = 0,
8998 };
8999
9000 enum ResultIndices {
9001 ResultIdx = 0,
9002 };
9003
9004 QuantizeNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
9005 : Node(Kinded::Kind::QuantizeNodeKind, name), Input_(this, Input) {
9006 addResult(Result);
9007 }
9008 const NodeValue getInput() const { return Input_; }
9009 NodeValue getResult() { return getNthResult(0); }
9010 const NodeValue getResult() const { return getNthResult(0); }
9011
9012 static bool classof(const Kinded *k) {
9013 return k->getKind() == Kinded::Kind::QuantizeNodeKind;
9014 }
9015
9016
9017 bool isOverwrittenNthInput(unsigned idx) const {
9018 return false;
9019 }
9020
9021 unsigned getNumInputs() const;
9022 std::string getInputName(unsigned idx) const;
9023 NodeValue getNthInput(unsigned idx);
9024 void setNthInput(unsigned idx, NodeValue val);
9025 llvm::StringRef getOutputName(unsigned idx) const;
9026 bool hasSideEffects() const { return 0; }
9027 bool isCanonical() const { return 1; }
9028 bool isDataParallel() const { return 1; }
9029 std::string getDebugDesc() const;
9030 bool isEqual(const QuantizeNode &other) const;
9031 llvm::hash_code getHash() const;
9032 void visit(Node *parent, NodeWalker *visitor);
9033 Node* clone() const;
9034 bool verify() const;
9035};
9036} // namespace glow
9037
9038
9039namespace glow {
9040/// Convert quantized input tensor into the float representation. x = Scale * (x_q - Offset).
9041class DequantizeNode final : public Node {
9042 NodeHandle Input_;
9043
9044 public:
9045 enum InputIndices {
9046 InputIdx = 0,
9047 };
9048
9049 enum ResultIndices {
9050 ResultIdx = 0,
9051 };
9052
9053 DequantizeNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
9054 : Node(Kinded::Kind::DequantizeNodeKind, name), Input_(this, Input) {
9055 addResult(Result);
9056 }
9057 const NodeValue getInput() const { return Input_; }
9058 NodeValue getResult() { return getNthResult(0); }
9059 const NodeValue getResult() const { return getNthResult(0); }
9060
9061 static bool classof(const Kinded *k) {
9062 return k->getKind() == Kinded::Kind::DequantizeNodeKind;
9063 }
9064
9065
9066 bool isOverwrittenNthInput(unsigned idx) const {
9067 return false;
9068 }
9069
9070 unsigned getNumInputs() const;
9071 std::string getInputName(unsigned idx) const;
9072 NodeValue getNthInput(unsigned idx);
9073 void setNthInput(unsigned idx, NodeValue val);
9074 llvm::StringRef getOutputName(unsigned idx) const;
9075 bool hasSideEffects() const { return 0; }
9076 bool isCanonical() const { return 1; }
9077 bool isDataParallel() const { return 1; }
9078 std::string getDebugDesc() const;
9079 bool isEqual(const DequantizeNode &other) const;
9080 llvm::hash_code getHash() const;
9081 void visit(Node *parent, NodeWalker *visitor);
9082 Node* clone() const;
9083 bool verify() const;
9084};
9085} // namespace glow
9086
9087
9088namespace glow {
9089/// Rescale the input quantized tensor to a new Scale and Offset. The new Scale and Offset are specified by the output type passed to the constructor
9090class RescaleQuantizedNode final : public Node {
9091 NodeHandle Input_;
9092
9093 public:
9094 enum InputIndices {
9095 InputIdx = 0,
9096 };
9097
9098 enum ResultIndices {
9099 ResultIdx = 0,
9100 };
9101
9102 RescaleQuantizedNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
9103 : Node(Kinded::Kind::RescaleQuantizedNodeKind, name), Input_(this, Input) {
9104 addResult(Result);
9105 }
9106 const NodeValue getInput() const { return Input_; }
9107 NodeValue getResult() { return getNthResult(0); }
9108 const NodeValue getResult() const { return getNthResult(0); }
9109
9110 static bool classof(const Kinded *k) {
9111 return k->getKind() == Kinded::Kind::RescaleQuantizedNodeKind;
9112 }
9113
9114
9115 bool isOverwrittenNthInput(unsigned idx) const {
9116 return false;
9117 }
9118
9119 unsigned getNumInputs() const;
9120 std::string getInputName(unsigned idx) const;
9121 NodeValue getNthInput(unsigned idx);
9122 void setNthInput(unsigned idx, NodeValue val);
9123 llvm::StringRef getOutputName(unsigned idx) const;
9124 bool hasSideEffects() const { return 0; }
9125 bool isCanonical() const { return 1; }
9126 bool isDataParallel() const { return 1; }
9127 std::string getDebugDesc() const;
9128 bool isEqual(const RescaleQuantizedNode &other) const;
9129 llvm::hash_code getHash() const;
9130 void visit(Node *parent, NodeWalker *visitor);
9131 Node* clone() const;
9132 bool verify() const;
9133};
9134} // namespace glow
9135
9136
9137namespace glow {
9138/// Finds the top K maximal elements for each vector in the tensor. Vectors are defined as the last dimension in the tensor. The input shape {D_0, D_1, ... D_n} results in the outputs {D_0, D_1, ... D_n-1, K}, sorted in non-decreasing order.
9139class TopKNode final : public Node {
9140 NodeHandle Input_;
9141 unsigned_t K_;
9142
9143 public:
9144 enum InputIndices {
9145 InputIdx = 0,
9146 };
9147
9148 enum ResultIndices {
9149 ValuesIdx = 0,
9150 IndicesIdx = 1,
9151 };
9152
9153 TopKNode(llvm::StringRef name, TypeRef Values , TypeRef Indices , NodeValue Input, unsigned_t K)
9154 : Node(Kinded::Kind::TopKNodeKind, name), Input_(this, Input), K_(K) {
9155 addResult(Values);
9156 addResult(Indices);
9157 }
9158 const NodeValue getInput() const { return Input_; }
9159 NodeValue getValues() { return getNthResult(0); }
9160 const NodeValue getValues() const { return getNthResult(0); }
9161 NodeValue getIndices() { return getNthResult(1); }
9162 const NodeValue getIndices() const { return getNthResult(1); }
9163 unsigned_t getK() const { return K_; }
9164
9165 static bool classof(const Kinded *k) {
9166 return k->getKind() == Kinded::Kind::TopKNodeKind;
9167 }
9168
9169
9170 bool isOverwrittenNthInput(unsigned idx) const {
9171 return false;
9172 }
9173
9174 unsigned getNumInputs() const;
9175 std::string getInputName(unsigned idx) const;
9176 NodeValue getNthInput(unsigned idx);
9177 void setNthInput(unsigned idx, NodeValue val);
9178 llvm::StringRef getOutputName(unsigned idx) const;
9179 bool hasSideEffects() const { return 0; }
9180 bool isCanonical() const { return 1; }
9181 bool isDataParallel() const { return 0; }
9182 std::string getDebugDesc() const;
9183 bool isEqual(const TopKNode &other) const;
9184 llvm::hash_code getHash() const;
9185 void visit(Node *parent, NodeWalker *visitor);
9186 Node* clone() const;
9187 bool verify() const;
9188};
9189} // namespace glow
9190
9191
9192namespace glow {
9193/// A LSTM unit node, take Input as I, F, G, O,takes F from forget gate, I from input gate,O from output gate, G from cell gate and C from cell state. Calulates newC = sigmoid(F) * C + sigmoid(I) * tanh(G), newH = tanh(newC) * sigmoid(O).
9194class LSTMUnitNode final : public Node {
9195 NodeHandle Input_;
9196 NodeHandle C_;
9197
9198 public:
9199 enum InputIndices {
9200 InputIdx = 0,
9201 CIdx = 1,
9202 };
9203
9204 enum ResultIndices {
9205 newCIdx = 0,
9206 newHIdx = 1,
9207 };
9208
9209 LSTMUnitNode(llvm::StringRef name, NodeValue Input, NodeValue C)
9210 : Node(Kinded::Kind::LSTMUnitNodeKind, name), Input_(this, Input), C_(this, C) {
9211 addResult(C.getType());
9212 addResult(C.getType());
9213 }
9214 const NodeValue getInput() const { return Input_; }
9215 const NodeValue getC() const { return C_; }
9216 NodeValue getnewC() { return getNthResult(0); }
9217 const NodeValue getnewC() const { return getNthResult(0); }
9218 NodeValue getnewH() { return getNthResult(1); }
9219 const NodeValue getnewH() const { return getNthResult(1); }
9220
9221 static bool classof(const Kinded *k) {
9222 return k->getKind() == Kinded::Kind::LSTMUnitNodeKind;
9223 }
9224
9225
9226 bool isOverwrittenNthInput(unsigned idx) const {
9227 return false;
9228 }
9229
9230 unsigned getNumInputs() const;
9231 std::string getInputName(unsigned idx) const;
9232 NodeValue getNthInput(unsigned idx);
9233 void setNthInput(unsigned idx, NodeValue val);
9234 llvm::StringRef getOutputName(unsigned idx) const;
9235 bool hasSideEffects() const { return 0; }
9236 bool isCanonical() const { return 1; }
9237 bool isDataParallel() const { return 0; }
9238 std::string getDebugDesc() const;
9239 bool isEqual(const LSTMUnitNode &other) const;
9240 llvm::hash_code getHash() const;
9241 void visit(Node *parent, NodeWalker *visitor);
9242 Node* clone() const;
9243 bool verify() const;
9244};
9245} // namespace glow
9246
9247
9248namespace glow {
9249/// Convert the input from its current type to the destination type. The input and output types must have the same shapes. Moreover the input and output types must not be quantized types. Quantized types should use the appropriate Quantize, Dequantize, and Rescale nodes.
9250class ConvertToNode final : public Node {
9251 NodeHandle Input_;
9252
9253 public:
9254 enum InputIndices {
9255 InputIdx = 0,
9256 };
9257
9258 enum ResultIndices {
9259 ResultIdx = 0,
9260 };
9261
9262 ConvertToNode(llvm::StringRef name, TypeRef Result , NodeValue Input)
9263 : Node(Kinded::Kind::ConvertToNodeKind, name), Input_(this, Input) {
9264 addResult(Result);
9265 }
9266 const NodeValue getInput() const { return Input_; }
9267 NodeValue getResult() { return getNthResult(0); }
9268 const NodeValue getResult() const { return getNthResult(0); }
9269
9270 static bool classof(const Kinded *k) {
9271 return k->getKind() == Kinded::Kind::ConvertToNodeKind;
9272 }
9273
9274
9275 bool isOverwrittenNthInput(unsigned idx) const {
9276 return false;
9277 }
9278
9279 unsigned getNumInputs() const;
9280 std::string getInputName(unsigned idx) const;
9281 NodeValue getNthInput(unsigned idx);
9282 void setNthInput(unsigned idx, NodeValue val);
9283 llvm::StringRef getOutputName(unsigned idx) const;
9284 bool hasSideEffects() const { return 0; }
9285 bool isCanonical() const { return 1; }
9286 bool isDataParallel() const { return 1; }
9287 std::string getDebugDesc() const;
9288 bool isEqual(const ConvertToNode &other) const;
9289 llvm::hash_code getHash() const;
9290 void visit(Node *parent, NodeWalker *visitor);
9291 Node* clone() const;
9292 bool verify() const;
9293};
9294} // namespace glow
9295
9296
9297namespace glow {
9298/// This is a node representing an external function call. One possible use of this capability is to pass a source code for a function/kernel. When processing this node, a backend can compile and execute the source code. This node can also be used to pass binary or pointers to executable code. The semantics and implementation of this node not standardized and is very backend-specific.
9299class ExternalFunctionCallNode final : public Node {
9300 std::vector<NodeHandle> Inputs_;
9301 std::string FunctionName_;
9302 std::string FunctionImpl_;
9303 std::string FunctionKind_;
9304
9305 public:
9306 enum InputIndices {
9307 };
9308
9309 enum ResultIndices {
9310 ResultIdx = 0,
9311 };
9312
9313 ExternalFunctionCallNode(llvm::StringRef name, TypeRef Result , std::vector<NodeValue> Inputs, std::string FunctionName, std::string FunctionImpl, std::string FunctionKind)
9314 : Node(Kinded::Kind::ExternalFunctionCallNodeKind, name), FunctionName_(FunctionName), FunctionImpl_(FunctionImpl), FunctionKind_(FunctionKind) {
9315 addResult(Result);
9316 Inputs_.resize(Inputs.size());
9317 for (size_t idx = 0, e = Inputs.size(); idx < e; ++idx) {
9318 Inputs_[idx] = Inputs[idx];
9319 Inputs_[idx].setParent(this);
9320 }
9321 }
9322 NodeValue getResult() { return getNthResult(0); }
9323 const NodeValue getResult() const { return getNthResult(0); }
9324 NodeValueArrayRef getInputs() const { return Inputs_; }
9325 std::string getFunctionName() const { return FunctionName_; }
9326 std::string getFunctionImpl() const { return FunctionImpl_; }
9327 std::string getFunctionKind() const { return FunctionKind_; }
9328
9329 static bool classof(const Kinded *k) {
9330 return k->getKind() == Kinded::Kind::ExternalFunctionCallNodeKind;
9331 }
9332
9333
9334 bool isOverwrittenNthInput(unsigned idx) const {
9335 return false;
9336 }
9337
9338 unsigned getNumInputs() const;
9339 std::string getInputName(unsigned idx) const;
9340 NodeValue getNthInput(unsigned idx);
9341 void setNthInput(unsigned idx, NodeValue val);
9342 llvm::StringRef getOutputName(unsigned idx) const;
9343 bool hasSideEffects() const { return 1; }
9344 bool isCanonical() const { return 1; }
9345 bool isDataParallel() const { return 0; }
9346 std::string getDebugDesc() const;
9347 bool isEqual(const ExternalFunctionCallNode &other) const;
9348 llvm::hash_code getHash() const;
9349 void visit(Node *parent, NodeWalker *visitor);
9350 Node* clone() const;
9351 bool verify() const;
9352};
9353} // namespace glow
9354
9355
9356namespace glow {
9357/// Computes the spectrogram of a mono audio signal using given window size and stride. The FFT length used to compute the spectrogram is the next power of 2 (for a window size of 640 the FFT length is 1024). The length of each spectrogram window is FFT_length / 2 + 1. This node is inspired from TensorFlow.
9358class AudioSpectrogramNode final : public Node {
9359 NodeHandle Input_;
9360 NodeHandle Window_;
9361 NodeHandle TwiddleFactors_;
9362 NodeHandle BitReverseIndices_;
9363 NodeHandle ComplexToRealWeights_;
9364 unsigned_t WindowSize_;
9365 unsigned_t WindowStride_;
9366 bool MagnitudeSquared_;
9367
9368 public:
9369 enum InputIndices {
9370 InputIdx = 0,
9371 WindowIdx = 1,
9372 TwiddleFactorsIdx = 2,
9373 BitReverseIndicesIdx = 3,
9374 ComplexToRealWeightsIdx = 4,
9375 };
9376
9377 enum ResultIndices {
9378 SpectrogramIdx = 0,
9379 };
9380
9381 AudioSpectrogramNode(llvm::StringRef name, TypeRef Spectrogram , NodeValue Input, NodeValue Window, NodeValue TwiddleFactors, NodeValue BitReverseIndices, NodeValue ComplexToRealWeights, unsigned_t WindowSize, unsigned_t WindowStride, bool MagnitudeSquared)
9382 : Node(Kinded::Kind::AudioSpectrogramNodeKind, name), Input_(this, Input), Window_(this, Window), TwiddleFactors_(this, TwiddleFactors), BitReverseIndices_(this, BitReverseIndices), ComplexToRealWeights_(this, ComplexToRealWeights), WindowSize_(WindowSize), WindowStride_(WindowStride), MagnitudeSquared_(MagnitudeSquared) {
9383 addResult(Spectrogram);
9384 }
9385 const NodeValue getInput() const { return Input_; }
9386 const NodeValue getWindow() const { return Window_; }
9387 const NodeValue getTwiddleFactors() const { return TwiddleFactors_; }
9388 const NodeValue getBitReverseIndices() const { return BitReverseIndices_; }
9389 const NodeValue getComplexToRealWeights() const { return ComplexToRealWeights_; }
9390 NodeValue getSpectrogram() { return getNthResult(0); }
9391 const NodeValue getSpectrogram() const { return getNthResult(0); }
9392 unsigned_t getWindowSize() const { return WindowSize_; }
9393 unsigned_t getWindowStride() const { return WindowStride_; }
9394 bool getMagnitudeSquared() const { return MagnitudeSquared_; }
9395
9396 static bool classof(const Kinded *k) {
9397 return k->getKind() == Kinded::Kind::AudioSpectrogramNodeKind;
9398 }
9399
9400
9401 bool isOverwrittenNthInput(unsigned idx) const {
9402 return false;
9403 }
9404
9405 unsigned getNumInputs() const;
9406 std::string getInputName(unsigned idx) const;
9407 NodeValue getNthInput(unsigned idx);
9408 void setNthInput(unsigned idx, NodeValue val);
9409 llvm::StringRef getOutputName(unsigned idx) const;
9410 bool hasSideEffects() const { return 0; }
9411 bool isCanonical() const { return 1; }
9412 bool isDataParallel() const { return 0; }
9413 std::string getDebugDesc() const;
9414 bool isEqual(const AudioSpectrogramNode &other) const;
9415 llvm::hash_code getHash() const;
9416 void visit(Node *parent, NodeWalker *visitor);
9417 Node* clone() const;
9418 bool verify() const;
9419};
9420} // namespace glow
9421
9422
9423namespace glow {
9424/// Computes the MFCC (Mel Frequency Cepstral Coefficient) for the given spectrogram. This node is mostly used as feature extractor for voice/speech audio data in voice command or keyword spotting applications. The input is assumed to be a power spectrogram and not a magnitude.This node is inspired from TensorFlow.
9425class MFCCNode final : public Node {
9426 NodeHandle Spectrogram_;
9427 NodeHandle MelWeights_;
9428 NodeHandle MelRanges_;
9429 NodeHandle DctMat_;
9430 float SampleRate_;
9431 float LowerFrequency_;
9432 float UpperFrequency_;
9433 unsigned_t FilterBankCount_;
9434 unsigned_t NumCoefficients_;
9435
9436 public:
9437 enum InputIndices {
9438 SpectrogramIdx = 0,
9439 MelWeightsIdx = 1,
9440 MelRangesIdx = 2,
9441 DctMatIdx = 3,
9442 };
9443
9444 enum ResultIndices {
9445 CoefficientsIdx = 0,
9446 };
9447
9448 MFCCNode(llvm::StringRef name, TypeRef Coefficients , NodeValue Spectrogram, NodeValue MelWeights, NodeValue MelRanges, NodeValue DctMat, float SampleRate, float LowerFrequency, float UpperFrequency, unsigned_t FilterBankCount, unsigned_t NumCoefficients)
9449 : Node(Kinded::Kind::MFCCNodeKind, name), Spectrogram_(this, Spectrogram), MelWeights_(this, MelWeights), MelRanges_(this, MelRanges), DctMat_(this, DctMat), SampleRate_(SampleRate), LowerFrequency_(LowerFrequency), UpperFrequency_(UpperFrequency), FilterBankCount_(FilterBankCount), NumCoefficients_(NumCoefficients) {
9450 addResult(Coefficients);
9451 }
9452 const NodeValue getSpectrogram() const { return Spectrogram_; }
9453 const NodeValue getMelWeights() const { return MelWeights_; }
9454 const NodeValue getMelRanges() const { return MelRanges_; }
9455 const NodeValue getDctMat() const { return DctMat_; }
9456 NodeValue getCoefficients() { return getNthResult(0); }
9457 const NodeValue getCoefficients() const { return getNthResult(0); }
9458 float getSampleRate() const { return SampleRate_; }
9459 float getLowerFrequency() const { return LowerFrequency_; }
9460 float getUpperFrequency() const { return UpperFrequency_; }
9461 unsigned_t getFilterBankCount() const { return FilterBankCount_; }
9462 unsigned_t getNumCoefficients() const { return NumCoefficients_; }
9463
9464 static bool classof(const Kinded *k) {
9465 return k->getKind() == Kinded::Kind::MFCCNodeKind;
9466 }
9467
9468
9469 bool isOverwrittenNthInput(unsigned idx) const {
9470 return false;
9471 }
9472
9473 unsigned getNumInputs() const;
9474 std::string getInputName(unsigned idx) const;
9475 NodeValue getNthInput(unsigned idx);
9476 void setNthInput(unsigned idx, NodeValue val);
9477 llvm::StringRef getOutputName(unsigned idx) const;
9478 bool hasSideEffects() const { return 0; }
9479 bool isCanonical() const { return 1; }
9480 bool isDataParallel() const { return 0; }
9481 std::string getDebugDesc() const;
9482 bool isEqual(const MFCCNode &other) const;
9483 llvm::hash_code getHash() const;
9484 void visit(Node *parent, NodeWalker *visitor);
9485 Node* clone() const;
9486 bool verify() const;
9487};
9488} // namespace glow
9489
9490
9491namespace glow {
9492/// This is a mix of ONNX and TF NMSv4. It supports multiple classes and does per class NMS. It also supports TF NMS V4 by outputting indices and scalar tensor with number of valid indices. It pads the rest with global MIN box.
9493class NonMaxSuppressionNode final : public Node {
9494 NodeHandle Boxes_;
9495 NodeHandle Scores_;
9496 unsigned_t CenterPointBox_;
9497 unsigned_t MaxOutputBoxesPerClass_;
9498 float IouThreshold_;
9499 float ScoreThreshold_;
9500 bool IsTFVersion4_;
9501
9502 public:
9503 enum InputIndices {
9504 BoxesIdx = 0,
9505 ScoresIdx = 1,
9506 };
9507
9508 enum ResultIndices {
9509 IndicesIdx = 0,
9510 NumberOfSelectedIndicesIdx = 1,
9511 };
9512
9513 NonMaxSuppressionNode(llvm::StringRef name, TypeRef Indices , TypeRef NumberOfSelectedIndices , NodeValue Boxes, NodeValue Scores, unsigned_t CenterPointBox, unsigned_t MaxOutputBoxesPerClass, float IouThreshold, float ScoreThreshold, bool IsTFVersion4)
9514 : Node(Kinded::Kind::NonMaxSuppressionNodeKind, name), Boxes_(this, Boxes), Scores_(this, Scores), CenterPointBox_(CenterPointBox), MaxOutputBoxesPerClass_(MaxOutputBoxesPerClass), IouThreshold_(IouThreshold), ScoreThreshold_(ScoreThreshold), IsTFVersion4_(IsTFVersion4) {
9515 addResult(Indices);
9516 addResult(NumberOfSelectedIndices);
9517 }
9518 const NodeValue getBoxes() const { return Boxes_; }
9519 const NodeValue getScores() const { return Scores_; }
9520 NodeValue getIndices() { return getNthResult(0); }
9521 const NodeValue getIndices() const { return getNthResult(0); }
9522 NodeValue getNumberOfSelectedIndices() { return getNthResult(1); }
9523 const NodeValue getNumberOfSelectedIndices() const { return getNthResult(1); }
9524 unsigned_t getCenterPointBox() const { return CenterPointBox_; }
9525 unsigned_t getMaxOutputBoxesPerClass() const { return MaxOutputBoxesPerClass_; }
9526 float getIouThreshold() const { return IouThreshold_; }
9527 float getScoreThreshold() const { return ScoreThreshold_; }
9528 bool getIsTFVersion4() const { return IsTFVersion4_; }
9529
9530 static bool classof(const Kinded *k) {
9531 return k->getKind() == Kinded::Kind::NonMaxSuppressionNodeKind;
9532 }
9533
9534
9535 bool isOverwrittenNthInput(unsigned idx) const {
9536 return false;
9537 }
9538
9539 unsigned getNumInputs() const;
9540 std::string getInputName(unsigned idx) const;
9541 NodeValue getNthInput(unsigned idx);
9542 void setNthInput(unsigned idx, NodeValue val);
9543 llvm::StringRef getOutputName(unsigned idx) const;
9544 bool hasSideEffects() const { return 0; }
9545 bool isCanonical() const { return 1; }
9546 bool isDataParallel() const { return 0; }
9547 std::string getDebugDesc() const;
9548 bool isEqual(const NonMaxSuppressionNode &other) const;
9549 llvm::hash_code getHash() const;
9550 void visit(Node *parent, NodeWalker *visitor);
9551 Node* clone() const;
9552 bool verify() const;
9553};
9554} // namespace glow
9555
9556
9557namespace glow {
9558/// This node is a TensorFlowLite version of NonMaxSuppresion. The node has the following inputs: Boxes with size [N, B, 4], Scores with size [N, B, C] and Anchors with size [B, 4] where N is the batch size, B is the number of boxes and C is the number of classes. The node has the following attributes (parameters): NumClasses - Number of classes (without the background class). MaxDetections - The maximum number of detections. MaxClassesPerDetection - Maximum classes per detection (Fast NMS). MaxDetectionsPerClass - Maximum detections per class (Regular NMS). IouThreshold - Detection threshold for IoU metric. ScoreThreshold - Detection threshold for scores. XScale - X scale used for decoding the boxes. YScale - Y scale used for decoding the boxes. HScale - H scale used for decoding the boxes. WScale - W scale used for decoding the boxes. RegularNMS - Whether the NMS is 'Regular' or 'Fast'. The node will have the following outputs: DetectionBoxes - the chosen boxes (float). DetectionClasses - the classes of the chosen boxes (int32). DetectionScores - the scores of the chosen boxes (float). NumDetections - number of chose boxes (int32). The first three output tensors will be allocated using the maximum number of possible detections (worst case scenario) but the actual usage will be given by the 'NumDetections' output.
9559class TFLiteDetectionPostProcessNode final : public Node {
9560 NodeHandle Boxes_;
9561 NodeHandle Scores_;
9562 NodeHandle Anchors_;
9563 unsigned_t NumClasses_;
9564 unsigned_t MaxDetections_;
9565 unsigned_t MaxClassesPerDetection_;
9566 unsigned_t MaxDetectionsPerClass_;
9567 float IouThreshold_;
9568 float ScoreThreshold_;
9569 float XScale_;
9570 float YScale_;
9571 float HScale_;
9572 float WScale_;
9573 bool RegularNMS_;
9574
9575 public:
9576 enum InputIndices {
9577 BoxesIdx = 0,
9578 ScoresIdx = 1,
9579 AnchorsIdx = 2,
9580 };
9581
9582 enum ResultIndices {
9583 DetectionBoxesIdx = 0,
9584 DetectionClassesIdx = 1,
9585 DetectionScoresIdx = 2,
9586 NumDetectionsIdx = 3,
9587 };
9588
9589 TFLiteDetectionPostProcessNode(llvm::StringRef name, TypeRef DetectionBoxes , TypeRef DetectionClasses , TypeRef DetectionScores , TypeRef NumDetections , NodeValue Boxes, NodeValue Scores, NodeValue Anchors, unsigned_t NumClasses, unsigned_t MaxDetections, unsigned_t MaxClassesPerDetection, unsigned_t MaxDetectionsPerClass, float IouThreshold, float ScoreThreshold, float XScale, float YScale, float HScale, float WScale, bool RegularNMS)
9590 : Node(Kinded::Kind::TFLiteDetectionPostProcessNodeKind, name), Boxes_(this, Boxes), Scores_(this, Scores), Anchors_(this, Anchors), NumClasses_(NumClasses), MaxDetections_(MaxDetections), MaxClassesPerDetection_(MaxClassesPerDetection), MaxDetectionsPerClass_(MaxDetectionsPerClass), IouThreshold_(IouThreshold), ScoreThreshold_(ScoreThreshold), XScale_(XScale), YScale_(YScale), HScale_(HScale), WScale_(WScale), RegularNMS_(RegularNMS) {
9591 addResult(DetectionBoxes);
9592 addResult(DetectionClasses);
9593 addResult(DetectionScores);
9594 addResult(NumDetections);
9595 }
9596 const NodeValue getBoxes() const { return Boxes_; }
9597 const NodeValue getScores() const { return Scores_; }
9598 const NodeValue getAnchors() const { return Anchors_; }
9599 NodeValue getDetectionBoxes() { return getNthResult(0); }
9600 const NodeValue getDetectionBoxes() const { return getNthResult(0); }
9601 NodeValue getDetectionClasses() { return getNthResult(1); }
9602 const NodeValue getDetectionClasses() const { return getNthResult(1); }
9603 NodeValue getDetectionScores() { return getNthResult(2); }
9604 const NodeValue getDetectionScores() const { return getNthResult(2); }
9605 NodeValue getNumDetections() { return getNthResult(3); }
9606 const NodeValue getNumDetections() const { return getNthResult(3); }
9607 unsigned_t getNumClasses() const { return NumClasses_; }
9608 unsigned_t getMaxDetections() const { return MaxDetections_; }
9609 unsigned_t getMaxClassesPerDetection() const { return MaxClassesPerDetection_; }
9610 unsigned_t getMaxDetectionsPerClass() const { return MaxDetectionsPerClass_; }
9611 float getIouThreshold() const { return IouThreshold_; }
9612 float getScoreThreshold() const { return ScoreThreshold_; }
9613 float getXScale() const { return XScale_; }
9614 float getYScale() const { return YScale_; }
9615 float getHScale() const { return HScale_; }
9616 float getWScale() const { return WScale_; }
9617 bool getRegularNMS() const { return RegularNMS_; }
9618
9619 static bool classof(const Kinded *k) {
9620 return k->getKind() == Kinded::Kind::TFLiteDetectionPostProcessNodeKind;
9621 }
9622
9623
9624 bool isOverwrittenNthInput(unsigned idx) const {
9625 return false;
9626 }
9627
9628 unsigned getNumInputs() const;
9629 std::string getInputName(unsigned idx) const;
9630 NodeValue getNthInput(unsigned idx);
9631 void setNthInput(unsigned idx, NodeValue val);
9632 llvm::StringRef getOutputName(unsigned idx) const;
9633 bool hasSideEffects() const { return 0; }
9634 bool isCanonical() const { return 1; }
9635 bool isDataParallel() const { return 0; }
9636 std::string getDebugDesc() const;
9637 bool isEqual(const TFLiteDetectionPostProcessNode &other) const;
9638 llvm::hash_code getHash() const;
9639 void visit(Node *parent, NodeWalker *visitor);
9640 Node* clone() const;
9641 bool verify() const;
9642};
9643} // namespace glow
9644
9645
9646namespace glow {
9647/// Performs region of interest align (ROI) operator. FeatureMap - a tensor of [N,H,W,C]. N is the batch, C is the channel, H is the height, W is the width. Boxes - a tensor of [K,4] or [K,5] with format [[optinal_batch_index] x0, y0, x1, y1]. K is the number of boxes. BatchIndices - a tensor of [K,]. If N > 1 and Box shape is [K,4], BatchIndices must be valid. Output is a tensor with shape [K, OutputHeight, OutputWidth, C]. Aligned - if true, coordinates are aligned to a center of a pixel.
9648class ROIAlignNode final : public Node {
9649 NodeHandle FeatureMap_;
9650 NodeHandle Boxes_;
9651 NodeHandle BatchIndices_;
9652 unsigned_t Mode_;
9653 unsigned_t OutputHeight_;
9654 unsigned_t OutputWidth_;
9655 unsigned_t SamplingRatio_;
9656 float SpatialScale_;
9657 bool Aligned_;
9658 bool Rotated_;
9659
9660 public:
9661 enum InputIndices {
9662 FeatureMapIdx = 0,
9663 BoxesIdx = 1,
9664 BatchIndicesIdx = 2,
9665 };
9666
9667 enum ResultIndices {
9668 ResultIdx = 0,
9669 };
9670
9671 ROIAlignNode(llvm::StringRef name, TypeRef Result , NodeValue FeatureMap, NodeValue Boxes, NodeValue BatchIndices, unsigned_t Mode, unsigned_t OutputHeight, unsigned_t OutputWidth, unsigned_t SamplingRatio, float SpatialScale, bool Aligned, bool Rotated)
9672 : Node(Kinded::Kind::ROIAlignNodeKind, name), FeatureMap_(this, FeatureMap), Boxes_(this, Boxes), BatchIndices_(this, BatchIndices), Mode_(Mode), OutputHeight_(OutputHeight), OutputWidth_(OutputWidth), SamplingRatio_(SamplingRatio), SpatialScale_(SpatialScale), Aligned_(Aligned), Rotated_(Rotated) {
9673 addResult(Result);
9674 }
9675 const NodeValue getFeatureMap() const { return FeatureMap_; }
9676 const NodeValue getBoxes() const { return Boxes_; }
9677 const NodeValue getBatchIndices() const { return BatchIndices_; }
9678 NodeValue getResult() { return getNthResult(0); }
9679 const NodeValue getResult() const { return getNthResult(0); }
9680 unsigned_t getMode() const { return Mode_; }
9681 unsigned_t getOutputHeight() const { return OutputHeight_; }
9682 unsigned_t getOutputWidth() const { return OutputWidth_; }
9683 unsigned_t getSamplingRatio() const { return SamplingRatio_; }
9684 float getSpatialScale() const { return SpatialScale_; }
9685 bool getAligned() const { return Aligned_; }
9686 bool getRotated() const { return Rotated_; }
9687
9688 static bool classof(const Kinded *k) {
9689 return k->getKind() == Kinded::Kind::ROIAlignNodeKind;
9690 }
9691
9692
9693 bool isOverwrittenNthInput(unsigned idx) const {
9694 return false;
9695 }
9696
9697 unsigned getNumInputs() const;
9698 std::string getInputName(unsigned idx) const;
9699 NodeValue getNthInput(unsigned idx);
9700 void setNthInput(unsigned idx, NodeValue val);
9701 llvm::StringRef getOutputName(unsigned idx) const;
9702 bool hasSideEffects() const { return 0; }
9703 bool isCanonical() const { return 1; }
9704 bool isDataParallel() const { return 0; }
9705 std::string getDebugDesc() const;
9706 bool isEqual(const ROIAlignNode &other) const;
9707 llvm::hash_code getHash() const;
9708 void visit(Node *parent, NodeWalker *visitor);
9709 Node* clone() const;
9710 bool verify() const;
9711};
9712} // namespace glow
9713
9714
9715namespace glow {
9716/// Transform proposal bounding boxes to target bounding box using bounding box regression deltas. Rois tensor's format is: <[optional_batch_index], x1, y1, x2, y2>, shape (M, 4) or (M, 5) where M is the number of Rois. For rotated boxes, this would have an additional angle (in degrees) in the format <[optional_batch_id], ctr_x, ctr_y, w, h, angle> Deltas are of shape (M, K*4) with format <dx, dy, dw, dh>, where K is the number of classes. For rotated Rois: shape (M, K*5), format <dx, dy, dw, dh, da>. ImInfo is of shape <batch_size, 3> with format <img_height, img_width, img_scale>.If proposals from multiple images in a batch are present, they should be grouped sequentially and in incremental order.
9717class BBoxTransformNode final : public Node {
9718 NodeHandle Rois_;
9719 NodeHandle Deltas_;
9720 NodeHandle ImInfo_;
9721 std::vector<float> Weights_;
9722 bool ApplyScale_;
9723 bool Rotated_;
9724 bool AngleBoundOn_;
9725 int64_t AngleBoundLo_;
9726 int64_t AngleBoundHi_;
9727 float ClipAngleThresh_;
9728 bool LegacyPlusOne_;
9729
9730 public:
9731 enum InputIndices {
9732 RoisIdx = 0,
9733 DeltasIdx = 1,
9734 ImInfoIdx = 2,
9735 };
9736
9737 enum ResultIndices {
9738 BoxOutIdx = 0,
9739 RoiBatchSplitsIdx = 1,
9740 };
9741
9742 BBoxTransformNode(llvm::StringRef name, TypeRef BoxOut , TypeRef RoiBatchSplits , NodeValue Rois, NodeValue Deltas, NodeValue ImInfo, std::vector<float> Weights, bool ApplyScale, bool Rotated, bool AngleBoundOn, int64_t AngleBoundLo, int64_t AngleBoundHi, float ClipAngleThresh, bool LegacyPlusOne)
9743 : Node(Kinded::Kind::BBoxTransformNodeKind, name), Rois_(this, Rois), Deltas_(this, Deltas), ImInfo_(this, ImInfo), Weights_(Weights), ApplyScale_(ApplyScale), Rotated_(Rotated), AngleBoundOn_(AngleBoundOn), AngleBoundLo_(AngleBoundLo), AngleBoundHi_(AngleBoundHi), ClipAngleThresh_(ClipAngleThresh), LegacyPlusOne_(LegacyPlusOne) {
9744 addResult(BoxOut);
9745 addResult(RoiBatchSplits);
9746 }
9747 const NodeValue getRois() const { return Rois_; }
9748 const NodeValue getDeltas() const { return Deltas_; }
9749 const NodeValue getImInfo() const { return ImInfo_; }
9750 NodeValue getBoxOut() { return getNthResult(0); }
9751 const NodeValue getBoxOut() const { return getNthResult(0); }
9752 NodeValue getRoiBatchSplits() { return getNthResult(1); }
9753 const NodeValue getRoiBatchSplits() const { return getNthResult(1); }
9754 llvm::ArrayRef<float> getWeights() const { return Weights_; }
9755 bool getApplyScale() const { return ApplyScale_; }
9756 bool getRotated() const { return Rotated_; }
9757 bool getAngleBoundOn() const { return AngleBoundOn_; }
9758 int64_t getAngleBoundLo() const { return AngleBoundLo_; }
9759 int64_t getAngleBoundHi() const { return AngleBoundHi_; }
9760 float getClipAngleThresh() const { return ClipAngleThresh_; }
9761 bool getLegacyPlusOne() const { return LegacyPlusOne_; }
9762
9763 static bool classof(const Kinded *k) {
9764 return k->getKind() == Kinded::Kind::BBoxTransformNodeKind;
9765 }
9766
9767
9768 bool isOverwrittenNthInput(unsigned idx) const {
9769 return false;
9770 }
9771
9772 unsigned getNumInputs() const;
9773 std::string getInputName(unsigned idx) const;
9774 NodeValue getNthInput(unsigned idx);
9775 void setNthInput(unsigned idx, NodeValue val);
9776 llvm::StringRef getOutputName(unsigned idx) const;
9777 bool hasSideEffects() const { return 0; }
9778 bool isCanonical() const { return 1; }
9779 bool isDataParallel() const { return 0; }
9780 std::string getDebugDesc() const;
9781 bool isEqual(const BBoxTransformNode &other) const;
9782 llvm::hash_code getHash() const;
9783 void visit(Node *parent, NodeWalker *visitor);
9784 Node* clone() const;
9785 bool verify() const;
9786};
9787} // namespace glow
9788
9789
9790namespace glow {
9791/// Given RpnMinLevel, RpnMaxLevel and RpnPostNmsTopN CollectRpnProposals merges RoisIn based on RoisProbsIn and returns top proposals limited to RpnPostNmsTopN total, size (n x B), where B is box dimensions and based on dimension of input rois. Format for upright boxes is (image_index, x1, y1, x2, y2).Format for rotated boxes (image_index, ctr_x, ctr_y, w, h, angle)RpnPostNmsTopN should be greater than zero
9792class CollectRpnProposalsNode final : public Node {
9793 std::vector<NodeHandle> RoisIn_;
9794 std::vector<NodeHandle> RoisProbsIn_;
9795 int64_t RpnMaxLevel_;
9796 int64_t RpnMinLevel_;
9797 unsigned_t RpnPostNmsTopN_;
9798
9799 public:
9800 enum InputIndices {
9801 };
9802
9803 enum ResultIndices {
9804 ResultIdx = 0,
9805 };
9806
9807 CollectRpnProposalsNode(llvm::StringRef name, TypeRef Result , std::vector<NodeValue> RoisIn, std::vector<NodeValue> RoisProbsIn, int64_t RpnMaxLevel, int64_t RpnMinLevel, unsigned_t RpnPostNmsTopN)
9808 : Node(Kinded::Kind::CollectRpnProposalsNodeKind, name), RpnMaxLevel_(RpnMaxLevel), RpnMinLevel_(RpnMinLevel), RpnPostNmsTopN_(RpnPostNmsTopN) {
9809 addResult(Result);
9810 RoisIn_.resize(RoisIn.size());
9811 for (size_t idx = 0, e = RoisIn.size(); idx < e; ++idx) {
9812 RoisIn_[idx] = RoisIn[idx];
9813 RoisIn_[idx].setParent(this);
9814 }
9815 RoisProbsIn_.resize(RoisProbsIn.size());
9816 for (size_t idx = 0, e = RoisProbsIn.size(); idx < e; ++idx) {
9817 RoisProbsIn_[idx] = RoisProbsIn[idx];
9818 RoisProbsIn_[idx].setParent(this);
9819 }
9820 }
9821 NodeValue getResult() { return getNthResult(0); }
9822 const NodeValue getResult() const { return getNthResult(0); }
9823 NodeValueArrayRef getRoisIn() const { return RoisIn_; }
9824 NodeValueArrayRef getRoisProbsIn() const { return RoisProbsIn_; }
9825 int64_t getRpnMaxLevel() const { return RpnMaxLevel_; }
9826 int64_t getRpnMinLevel() const { return RpnMinLevel_; }
9827 unsigned_t getRpnPostNmsTopN() const { return RpnPostNmsTopN_; }
9828
9829 static bool classof(const Kinded *k) {
9830 return k->getKind() == Kinded::Kind::CollectRpnProposalsNodeKind;
9831 }
9832
9833
9834 bool isOverwrittenNthInput(unsigned idx) const {
9835 return false;
9836 }
9837
9838 unsigned getNumInputs() const;
9839 std::string getInputName(unsigned idx) const;
9840 NodeValue getNthInput(unsigned idx);
9841 void setNthInput(unsigned idx, NodeValue val);
9842 llvm::StringRef getOutputName(unsigned idx) const;
9843 bool hasSideEffects() const { return 0; }
9844 bool isCanonical() const { return 1; }
9845 bool isDataParallel() const { return 0; }
9846 std::string getDebugDesc() const;
9847 bool isEqual(const CollectRpnProposalsNode &other) const;
9848 llvm::hash_code getHash() const;
9849 void visit(Node *parent, NodeWalker *visitor);
9850 Node* clone() const;
9851 bool verify() const;
9852};
9853} // namespace glow
9854
9855
9856namespace glow {
9857/// LookupTable based data-parallel operation.Given an interpolation table and and index table, return interpolated approximations for arbitrary functions.
9858class LookupTableNode final : public Node {
9859 NodeHandle Input_;
9860 NodeHandle Table_;
9861 NodeHandle TableIdx_;
9862 glow::LUTOperator Operator_;
9863 std::vector<float> OperatorArgs_;
9864
9865 public:
9866 enum InputIndices {
9867 InputIdx = 0,
9868 TableIdx = 1,
9869 TableIdxIdx = 2,
9870 };
9871
9872 enum ResultIndices {
9873 ResultIdx = 0,
9874 };
9875
9876 LookupTableNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Table, NodeValue TableIdx, glow::LUTOperator Operator, std::vector<float> OperatorArgs)
9877 : Node(Kinded::Kind::LookupTableNodeKind, name), Input_(this, Input), Table_(this, Table), TableIdx_(this, TableIdx), Operator_(Operator), OperatorArgs_(OperatorArgs) {
9878 addResult(Result);
9879 }
9880 const NodeValue getInput() const { return Input_; }
9881 const NodeValue getTable() const { return Table_; }
9882 const NodeValue getTableIdx() const { return TableIdx_; }
9883 NodeValue getResult() { return getNthResult(0); }
9884 const NodeValue getResult() const { return getNthResult(0); }
9885 glow::LUTOperator getOperator() const { return Operator_; }
9886 llvm::ArrayRef<float> getOperatorArgs() const { return OperatorArgs_; }
9887
9888 static bool classof(const Kinded *k) {
9889 return k->getKind() == Kinded::Kind::LookupTableNodeKind;
9890 }
9891
9892
9893 bool isOverwrittenNthInput(unsigned idx) const {
9894 return false;
9895 }
9896
9897 unsigned getNumInputs() const;
9898 std::string getInputName(unsigned idx) const;
9899 NodeValue getNthInput(unsigned idx);
9900 void setNthInput(unsigned idx, NodeValue val);
9901 llvm::StringRef getOutputName(unsigned idx) const;
9902 bool hasSideEffects() const { return 0; }
9903 bool isCanonical() const { return 1; }
9904 bool isDataParallel() const { return 1; }
9905 std::string getDebugDesc() const;
9906 bool isEqual(const LookupTableNode &other) const;
9907 llvm::hash_code getHash() const;
9908 void visit(Node *parent, NodeWalker *visitor);
9909 Node* clone() const;
9910 bool verify() const;
9911};
9912} // namespace glow
9913
9914
9915namespace glow {
9916/// A Max node with one splat input; CPU specific.
9917class CPUMaxSplatNode final : public Node {
9918 NodeHandle Input_;
9919 float SplatValue_;
9920
9921 public:
9922 enum InputIndices {
9923 InputIdx = 0,
9924 };
9925
9926 enum ResultIndices {
9927 ResultIdx = 0,
9928 };
9929
9930 CPUMaxSplatNode(llvm::StringRef name, NodeValue Input, float SplatValue)
9931 : Node(Kinded::Kind::CPUMaxSplatNodeKind, name), Input_(this, Input), SplatValue_(SplatValue) {
9932 addResult(Input.getType());
9933 }
9934 const NodeValue getInput() const { return Input_; }
9935 NodeValue getResult() { return getNthResult(0); }
9936 const NodeValue getResult() const { return getNthResult(0); }
9937 float getSplatValue() const { return SplatValue_; }
9938
9939 static bool classof(const Kinded *k) {
9940 return k->getKind() == Kinded::Kind::CPUMaxSplatNodeKind;
9941 }
9942
9943
9944 bool isOverwrittenNthInput(unsigned idx) const {
9945 return false;
9946 }
9947
9948 unsigned getNumInputs() const;
9949 std::string getInputName(unsigned idx) const;
9950 NodeValue getNthInput(unsigned idx);
9951 void setNthInput(unsigned idx, NodeValue val);
9952 llvm::StringRef getOutputName(unsigned idx) const;
9953 bool hasSideEffects() const { return 0; }
9954 bool isCanonical() const { return 0; }
9955 bool isDataParallel() const { return 0; }
9956 std::string getDebugDesc() const;
9957 bool isEqual(const CPUMaxSplatNode &other) const;
9958 llvm::hash_code getHash() const;
9959 void visit(Node *parent, NodeWalker *visitor);
9960 Node* clone() const;
9961 bool verify() const;
9962};
9963} // namespace glow
9964
9965
9966namespace glow {
9967/// This is a cpu-specific convolution implementation where the filter is transposed to the shape [D/8, K, K, C, 8]
9968class CPUConvDKKC8Node final : public Node {
9969 NodeHandle Input_;
9970 NodeHandle Filter_;
9971 NodeHandle Bias_;
9972 std::vector<unsigned_t> Kernels_;
9973 std::vector<unsigned_t> Strides_;
9974 std::vector<unsigned_t> Pads_;
9975 unsigned_t Group_;
9976
9977 public:
9978 enum InputIndices {
9979 InputIdx = 0,
9980 FilterIdx = 1,
9981 BiasIdx = 2,
9982 };
9983
9984 enum ResultIndices {
9985 ResultIdx = 0,
9986 };
9987
9988 CPUConvDKKC8Node(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group)
9989 : Node(Kinded::Kind::CPUConvDKKC8NodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group) {
9990 addResult(Result);
9991 }
9992 const NodeValue getInput() const { return Input_; }
9993 const NodeValue getFilter() const { return Filter_; }
9994 const NodeValue getBias() const { return Bias_; }
9995 NodeValue getResult() { return getNthResult(0); }
9996 const NodeValue getResult() const { return getNthResult(0); }
9997 llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; }
9998 llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; }
9999 llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; }
10000 unsigned_t getGroup() const { return Group_; }
10001
10002 static bool classof(const Kinded *k) {
10003 return k->getKind() == Kinded::Kind::CPUConvDKKC8NodeKind;
10004 }
10005
10006
10007 bool isOverwrittenNthInput(unsigned idx) const {
10008 return false;
10009 }
10010
10011 unsigned getNumInputs() const;
10012 std::string getInputName(unsigned idx) const;
10013 NodeValue getNthInput(unsigned idx);
10014 void setNthInput(unsigned idx, NodeValue val);
10015 llvm::StringRef getOutputName(unsigned idx) const;
10016 bool hasSideEffects() const { return 0; }
10017 bool isCanonical() const { return 0; }
10018 bool isDataParallel() const { return 0; }
10019 std::string getDebugDesc() const;
10020 bool isEqual(const CPUConvDKKC8Node &other) const;
10021 llvm::hash_code getHash() const;
10022 void visit(Node *parent, NodeWalker *visitor);
10023 Node* clone() const;
10024 bool verify() const;
10025};
10026} // namespace glow
10027