1#include "glow/Graph/Nodes.h"
2#include "glow/Base/Type.h"
3#include "glow/Support/Support.h"
4using namespace glow;
5
6unsigned SaveNode::getNumInputs() const {
7 return 2;
8}
9
10std::string SaveNode::getInputName(unsigned idx) const {
11 if (idx == 0) { return "Input"; }
12 if (idx == 1) { return "Output"; }
13 idx -= 2;
14 llvm_unreachable("Invalid index");
15}
16
17NodeValue SaveNode::getNthInput(unsigned idx) {
18 if (idx == 0) { return Input_; }
19 if (idx == 1) { return Output_; }
20 idx -= 2;
21 llvm_unreachable("Invalid index");
22}
23
24void SaveNode::setNthInput(unsigned idx, NodeValue val) {
25 if (idx == 0) { Input_ = val; return; }
26 if (idx == 1) { Output_ = val; return; }
27 idx -= 2;
28 llvm_unreachable("Invalid index");
29}
30
31llvm::StringRef SaveNode::getOutputName(unsigned idx) const {
32 llvm_unreachable("Invalid index");
33}
34
35std::string SaveNode::getDebugDesc() const {
36 DescriptionBuilder db(getKindName());
37 db.addParam("Name", separateString(getName(), 100, "\n"));
38 if (hasPredicate()) db.addParam("Predicate", "Yes");
39 db
40 .addParam("Input", *(getInput().getType()))
41 .addParam("Output", *(getOutput().getType()))
42 .addParam("Users", getNumUsers());
43 return db;
44}
45
46void SaveNode::visit(Node *parent, NodeWalker *visitor) {
47 if (!visitor->shouldVisit(parent, this)) { return; }
48 visitor->pre(parent, this);
49if (hasPredicate())
50 getPredicate().getNode()->visit(this, visitor);
51 getInput().getNode()->visit(this, visitor);
52 getOutput().getNode()->visit(this, visitor);
53 visitor->post(parent, this);
54}
55
56bool SaveNode::isEqual(const SaveNode &other) const {
57 return true &&
58 Input_ == other.Input_ &&
59 Output_ == other.Output_ &&
60 predicate_ == other.predicate_;
61}
62
63Node* SaveNode::clone() const {
64 return new SaveNode(getName(), getInput(), getOutput());
65}
66
67llvm::hash_code SaveNode::getHash() const {
68 return llvm::hash_combine(
69 Input_,
70 Output_);
71}
72Placeholder *SaveNode::getPlaceholder() const { return llvm::cast<Placeholder>(Output_.getNode()); };
73unsigned PadNode::getNumInputs() const {
74 return 1;
75}
76
77std::string PadNode::getInputName(unsigned idx) const {
78 if (idx == 0) { return "Input"; }
79 idx -= 1;
80 llvm_unreachable("Invalid index");
81}
82
83NodeValue PadNode::getNthInput(unsigned idx) {
84 if (idx == 0) { return Input_; }
85 idx -= 1;
86 llvm_unreachable("Invalid index");
87}
88
89void PadNode::setNthInput(unsigned idx, NodeValue val) {
90 if (idx == 0) { Input_ = val; return; }
91 idx -= 1;
92 llvm_unreachable("Invalid index");
93}
94
95llvm::StringRef PadNode::getOutputName(unsigned idx) const {
96 if (idx == 0) { return "Result"; }
97 llvm_unreachable("Invalid index");
98}
99
100std::string PadNode::getDebugDesc() const {
101 DescriptionBuilder db(getKindName());
102 db.addParam("Name", separateString(getName(), 100, "\n"));
103 if (hasPredicate()) db.addParam("Predicate", "Yes");
104 db
105 .addParam("Input", *(getInput().getType()))
106 .addParam("Mode", getMode())
107 .addParam("Pads", getPads())
108 .addParam("Value", getValue())
109 .addParam("Users", getNumUsers());
110 db.addParam("Result", *(getResult().getType()));
111 return db;
112}
113
114void PadNode::visit(Node *parent, NodeWalker *visitor) {
115 if (!visitor->shouldVisit(parent, this)) { return; }
116 visitor->pre(parent, this);
117if (hasPredicate())
118 getPredicate().getNode()->visit(this, visitor);
119 getInput().getNode()->visit(this, visitor);
120 visitor->post(parent, this);
121}
122
123bool PadNode::isEqual(const PadNode &other) const {
124 return true &&
125 Input_ == other.Input_ &&
126 predicate_ == other.predicate_ &&
127 Mode_ == other.Mode_ &&
128 Pads_ == other.Pads_ &&
129 Value_ == other.Value_ &&
130 getType(0) == other.getType(0);
131}
132
133Node* PadNode::clone() const {
134 return new PadNode(getName(), getResult().getType(), getInput(), getMode(), getPads(), getValue());
135}
136
137llvm::hash_code PadNode::getHash() const {
138 return llvm::hash_combine(
139 Mode_,
140 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
141 toBinary(Value_),
142 Input_);
143}
144
145unsigned ConvolutionGradNode::getNumInputs() const {
146 return 5;
147}
148
149std::string ConvolutionGradNode::getInputName(unsigned idx) const {
150 if (idx == 0) { return "Input"; }
151 if (idx == 1) { return "Filter"; }
152 if (idx == 2) { return "Bias"; }
153 if (idx == 3) { return "OriginalOutputForResult"; }
154 if (idx == 4) { return "GradOfOriginalOutputNamedResult"; }
155 idx -= 5;
156 llvm_unreachable("Invalid index");
157}
158
159NodeValue ConvolutionGradNode::getNthInput(unsigned idx) {
160 if (idx == 0) { return Input_; }
161 if (idx == 1) { return Filter_; }
162 if (idx == 2) { return Bias_; }
163 if (idx == 3) { return OriginalOutputForResult_; }
164 if (idx == 4) { return GradOfOriginalOutputNamedResult_; }
165 idx -= 5;
166 llvm_unreachable("Invalid index");
167}
168
169void ConvolutionGradNode::setNthInput(unsigned idx, NodeValue val) {
170 if (idx == 0) { Input_ = val; return; }
171 if (idx == 1) { Filter_ = val; return; }
172 if (idx == 2) { Bias_ = val; return; }
173 if (idx == 3) { OriginalOutputForResult_ = val; return; }
174 if (idx == 4) { GradOfOriginalOutputNamedResult_ = val; return; }
175 idx -= 5;
176 llvm_unreachable("Invalid index");
177}
178
179llvm::StringRef ConvolutionGradNode::getOutputName(unsigned idx) const {
180 if (idx == 0) { return "GradOfInputNamedInput"; }
181 if (idx == 1) { return "GradOfInputNamedFilter"; }
182 if (idx == 2) { return "GradOfInputNamedBias"; }
183 llvm_unreachable("Invalid index");
184}
185
186std::string ConvolutionGradNode::getDebugDesc() const {
187 DescriptionBuilder db(getKindName());
188 db.addParam("Name", separateString(getName(), 100, "\n"));
189 if (hasPredicate()) db.addParam("Predicate", "Yes");
190 db
191 .addParam("Input", *(getInput().getType()))
192 .addParam("Filter", *(getFilter().getType()))
193 .addParam("Bias", *(getBias().getType()))
194 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
195 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
196 .addParam("Kernels", getKernels())
197 .addParam("Strides", getStrides())
198 .addParam("Pads", getPads())
199 .addParam("Group", getGroup())
200 .addParam("Dilation", getDilation())
201 .addParam("Layout", getLayout())
202 .addParam("FusedActivation", getFusedActivation())
203 .addParam("FusedActivationArgs", getFusedActivationArgs())
204 .addParam("Users", getNumUsers());
205 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
206 db.addParam("GradOfInputNamedFilter", *(getGradOfInputNamedFilter().getType()));
207 db.addParam("GradOfInputNamedBias", *(getGradOfInputNamedBias().getType()));
208 return db;
209}
210
211void ConvolutionGradNode::visit(Node *parent, NodeWalker *visitor) {
212 if (!visitor->shouldVisit(parent, this)) { return; }
213 visitor->pre(parent, this);
214if (hasPredicate())
215 getPredicate().getNode()->visit(this, visitor);
216 getInput().getNode()->visit(this, visitor);
217 getFilter().getNode()->visit(this, visitor);
218 getBias().getNode()->visit(this, visitor);
219 getOriginalOutputForResult().getNode()->visit(this, visitor);
220 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
221 visitor->post(parent, this);
222}
223
224bool ConvolutionGradNode::isEqual(const ConvolutionGradNode &other) const {
225 return true &&
226 Input_ == other.Input_ &&
227 Filter_ == other.Filter_ &&
228 Bias_ == other.Bias_ &&
229 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
230 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
231 predicate_ == other.predicate_ &&
232 Kernels_ == other.Kernels_ &&
233 Strides_ == other.Strides_ &&
234 Pads_ == other.Pads_ &&
235 Group_ == other.Group_ &&
236 Dilation_ == other.Dilation_ &&
237 Layout_ == other.Layout_ &&
238 FusedActivation_ == other.FusedActivation_ &&
239 FusedActivationArgs_ == other.FusedActivationArgs_ &&
240 getType(0) == other.getType(0) &&
241 getType(1) == other.getType(1) &&
242 getType(2) == other.getType(2);
243}
244
245Node* ConvolutionGradNode::clone() const {
246 return new ConvolutionGradNode(getName(), getInput(), getFilter(), getBias(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getKernels(), getStrides(), getPads(), getGroup(), getDilation(), getLayout(), getFusedActivation(), getFusedActivationArgs());
247}
248
249llvm::hash_code ConvolutionGradNode::getHash() const {
250 return llvm::hash_combine(
251 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
252 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
253 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
254 Group_,
255 llvm::hash_combine_range(Dilation_.begin(), Dilation_.end()),
256 Layout_,
257 FusedActivation_,
258 [](const std::vector<float>& floatVec) -> llvm::hash_code {
259 std::vector<size_t> sizeVec = toBinary(floatVec);
260 return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end());
261 }(FusedActivationArgs_),
262 Input_,
263 Filter_,
264 Bias_,
265 OriginalOutputForResult_,
266 GradOfOriginalOutputNamedResult_);
267}
268
269unsigned ConvolutionNode::getNumInputs() const {
270 return 3;
271}
272
273std::string ConvolutionNode::getInputName(unsigned idx) const {
274 if (idx == 0) { return "Input"; }
275 if (idx == 1) { return "Filter"; }
276 if (idx == 2) { return "Bias"; }
277 idx -= 3;
278 llvm_unreachable("Invalid index");
279}
280
281NodeValue ConvolutionNode::getNthInput(unsigned idx) {
282 if (idx == 0) { return Input_; }
283 if (idx == 1) { return Filter_; }
284 if (idx == 2) { return Bias_; }
285 idx -= 3;
286 llvm_unreachable("Invalid index");
287}
288
289void ConvolutionNode::setNthInput(unsigned idx, NodeValue val) {
290 if (idx == 0) { Input_ = val; return; }
291 if (idx == 1) { Filter_ = val; return; }
292 if (idx == 2) { Bias_ = val; return; }
293 idx -= 3;
294 llvm_unreachable("Invalid index");
295}
296
297llvm::StringRef ConvolutionNode::getOutputName(unsigned idx) const {
298 if (idx == 0) { return "Result"; }
299 llvm_unreachable("Invalid index");
300}
301
302std::string ConvolutionNode::getDebugDesc() const {
303 DescriptionBuilder db(getKindName());
304 db.addParam("Name", separateString(getName(), 100, "\n"));
305 if (hasPredicate()) db.addParam("Predicate", "Yes");
306 db
307 .addParam("Input", *(getInput().getType()))
308 .addParam("Filter", *(getFilter().getType()))
309 .addParam("Bias", *(getBias().getType()))
310 .addParam("Kernels", getKernels())
311 .addParam("Strides", getStrides())
312 .addParam("Pads", getPads())
313 .addParam("Group", getGroup())
314 .addParam("Dilation", getDilation())
315 .addParam("Layout", getLayout())
316 .addParam("FusedActivation", getFusedActivation())
317 .addParam("FusedActivationArgs", getFusedActivationArgs())
318 .addParam("Users", getNumUsers());
319 db.addParam("Result", *(getResult().getType()));
320 return db;
321}
322
323void ConvolutionNode::visit(Node *parent, NodeWalker *visitor) {
324 if (!visitor->shouldVisit(parent, this)) { return; }
325 visitor->pre(parent, this);
326if (hasPredicate())
327 getPredicate().getNode()->visit(this, visitor);
328 getInput().getNode()->visit(this, visitor);
329 getFilter().getNode()->visit(this, visitor);
330 getBias().getNode()->visit(this, visitor);
331 visitor->post(parent, this);
332}
333
334bool ConvolutionNode::isEqual(const ConvolutionNode &other) const {
335 return true &&
336 Input_ == other.Input_ &&
337 Filter_ == other.Filter_ &&
338 Bias_ == other.Bias_ &&
339 predicate_ == other.predicate_ &&
340 Kernels_ == other.Kernels_ &&
341 Strides_ == other.Strides_ &&
342 Pads_ == other.Pads_ &&
343 Group_ == other.Group_ &&
344 Dilation_ == other.Dilation_ &&
345 Layout_ == other.Layout_ &&
346 FusedActivation_ == other.FusedActivation_ &&
347 FusedActivationArgs_ == other.FusedActivationArgs_ &&
348 getType(0) == other.getType(0);
349}
350
351Node* ConvolutionNode::clone() const {
352 return new ConvolutionNode(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getKernels(), getStrides(), getPads(), getGroup(), getDilation(), getLayout(), getFusedActivation(), getFusedActivationArgs());
353}
354
355llvm::hash_code ConvolutionNode::getHash() const {
356 return llvm::hash_combine(
357 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
358 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
359 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
360 Group_,
361 llvm::hash_combine_range(Dilation_.begin(), Dilation_.end()),
362 Layout_,
363 FusedActivation_,
364 [](const std::vector<float>& floatVec) -> llvm::hash_code {
365 std::vector<size_t> sizeVec = toBinary(floatVec);
366 return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end());
367 }(FusedActivationArgs_),
368 Input_,
369 Filter_,
370 Bias_);
371}
372bool ConvolutionNode::hasFusedActivation() const { return getFusedActivation() != FusedActivation::NONE; }
373ConvolutionGradNode *ConvolutionNode::getGrad(GraphGradMapper &builder) {
374 auto *x = new ConvolutionGradNode(getName().str() + "_grad", getInput(), getFilter(), getBias(), getResult(), builder.getGradient(getResult()), getKernels(), getStrides(), getPads(), getGroup(), getDilation(), getLayout(), getFusedActivation(), getFusedActivationArgs());
375 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
376 builder.addGradient(getFilter(), x->getGradOfInputNamedFilter());
377 builder.addGradient(getBias(), x->getGradOfInputNamedBias());
378 return x;
379}
380
381unsigned ChannelwiseQuantizedConvolutionNode::getNumInputs() const {
382 return 7;
383}
384
385std::string ChannelwiseQuantizedConvolutionNode::getInputName(unsigned idx) const {
386 if (idx == 0) { return "Input"; }
387 if (idx == 1) { return "Filter"; }
388 if (idx == 2) { return "Bias"; }
389 if (idx == 3) { return "FilterScales"; }
390 if (idx == 4) { return "FilterOffsets"; }
391 if (idx == 5) { return "BiasScales"; }
392 if (idx == 6) { return "BiasOffsets"; }
393 idx -= 7;
394 llvm_unreachable("Invalid index");
395}
396
397NodeValue ChannelwiseQuantizedConvolutionNode::getNthInput(unsigned idx) {
398 if (idx == 0) { return Input_; }
399 if (idx == 1) { return Filter_; }
400 if (idx == 2) { return Bias_; }
401 if (idx == 3) { return FilterScales_; }
402 if (idx == 4) { return FilterOffsets_; }
403 if (idx == 5) { return BiasScales_; }
404 if (idx == 6) { return BiasOffsets_; }
405 idx -= 7;
406 llvm_unreachable("Invalid index");
407}
408
409void ChannelwiseQuantizedConvolutionNode::setNthInput(unsigned idx, NodeValue val) {
410 if (idx == 0) { Input_ = val; return; }
411 if (idx == 1) { Filter_ = val; return; }
412 if (idx == 2) { Bias_ = val; return; }
413 if (idx == 3) { FilterScales_ = val; return; }
414 if (idx == 4) { FilterOffsets_ = val; return; }
415 if (idx == 5) { BiasScales_ = val; return; }
416 if (idx == 6) { BiasOffsets_ = val; return; }
417 idx -= 7;
418 llvm_unreachable("Invalid index");
419}
420
421llvm::StringRef ChannelwiseQuantizedConvolutionNode::getOutputName(unsigned idx) const {
422 if (idx == 0) { return "Result"; }
423 llvm_unreachable("Invalid index");
424}
425
426std::string ChannelwiseQuantizedConvolutionNode::getDebugDesc() const {
427 DescriptionBuilder db(getKindName());
428 db.addParam("Name", separateString(getName(), 100, "\n"));
429 if (hasPredicate()) db.addParam("Predicate", "Yes");
430 db
431 .addParam("Input", *(getInput().getType()))
432 .addParam("Filter", *(getFilter().getType()))
433 .addParam("Bias", *(getBias().getType()))
434 .addParam("FilterScales", *(getFilterScales().getType()))
435 .addParam("FilterOffsets", *(getFilterOffsets().getType()))
436 .addParam("BiasScales", *(getBiasScales().getType()))
437 .addParam("BiasOffsets", *(getBiasOffsets().getType()))
438 .addParam("Kernels", getKernels())
439 .addParam("Strides", getStrides())
440 .addParam("Pads", getPads())
441 .addParam("Group", getGroup())
442 .addParam("Dilation", getDilation())
443 .addParam("FusedActivation", getFusedActivation())
444 .addParam("FusedActivationArgs", getFusedActivationArgs())
445 .addParam("Users", getNumUsers());
446 db.addParam("Result", *(getResult().getType()));
447 return db;
448}
449
450void ChannelwiseQuantizedConvolutionNode::visit(Node *parent, NodeWalker *visitor) {
451 if (!visitor->shouldVisit(parent, this)) { return; }
452 visitor->pre(parent, this);
453if (hasPredicate())
454 getPredicate().getNode()->visit(this, visitor);
455 getInput().getNode()->visit(this, visitor);
456 getFilter().getNode()->visit(this, visitor);
457 getBias().getNode()->visit(this, visitor);
458 getFilterScales().getNode()->visit(this, visitor);
459 getFilterOffsets().getNode()->visit(this, visitor);
460 getBiasScales().getNode()->visit(this, visitor);
461 getBiasOffsets().getNode()->visit(this, visitor);
462 visitor->post(parent, this);
463}
464
465bool ChannelwiseQuantizedConvolutionNode::isEqual(const ChannelwiseQuantizedConvolutionNode &other) const {
466 return true &&
467 Input_ == other.Input_ &&
468 Filter_ == other.Filter_ &&
469 Bias_ == other.Bias_ &&
470 FilterScales_ == other.FilterScales_ &&
471 FilterOffsets_ == other.FilterOffsets_ &&
472 BiasScales_ == other.BiasScales_ &&
473 BiasOffsets_ == other.BiasOffsets_ &&
474 predicate_ == other.predicate_ &&
475 Kernels_ == other.Kernels_ &&
476 Strides_ == other.Strides_ &&
477 Pads_ == other.Pads_ &&
478 Group_ == other.Group_ &&
479 Dilation_ == other.Dilation_ &&
480 FusedActivation_ == other.FusedActivation_ &&
481 FusedActivationArgs_ == other.FusedActivationArgs_ &&
482 getType(0) == other.getType(0);
483}
484
485Node* ChannelwiseQuantizedConvolutionNode::clone() const {
486 return new ChannelwiseQuantizedConvolutionNode(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getFilterScales(), getFilterOffsets(), getBiasScales(), getBiasOffsets(), getKernels(), getStrides(), getPads(), getGroup(), getDilation(), getFusedActivation(), getFusedActivationArgs());
487}
488
489llvm::hash_code ChannelwiseQuantizedConvolutionNode::getHash() const {
490 return llvm::hash_combine(
491 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
492 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
493 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
494 Group_,
495 llvm::hash_combine_range(Dilation_.begin(), Dilation_.end()),
496 FusedActivation_,
497 [](const std::vector<float>& floatVec) -> llvm::hash_code {
498 std::vector<size_t> sizeVec = toBinary(floatVec);
499 return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end());
500 }(FusedActivationArgs_),
501 Input_,
502 Filter_,
503 Bias_,
504 FilterScales_,
505 FilterOffsets_,
506 BiasScales_,
507 BiasOffsets_);
508}
509bool ChannelwiseQuantizedConvolutionNode::hasFusedActivation() const { return getFusedActivation() != FusedActivation::NONE; }
510unsigned ConvTransposeNode::getNumInputs() const {
511 return 3;
512}
513
514std::string ConvTransposeNode::getInputName(unsigned idx) const {
515 if (idx == 0) { return "Input"; }
516 if (idx == 1) { return "Filter"; }
517 if (idx == 2) { return "Bias"; }
518 idx -= 3;
519 llvm_unreachable("Invalid index");
520}
521
522NodeValue ConvTransposeNode::getNthInput(unsigned idx) {
523 if (idx == 0) { return Input_; }
524 if (idx == 1) { return Filter_; }
525 if (idx == 2) { return Bias_; }
526 idx -= 3;
527 llvm_unreachable("Invalid index");
528}
529
530void ConvTransposeNode::setNthInput(unsigned idx, NodeValue val) {
531 if (idx == 0) { Input_ = val; return; }
532 if (idx == 1) { Filter_ = val; return; }
533 if (idx == 2) { Bias_ = val; return; }
534 idx -= 3;
535 llvm_unreachable("Invalid index");
536}
537
538llvm::StringRef ConvTransposeNode::getOutputName(unsigned idx) const {
539 if (idx == 0) { return "Result"; }
540 llvm_unreachable("Invalid index");
541}
542
543std::string ConvTransposeNode::getDebugDesc() const {
544 DescriptionBuilder db(getKindName());
545 db.addParam("Name", separateString(getName(), 100, "\n"));
546 if (hasPredicate()) db.addParam("Predicate", "Yes");
547 db
548 .addParam("Input", *(getInput().getType()))
549 .addParam("Filter", *(getFilter().getType()))
550 .addParam("Bias", *(getBias().getType()))
551 .addParam("Kernels", getKernels())
552 .addParam("Strides", getStrides())
553 .addParam("Pads", getPads())
554 .addParam("Group", getGroup())
555 .addParam("Dilation", getDilation())
556 .addParam("Users", getNumUsers());
557 db.addParam("Result", *(getResult().getType()));
558 return db;
559}
560
561void ConvTransposeNode::visit(Node *parent, NodeWalker *visitor) {
562 if (!visitor->shouldVisit(parent, this)) { return; }
563 visitor->pre(parent, this);
564if (hasPredicate())
565 getPredicate().getNode()->visit(this, visitor);
566 getInput().getNode()->visit(this, visitor);
567 getFilter().getNode()->visit(this, visitor);
568 getBias().getNode()->visit(this, visitor);
569 visitor->post(parent, this);
570}
571
572bool ConvTransposeNode::isEqual(const ConvTransposeNode &other) const {
573 return true &&
574 Input_ == other.Input_ &&
575 Filter_ == other.Filter_ &&
576 Bias_ == other.Bias_ &&
577 predicate_ == other.predicate_ &&
578 Kernels_ == other.Kernels_ &&
579 Strides_ == other.Strides_ &&
580 Pads_ == other.Pads_ &&
581 Group_ == other.Group_ &&
582 Dilation_ == other.Dilation_ &&
583 getType(0) == other.getType(0);
584}
585
586Node* ConvTransposeNode::clone() const {
587 return new ConvTransposeNode(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getKernels(), getStrides(), getPads(), getGroup(), getDilation());
588}
589
590llvm::hash_code ConvTransposeNode::getHash() const {
591 return llvm::hash_combine(
592 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
593 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
594 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
595 Group_,
596 llvm::hash_combine_range(Dilation_.begin(), Dilation_.end()),
597 Input_,
598 Filter_,
599 Bias_);
600}
601
602unsigned Convolution3DGradNode::getNumInputs() const {
603 return 5;
604}
605
606std::string Convolution3DGradNode::getInputName(unsigned idx) const {
607 if (idx == 0) { return "Input"; }
608 if (idx == 1) { return "Filter"; }
609 if (idx == 2) { return "Bias"; }
610 if (idx == 3) { return "OriginalOutputForResult"; }
611 if (idx == 4) { return "GradOfOriginalOutputNamedResult"; }
612 idx -= 5;
613 llvm_unreachable("Invalid index");
614}
615
616NodeValue Convolution3DGradNode::getNthInput(unsigned idx) {
617 if (idx == 0) { return Input_; }
618 if (idx == 1) { return Filter_; }
619 if (idx == 2) { return Bias_; }
620 if (idx == 3) { return OriginalOutputForResult_; }
621 if (idx == 4) { return GradOfOriginalOutputNamedResult_; }
622 idx -= 5;
623 llvm_unreachable("Invalid index");
624}
625
626void Convolution3DGradNode::setNthInput(unsigned idx, NodeValue val) {
627 if (idx == 0) { Input_ = val; return; }
628 if (idx == 1) { Filter_ = val; return; }
629 if (idx == 2) { Bias_ = val; return; }
630 if (idx == 3) { OriginalOutputForResult_ = val; return; }
631 if (idx == 4) { GradOfOriginalOutputNamedResult_ = val; return; }
632 idx -= 5;
633 llvm_unreachable("Invalid index");
634}
635
636llvm::StringRef Convolution3DGradNode::getOutputName(unsigned idx) const {
637 if (idx == 0) { return "GradOfInputNamedInput"; }
638 if (idx == 1) { return "GradOfInputNamedFilter"; }
639 if (idx == 2) { return "GradOfInputNamedBias"; }
640 llvm_unreachable("Invalid index");
641}
642
643std::string Convolution3DGradNode::getDebugDesc() const {
644 DescriptionBuilder db(getKindName());
645 db.addParam("Name", separateString(getName(), 100, "\n"));
646 if (hasPredicate()) db.addParam("Predicate", "Yes");
647 db
648 .addParam("Input", *(getInput().getType()))
649 .addParam("Filter", *(getFilter().getType()))
650 .addParam("Bias", *(getBias().getType()))
651 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
652 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
653 .addParam("Kernels", getKernels())
654 .addParam("Strides", getStrides())
655 .addParam("Pads", getPads())
656 .addParam("Group", getGroup())
657 .addParam("Users", getNumUsers());
658 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
659 db.addParam("GradOfInputNamedFilter", *(getGradOfInputNamedFilter().getType()));
660 db.addParam("GradOfInputNamedBias", *(getGradOfInputNamedBias().getType()));
661 return db;
662}
663
664void Convolution3DGradNode::visit(Node *parent, NodeWalker *visitor) {
665 if (!visitor->shouldVisit(parent, this)) { return; }
666 visitor->pre(parent, this);
667if (hasPredicate())
668 getPredicate().getNode()->visit(this, visitor);
669 getInput().getNode()->visit(this, visitor);
670 getFilter().getNode()->visit(this, visitor);
671 getBias().getNode()->visit(this, visitor);
672 getOriginalOutputForResult().getNode()->visit(this, visitor);
673 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
674 visitor->post(parent, this);
675}
676
677bool Convolution3DGradNode::isEqual(const Convolution3DGradNode &other) const {
678 return true &&
679 Input_ == other.Input_ &&
680 Filter_ == other.Filter_ &&
681 Bias_ == other.Bias_ &&
682 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
683 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
684 predicate_ == other.predicate_ &&
685 Kernels_ == other.Kernels_ &&
686 Strides_ == other.Strides_ &&
687 Pads_ == other.Pads_ &&
688 Group_ == other.Group_ &&
689 getType(0) == other.getType(0) &&
690 getType(1) == other.getType(1) &&
691 getType(2) == other.getType(2);
692}
693
694Node* Convolution3DGradNode::clone() const {
695 return new Convolution3DGradNode(getName(), getInput(), getFilter(), getBias(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getKernels(), getStrides(), getPads(), getGroup());
696}
697
698llvm::hash_code Convolution3DGradNode::getHash() const {
699 return llvm::hash_combine(
700 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
701 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
702 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
703 Group_,
704 Input_,
705 Filter_,
706 Bias_,
707 OriginalOutputForResult_,
708 GradOfOriginalOutputNamedResult_);
709}
710
711unsigned Convolution3DNode::getNumInputs() const {
712 return 3;
713}
714
715std::string Convolution3DNode::getInputName(unsigned idx) const {
716 if (idx == 0) { return "Input"; }
717 if (idx == 1) { return "Filter"; }
718 if (idx == 2) { return "Bias"; }
719 idx -= 3;
720 llvm_unreachable("Invalid index");
721}
722
723NodeValue Convolution3DNode::getNthInput(unsigned idx) {
724 if (idx == 0) { return Input_; }
725 if (idx == 1) { return Filter_; }
726 if (idx == 2) { return Bias_; }
727 idx -= 3;
728 llvm_unreachable("Invalid index");
729}
730
731void Convolution3DNode::setNthInput(unsigned idx, NodeValue val) {
732 if (idx == 0) { Input_ = val; return; }
733 if (idx == 1) { Filter_ = val; return; }
734 if (idx == 2) { Bias_ = val; return; }
735 idx -= 3;
736 llvm_unreachable("Invalid index");
737}
738
739llvm::StringRef Convolution3DNode::getOutputName(unsigned idx) const {
740 if (idx == 0) { return "Result"; }
741 llvm_unreachable("Invalid index");
742}
743
744std::string Convolution3DNode::getDebugDesc() const {
745 DescriptionBuilder db(getKindName());
746 db.addParam("Name", separateString(getName(), 100, "\n"));
747 if (hasPredicate()) db.addParam("Predicate", "Yes");
748 db
749 .addParam("Input", *(getInput().getType()))
750 .addParam("Filter", *(getFilter().getType()))
751 .addParam("Bias", *(getBias().getType()))
752 .addParam("Kernels", getKernels())
753 .addParam("Strides", getStrides())
754 .addParam("Pads", getPads())
755 .addParam("Group", getGroup())
756 .addParam("Users", getNumUsers());
757 db.addParam("Result", *(getResult().getType()));
758 return db;
759}
760
761void Convolution3DNode::visit(Node *parent, NodeWalker *visitor) {
762 if (!visitor->shouldVisit(parent, this)) { return; }
763 visitor->pre(parent, this);
764if (hasPredicate())
765 getPredicate().getNode()->visit(this, visitor);
766 getInput().getNode()->visit(this, visitor);
767 getFilter().getNode()->visit(this, visitor);
768 getBias().getNode()->visit(this, visitor);
769 visitor->post(parent, this);
770}
771
772bool Convolution3DNode::isEqual(const Convolution3DNode &other) const {
773 return true &&
774 Input_ == other.Input_ &&
775 Filter_ == other.Filter_ &&
776 Bias_ == other.Bias_ &&
777 predicate_ == other.predicate_ &&
778 Kernels_ == other.Kernels_ &&
779 Strides_ == other.Strides_ &&
780 Pads_ == other.Pads_ &&
781 Group_ == other.Group_ &&
782 getType(0) == other.getType(0);
783}
784
785Node* Convolution3DNode::clone() const {
786 return new Convolution3DNode(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getKernels(), getStrides(), getPads(), getGroup());
787}
788
789llvm::hash_code Convolution3DNode::getHash() const {
790 return llvm::hash_combine(
791 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
792 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
793 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
794 Group_,
795 Input_,
796 Filter_,
797 Bias_);
798}
799
800Convolution3DGradNode *Convolution3DNode::getGrad(GraphGradMapper &builder) {
801 auto *x = new Convolution3DGradNode(getName().str() + "_grad", getInput(), getFilter(), getBias(), getResult(), builder.getGradient(getResult()), getKernels(), getStrides(), getPads(), getGroup());
802 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
803 builder.addGradient(getFilter(), x->getGradOfInputNamedFilter());
804 builder.addGradient(getBias(), x->getGradOfInputNamedBias());
805 return x;
806}
807
808unsigned MaxPoolGradNode::getNumInputs() const {
809 return 5;
810}
811
812std::string MaxPoolGradNode::getInputName(unsigned idx) const {
813 if (idx == 0) { return "Input"; }
814 if (idx == 1) { return "OriginalOutputForResult"; }
815 if (idx == 2) { return "GradOfOriginalOutputNamedResult"; }
816 if (idx == 3) { return "OriginalOutputForArgmax"; }
817 if (idx == 4) { return "GradOfOriginalOutputNamedArgmax"; }
818 idx -= 5;
819 llvm_unreachable("Invalid index");
820}
821
822NodeValue MaxPoolGradNode::getNthInput(unsigned idx) {
823 if (idx == 0) { return Input_; }
824 if (idx == 1) { return OriginalOutputForResult_; }
825 if (idx == 2) { return GradOfOriginalOutputNamedResult_; }
826 if (idx == 3) { return OriginalOutputForArgmax_; }
827 if (idx == 4) { return GradOfOriginalOutputNamedArgmax_; }
828 idx -= 5;
829 llvm_unreachable("Invalid index");
830}
831
832void MaxPoolGradNode::setNthInput(unsigned idx, NodeValue val) {
833 if (idx == 0) { Input_ = val; return; }
834 if (idx == 1) { OriginalOutputForResult_ = val; return; }
835 if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; }
836 if (idx == 3) { OriginalOutputForArgmax_ = val; return; }
837 if (idx == 4) { GradOfOriginalOutputNamedArgmax_ = val; return; }
838 idx -= 5;
839 llvm_unreachable("Invalid index");
840}
841
842llvm::StringRef MaxPoolGradNode::getOutputName(unsigned idx) const {
843 if (idx == 0) { return "GradOfInputNamedInput"; }
844 llvm_unreachable("Invalid index");
845}
846
847std::string MaxPoolGradNode::getDebugDesc() const {
848 DescriptionBuilder db(getKindName());
849 db.addParam("Name", separateString(getName(), 100, "\n"));
850 if (hasPredicate()) db.addParam("Predicate", "Yes");
851 db
852 .addParam("Input", *(getInput().getType()))
853 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
854 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
855 .addParam("OriginalOutputForArgmax", *(getOriginalOutputForArgmax().getType()))
856 .addParam("GradOfOriginalOutputNamedArgmax", *(getGradOfOriginalOutputNamedArgmax().getType()))
857 .addParam("Kernels", getKernels())
858 .addParam("Strides", getStrides())
859 .addParam("Pads", getPads())
860 .addParam("Layout", getLayout())
861 .addParam("Users", getNumUsers());
862 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
863 return db;
864}
865
866void MaxPoolGradNode::visit(Node *parent, NodeWalker *visitor) {
867 if (!visitor->shouldVisit(parent, this)) { return; }
868 visitor->pre(parent, this);
869if (hasPredicate())
870 getPredicate().getNode()->visit(this, visitor);
871 getInput().getNode()->visit(this, visitor);
872 getOriginalOutputForResult().getNode()->visit(this, visitor);
873 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
874 getOriginalOutputForArgmax().getNode()->visit(this, visitor);
875 getGradOfOriginalOutputNamedArgmax().getNode()->visit(this, visitor);
876 visitor->post(parent, this);
877}
878
879bool MaxPoolGradNode::isEqual(const MaxPoolGradNode &other) const {
880 return true &&
881 Input_ == other.Input_ &&
882 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
883 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
884 OriginalOutputForArgmax_ == other.OriginalOutputForArgmax_ &&
885 GradOfOriginalOutputNamedArgmax_ == other.GradOfOriginalOutputNamedArgmax_ &&
886 predicate_ == other.predicate_ &&
887 Kernels_ == other.Kernels_ &&
888 Strides_ == other.Strides_ &&
889 Pads_ == other.Pads_ &&
890 Layout_ == other.Layout_ &&
891 getType(0) == other.getType(0);
892}
893
894Node* MaxPoolGradNode::clone() const {
895 return new MaxPoolGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getOriginalOutputForArgmax(), getGradOfOriginalOutputNamedArgmax(), getKernels(), getStrides(), getPads(), getLayout());
896}
897
898llvm::hash_code MaxPoolGradNode::getHash() const {
899 return llvm::hash_combine(
900 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
901 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
902 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
903 Layout_,
904 Input_,
905 OriginalOutputForResult_,
906 GradOfOriginalOutputNamedResult_,
907 OriginalOutputForArgmax_,
908 GradOfOriginalOutputNamedArgmax_);
909}
910
911unsigned MaxPoolNode::getNumInputs() const {
912 return 1;
913}
914
915std::string MaxPoolNode::getInputName(unsigned idx) const {
916 if (idx == 0) { return "Input"; }
917 idx -= 1;
918 llvm_unreachable("Invalid index");
919}
920
921NodeValue MaxPoolNode::getNthInput(unsigned idx) {
922 if (idx == 0) { return Input_; }
923 idx -= 1;
924 llvm_unreachable("Invalid index");
925}
926
927void MaxPoolNode::setNthInput(unsigned idx, NodeValue val) {
928 if (idx == 0) { Input_ = val; return; }
929 idx -= 1;
930 llvm_unreachable("Invalid index");
931}
932
933llvm::StringRef MaxPoolNode::getOutputName(unsigned idx) const {
934 if (idx == 0) { return "Result"; }
935 if (idx == 1) { return "Argmax"; }
936 llvm_unreachable("Invalid index");
937}
938
939std::string MaxPoolNode::getDebugDesc() const {
940 DescriptionBuilder db(getKindName());
941 db.addParam("Name", separateString(getName(), 100, "\n"));
942 if (hasPredicate()) db.addParam("Predicate", "Yes");
943 db
944 .addParam("Input", *(getInput().getType()))
945 .addParam("Kernels", getKernels())
946 .addParam("Strides", getStrides())
947 .addParam("Pads", getPads())
948 .addParam("Layout", getLayout())
949 .addParam("Users", getNumUsers());
950 db.addParam("Result", *(getResult().getType()));
951 db.addParam("Argmax", *(getArgmax().getType()));
952 return db;
953}
954
955void MaxPoolNode::visit(Node *parent, NodeWalker *visitor) {
956 if (!visitor->shouldVisit(parent, this)) { return; }
957 visitor->pre(parent, this);
958if (hasPredicate())
959 getPredicate().getNode()->visit(this, visitor);
960 getInput().getNode()->visit(this, visitor);
961 visitor->post(parent, this);
962}
963
964bool MaxPoolNode::isEqual(const MaxPoolNode &other) const {
965 return true &&
966 Input_ == other.Input_ &&
967 predicate_ == other.predicate_ &&
968 Kernels_ == other.Kernels_ &&
969 Strides_ == other.Strides_ &&
970 Pads_ == other.Pads_ &&
971 Layout_ == other.Layout_ &&
972 getType(0) == other.getType(0) &&
973 getType(1) == other.getType(1);
974}
975
976Node* MaxPoolNode::clone() const {
977 return new MaxPoolNode(getName(), getResult().getType(), getArgmax().getType(), getInput(), getKernels(), getStrides(), getPads(), getLayout());
978}
979
980llvm::hash_code MaxPoolNode::getHash() const {
981 return llvm::hash_combine(
982 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
983 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
984 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
985 Layout_,
986 Input_);
987}
988
989MaxPoolGradNode *MaxPoolNode::getGrad(GraphGradMapper &builder) {
990 auto *x = new MaxPoolGradNode(getName().str() + "_grad", getInput(), getResult(), builder.getGradient(getResult()), getArgmax(), builder.getGradient(getArgmax()), getKernels(), getStrides(), getPads(), getLayout());
991 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
992 return x;
993}
994
995unsigned ArgMaxNode::getNumInputs() const {
996 return 1;
997}
998
999std::string ArgMaxNode::getInputName(unsigned idx) const {
1000 if (idx == 0) { return "Input"; }
1001 idx -= 1;
1002 llvm_unreachable("Invalid index");
1003}
1004
1005NodeValue ArgMaxNode::getNthInput(unsigned idx) {
1006 if (idx == 0) { return Input_; }
1007 idx -= 1;
1008 llvm_unreachable("Invalid index");
1009}
1010
1011void ArgMaxNode::setNthInput(unsigned idx, NodeValue val) {
1012 if (idx == 0) { Input_ = val; return; }
1013 idx -= 1;
1014 llvm_unreachable("Invalid index");
1015}
1016
1017llvm::StringRef ArgMaxNode::getOutputName(unsigned idx) const {
1018 if (idx == 0) { return "Result"; }
1019 llvm_unreachable("Invalid index");
1020}
1021
1022std::string ArgMaxNode::getDebugDesc() const {
1023 DescriptionBuilder db(getKindName());
1024 db.addParam("Name", separateString(getName(), 100, "\n"));
1025 if (hasPredicate()) db.addParam("Predicate", "Yes");
1026 db
1027 .addParam("Input", *(getInput().getType()))
1028 .addParam("Axis", getAxis())
1029 .addParam("KeepDims", getKeepDims())
1030 .addParam("Users", getNumUsers());
1031 db.addParam("Result", *(getResult().getType()));
1032 return db;
1033}
1034
1035void ArgMaxNode::visit(Node *parent, NodeWalker *visitor) {
1036 if (!visitor->shouldVisit(parent, this)) { return; }
1037 visitor->pre(parent, this);
1038if (hasPredicate())
1039 getPredicate().getNode()->visit(this, visitor);
1040 getInput().getNode()->visit(this, visitor);
1041 visitor->post(parent, this);
1042}
1043
1044bool ArgMaxNode::isEqual(const ArgMaxNode &other) const {
1045 return true &&
1046 Input_ == other.Input_ &&
1047 predicate_ == other.predicate_ &&
1048 Axis_ == other.Axis_ &&
1049 KeepDims_ == other.KeepDims_ &&
1050 getType(0) == other.getType(0);
1051}
1052
1053Node* ArgMaxNode::clone() const {
1054 return new ArgMaxNode(getName(), getResult().getType(), getInput(), getAxis(), getKeepDims());
1055}
1056
1057llvm::hash_code ArgMaxNode::getHash() const {
1058 return llvm::hash_combine(
1059 Axis_,
1060 KeepDims_,
1061 Input_);
1062}
1063
1064unsigned ArgMinNode::getNumInputs() const {
1065 return 1;
1066}
1067
1068std::string ArgMinNode::getInputName(unsigned idx) const {
1069 if (idx == 0) { return "Input"; }
1070 idx -= 1;
1071 llvm_unreachable("Invalid index");
1072}
1073
1074NodeValue ArgMinNode::getNthInput(unsigned idx) {
1075 if (idx == 0) { return Input_; }
1076 idx -= 1;
1077 llvm_unreachable("Invalid index");
1078}
1079
1080void ArgMinNode::setNthInput(unsigned idx, NodeValue val) {
1081 if (idx == 0) { Input_ = val; return; }
1082 idx -= 1;
1083 llvm_unreachable("Invalid index");
1084}
1085
1086llvm::StringRef ArgMinNode::getOutputName(unsigned idx) const {
1087 if (idx == 0) { return "Result"; }
1088 llvm_unreachable("Invalid index");
1089}
1090
1091std::string ArgMinNode::getDebugDesc() const {
1092 DescriptionBuilder db(getKindName());
1093 db.addParam("Name", separateString(getName(), 100, "\n"));
1094 if (hasPredicate()) db.addParam("Predicate", "Yes");
1095 db
1096 .addParam("Input", *(getInput().getType()))
1097 .addParam("Axis", getAxis())
1098 .addParam("KeepDims", getKeepDims())
1099 .addParam("Users", getNumUsers());
1100 db.addParam("Result", *(getResult().getType()));
1101 return db;
1102}
1103
1104void ArgMinNode::visit(Node *parent, NodeWalker *visitor) {
1105 if (!visitor->shouldVisit(parent, this)) { return; }
1106 visitor->pre(parent, this);
1107if (hasPredicate())
1108 getPredicate().getNode()->visit(this, visitor);
1109 getInput().getNode()->visit(this, visitor);
1110 visitor->post(parent, this);
1111}
1112
1113bool ArgMinNode::isEqual(const ArgMinNode &other) const {
1114 return true &&
1115 Input_ == other.Input_ &&
1116 predicate_ == other.predicate_ &&
1117 Axis_ == other.Axis_ &&
1118 KeepDims_ == other.KeepDims_ &&
1119 getType(0) == other.getType(0);
1120}
1121
1122Node* ArgMinNode::clone() const {
1123 return new ArgMinNode(getName(), getResult().getType(), getInput(), getAxis(), getKeepDims());
1124}
1125
1126llvm::hash_code ArgMinNode::getHash() const {
1127 return llvm::hash_combine(
1128 Axis_,
1129 KeepDims_,
1130 Input_);
1131}
1132
1133unsigned AvgPoolGradNode::getNumInputs() const {
1134 return 3;
1135}
1136
1137std::string AvgPoolGradNode::getInputName(unsigned idx) const {
1138 if (idx == 0) { return "Input"; }
1139 if (idx == 1) { return "OriginalOutputForResult"; }
1140 if (idx == 2) { return "GradOfOriginalOutputNamedResult"; }
1141 idx -= 3;
1142 llvm_unreachable("Invalid index");
1143}
1144
1145NodeValue AvgPoolGradNode::getNthInput(unsigned idx) {
1146 if (idx == 0) { return Input_; }
1147 if (idx == 1) { return OriginalOutputForResult_; }
1148 if (idx == 2) { return GradOfOriginalOutputNamedResult_; }
1149 idx -= 3;
1150 llvm_unreachable("Invalid index");
1151}
1152
1153void AvgPoolGradNode::setNthInput(unsigned idx, NodeValue val) {
1154 if (idx == 0) { Input_ = val; return; }
1155 if (idx == 1) { OriginalOutputForResult_ = val; return; }
1156 if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; }
1157 idx -= 3;
1158 llvm_unreachable("Invalid index");
1159}
1160
1161llvm::StringRef AvgPoolGradNode::getOutputName(unsigned idx) const {
1162 if (idx == 0) { return "GradOfInputNamedInput"; }
1163 llvm_unreachable("Invalid index");
1164}
1165
1166std::string AvgPoolGradNode::getDebugDesc() const {
1167 DescriptionBuilder db(getKindName());
1168 db.addParam("Name", separateString(getName(), 100, "\n"));
1169 if (hasPredicate()) db.addParam("Predicate", "Yes");
1170 db
1171 .addParam("Input", *(getInput().getType()))
1172 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
1173 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
1174 .addParam("Kernels", getKernels())
1175 .addParam("Strides", getStrides())
1176 .addParam("Pads", getPads())
1177 .addParam("Layout", getLayout())
1178 .addParam("CountIncludePads", getCountIncludePads())
1179 .addParam("Users", getNumUsers());
1180 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
1181 return db;
1182}
1183
1184void AvgPoolGradNode::visit(Node *parent, NodeWalker *visitor) {
1185 if (!visitor->shouldVisit(parent, this)) { return; }
1186 visitor->pre(parent, this);
1187if (hasPredicate())
1188 getPredicate().getNode()->visit(this, visitor);
1189 getInput().getNode()->visit(this, visitor);
1190 getOriginalOutputForResult().getNode()->visit(this, visitor);
1191 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
1192 visitor->post(parent, this);
1193}
1194
1195bool AvgPoolGradNode::isEqual(const AvgPoolGradNode &other) const {
1196 return true &&
1197 Input_ == other.Input_ &&
1198 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
1199 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
1200 predicate_ == other.predicate_ &&
1201 Kernels_ == other.Kernels_ &&
1202 Strides_ == other.Strides_ &&
1203 Pads_ == other.Pads_ &&
1204 Layout_ == other.Layout_ &&
1205 CountIncludePads_ == other.CountIncludePads_ &&
1206 getType(0) == other.getType(0);
1207}
1208
1209Node* AvgPoolGradNode::clone() const {
1210 return new AvgPoolGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getKernels(), getStrides(), getPads(), getLayout(), getCountIncludePads());
1211}
1212
1213llvm::hash_code AvgPoolGradNode::getHash() const {
1214 return llvm::hash_combine(
1215 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
1216 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
1217 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
1218 Layout_,
1219 CountIncludePads_,
1220 Input_,
1221 OriginalOutputForResult_,
1222 GradOfOriginalOutputNamedResult_);
1223}
1224
1225unsigned AvgPoolNode::getNumInputs() const {
1226 return 1;
1227}
1228
1229std::string AvgPoolNode::getInputName(unsigned idx) const {
1230 if (idx == 0) { return "Input"; }
1231 idx -= 1;
1232 llvm_unreachable("Invalid index");
1233}
1234
1235NodeValue AvgPoolNode::getNthInput(unsigned idx) {
1236 if (idx == 0) { return Input_; }
1237 idx -= 1;
1238 llvm_unreachable("Invalid index");
1239}
1240
1241void AvgPoolNode::setNthInput(unsigned idx, NodeValue val) {
1242 if (idx == 0) { Input_ = val; return; }
1243 idx -= 1;
1244 llvm_unreachable("Invalid index");
1245}
1246
1247llvm::StringRef AvgPoolNode::getOutputName(unsigned idx) const {
1248 if (idx == 0) { return "Result"; }
1249 llvm_unreachable("Invalid index");
1250}
1251
1252std::string AvgPoolNode::getDebugDesc() const {
1253 DescriptionBuilder db(getKindName());
1254 db.addParam("Name", separateString(getName(), 100, "\n"));
1255 if (hasPredicate()) db.addParam("Predicate", "Yes");
1256 db
1257 .addParam("Input", *(getInput().getType()))
1258 .addParam("Kernels", getKernels())
1259 .addParam("Strides", getStrides())
1260 .addParam("Pads", getPads())
1261 .addParam("Layout", getLayout())
1262 .addParam("CountIncludePads", getCountIncludePads())
1263 .addParam("Users", getNumUsers());
1264 db.addParam("Result", *(getResult().getType()));
1265 return db;
1266}
1267
1268void AvgPoolNode::visit(Node *parent, NodeWalker *visitor) {
1269 if (!visitor->shouldVisit(parent, this)) { return; }
1270 visitor->pre(parent, this);
1271if (hasPredicate())
1272 getPredicate().getNode()->visit(this, visitor);
1273 getInput().getNode()->visit(this, visitor);
1274 visitor->post(parent, this);
1275}
1276
1277bool AvgPoolNode::isEqual(const AvgPoolNode &other) const {
1278 return true &&
1279 Input_ == other.Input_ &&
1280 predicate_ == other.predicate_ &&
1281 Kernels_ == other.Kernels_ &&
1282 Strides_ == other.Strides_ &&
1283 Pads_ == other.Pads_ &&
1284 Layout_ == other.Layout_ &&
1285 CountIncludePads_ == other.CountIncludePads_ &&
1286 getType(0) == other.getType(0);
1287}
1288
1289Node* AvgPoolNode::clone() const {
1290 return new AvgPoolNode(getName(), getResult().getType(), getInput(), getKernels(), getStrides(), getPads(), getLayout(), getCountIncludePads());
1291}
1292
1293llvm::hash_code AvgPoolNode::getHash() const {
1294 return llvm::hash_combine(
1295 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
1296 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
1297 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
1298 Layout_,
1299 CountIncludePads_,
1300 Input_);
1301}
1302
1303AvgPoolGradNode *AvgPoolNode::getGrad(GraphGradMapper &builder) {
1304 auto *x = new AvgPoolGradNode(getName().str() + "_grad", getInput(), getResult(), builder.getGradient(getResult()), getKernels(), getStrides(), getPads(), getLayout(), getCountIncludePads());
1305 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
1306 return x;
1307}
1308
1309unsigned AdaptiveAvgPoolGradNode::getNumInputs() const {
1310 return 3;
1311}
1312
1313std::string AdaptiveAvgPoolGradNode::getInputName(unsigned idx) const {
1314 if (idx == 0) { return "Input"; }
1315 if (idx == 1) { return "OriginalOutputForResult"; }
1316 if (idx == 2) { return "GradOfOriginalOutputNamedResult"; }
1317 idx -= 3;
1318 llvm_unreachable("Invalid index");
1319}
1320
1321NodeValue AdaptiveAvgPoolGradNode::getNthInput(unsigned idx) {
1322 if (idx == 0) { return Input_; }
1323 if (idx == 1) { return OriginalOutputForResult_; }
1324 if (idx == 2) { return GradOfOriginalOutputNamedResult_; }
1325 idx -= 3;
1326 llvm_unreachable("Invalid index");
1327}
1328
1329void AdaptiveAvgPoolGradNode::setNthInput(unsigned idx, NodeValue val) {
1330 if (idx == 0) { Input_ = val; return; }
1331 if (idx == 1) { OriginalOutputForResult_ = val; return; }
1332 if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; }
1333 idx -= 3;
1334 llvm_unreachable("Invalid index");
1335}
1336
1337llvm::StringRef AdaptiveAvgPoolGradNode::getOutputName(unsigned idx) const {
1338 if (idx == 0) { return "GradOfInputNamedInput"; }
1339 llvm_unreachable("Invalid index");
1340}
1341
1342std::string AdaptiveAvgPoolGradNode::getDebugDesc() const {
1343 DescriptionBuilder db(getKindName());
1344 db.addParam("Name", separateString(getName(), 100, "\n"));
1345 if (hasPredicate()) db.addParam("Predicate", "Yes");
1346 db
1347 .addParam("Input", *(getInput().getType()))
1348 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
1349 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
1350 .addParam("Users", getNumUsers());
1351 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
1352 return db;
1353}
1354
1355void AdaptiveAvgPoolGradNode::visit(Node *parent, NodeWalker *visitor) {
1356 if (!visitor->shouldVisit(parent, this)) { return; }
1357 visitor->pre(parent, this);
1358if (hasPredicate())
1359 getPredicate().getNode()->visit(this, visitor);
1360 getInput().getNode()->visit(this, visitor);
1361 getOriginalOutputForResult().getNode()->visit(this, visitor);
1362 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
1363 visitor->post(parent, this);
1364}
1365
1366bool AdaptiveAvgPoolGradNode::isEqual(const AdaptiveAvgPoolGradNode &other) const {
1367 return true &&
1368 Input_ == other.Input_ &&
1369 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
1370 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
1371 predicate_ == other.predicate_ &&
1372 getType(0) == other.getType(0);
1373}
1374
1375Node* AdaptiveAvgPoolGradNode::clone() const {
1376 return new AdaptiveAvgPoolGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
1377}
1378
1379llvm::hash_code AdaptiveAvgPoolGradNode::getHash() const {
1380 return llvm::hash_combine(
1381 Input_,
1382 OriginalOutputForResult_,
1383 GradOfOriginalOutputNamedResult_);
1384}
1385
1386unsigned AdaptiveAvgPoolNode::getNumInputs() const {
1387 return 1;
1388}
1389
1390std::string AdaptiveAvgPoolNode::getInputName(unsigned idx) const {
1391 if (idx == 0) { return "Input"; }
1392 idx -= 1;
1393 llvm_unreachable("Invalid index");
1394}
1395
1396NodeValue AdaptiveAvgPoolNode::getNthInput(unsigned idx) {
1397 if (idx == 0) { return Input_; }
1398 idx -= 1;
1399 llvm_unreachable("Invalid index");
1400}
1401
1402void AdaptiveAvgPoolNode::setNthInput(unsigned idx, NodeValue val) {
1403 if (idx == 0) { Input_ = val; return; }
1404 idx -= 1;
1405 llvm_unreachable("Invalid index");
1406}
1407
1408llvm::StringRef AdaptiveAvgPoolNode::getOutputName(unsigned idx) const {
1409 if (idx == 0) { return "Result"; }
1410 llvm_unreachable("Invalid index");
1411}
1412
1413std::string AdaptiveAvgPoolNode::getDebugDesc() const {
1414 DescriptionBuilder db(getKindName());
1415 db.addParam("Name", separateString(getName(), 100, "\n"));
1416 if (hasPredicate()) db.addParam("Predicate", "Yes");
1417 db
1418 .addParam("Input", *(getInput().getType()))
1419 .addParam("Users", getNumUsers());
1420 db.addParam("Result", *(getResult().getType()));
1421 return db;
1422}
1423
1424void AdaptiveAvgPoolNode::visit(Node *parent, NodeWalker *visitor) {
1425 if (!visitor->shouldVisit(parent, this)) { return; }
1426 visitor->pre(parent, this);
1427if (hasPredicate())
1428 getPredicate().getNode()->visit(this, visitor);
1429 getInput().getNode()->visit(this, visitor);
1430 visitor->post(parent, this);
1431}
1432
1433bool AdaptiveAvgPoolNode::isEqual(const AdaptiveAvgPoolNode &other) const {
1434 return true &&
1435 Input_ == other.Input_ &&
1436 predicate_ == other.predicate_ &&
1437 getType(0) == other.getType(0);
1438}
1439
1440Node* AdaptiveAvgPoolNode::clone() const {
1441 return new AdaptiveAvgPoolNode(getName(), getResult().getType(), getInput());
1442}
1443
1444llvm::hash_code AdaptiveAvgPoolNode::getHash() const {
1445 return llvm::hash_combine(
1446 Input_);
1447}
1448
1449AdaptiveAvgPoolGradNode *AdaptiveAvgPoolNode::getGrad(GraphGradMapper &builder) {
1450 auto *x = new AdaptiveAvgPoolGradNode(getName().str() + "_grad", getInput(), getResult(), builder.getGradient(getResult()));
1451 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
1452 return x;
1453}
1454
1455unsigned GemmNode::getNumInputs() const {
1456 return 3;
1457}
1458
1459std::string GemmNode::getInputName(unsigned idx) const {
1460 if (idx == 0) { return "A"; }
1461 if (idx == 1) { return "B"; }
1462 if (idx == 2) { return "C"; }
1463 idx -= 3;
1464 llvm_unreachable("Invalid index");
1465}
1466
1467NodeValue GemmNode::getNthInput(unsigned idx) {
1468 if (idx == 0) { return A_; }
1469 if (idx == 1) { return B_; }
1470 if (idx == 2) { return C_; }
1471 idx -= 3;
1472 llvm_unreachable("Invalid index");
1473}
1474
1475void GemmNode::setNthInput(unsigned idx, NodeValue val) {
1476 if (idx == 0) { A_ = val; return; }
1477 if (idx == 1) { B_ = val; return; }
1478 if (idx == 2) { C_ = val; return; }
1479 idx -= 3;
1480 llvm_unreachable("Invalid index");
1481}
1482
1483llvm::StringRef GemmNode::getOutputName(unsigned idx) const {
1484 if (idx == 0) { return "Result"; }
1485 llvm_unreachable("Invalid index");
1486}
1487
1488std::string GemmNode::getDebugDesc() const {
1489 DescriptionBuilder db(getKindName());
1490 db.addParam("Name", separateString(getName(), 100, "\n"));
1491 if (hasPredicate()) db.addParam("Predicate", "Yes");
1492 db
1493 .addParam("A", *(getA().getType()))
1494 .addParam("B", *(getB().getType()))
1495 .addParam("C", *(getC().getType()))
1496 .addParam("Alpha", getAlpha())
1497 .addParam("Beta", getBeta())
1498 .addParam("TransposeA", getTransposeA())
1499 .addParam("TransposeB", getTransposeB())
1500 .addParam("Users", getNumUsers());
1501 db.addParam("Result", *(getResult().getType()));
1502 return db;
1503}
1504
1505void GemmNode::visit(Node *parent, NodeWalker *visitor) {
1506 if (!visitor->shouldVisit(parent, this)) { return; }
1507 visitor->pre(parent, this);
1508if (hasPredicate())
1509 getPredicate().getNode()->visit(this, visitor);
1510 getA().getNode()->visit(this, visitor);
1511 getB().getNode()->visit(this, visitor);
1512 getC().getNode()->visit(this, visitor);
1513 visitor->post(parent, this);
1514}
1515
1516bool GemmNode::isEqual(const GemmNode &other) const {
1517 return true &&
1518 A_ == other.A_ &&
1519 B_ == other.B_ &&
1520 C_ == other.C_ &&
1521 predicate_ == other.predicate_ &&
1522 Alpha_ == other.Alpha_ &&
1523 Beta_ == other.Beta_ &&
1524 TransposeA_ == other.TransposeA_ &&
1525 TransposeB_ == other.TransposeB_ &&
1526 getType(0) == other.getType(0);
1527}
1528
1529Node* GemmNode::clone() const {
1530 return new GemmNode(getName(), getResult().getType(), getA(), getB(), getC(), getAlpha(), getBeta(), getTransposeA(), getTransposeB());
1531}
1532
1533llvm::hash_code GemmNode::getHash() const {
1534 return llvm::hash_combine(
1535 toBinary(Alpha_),
1536 toBinary(Beta_),
1537 TransposeA_,
1538 TransposeB_,
1539 A_,
1540 B_,
1541 C_);
1542}
1543
1544unsigned FullyConnectedGradNode::getNumInputs() const {
1545 return 5;
1546}
1547
1548std::string FullyConnectedGradNode::getInputName(unsigned idx) const {
1549 if (idx == 0) { return "Input"; }
1550 if (idx == 1) { return "Weights"; }
1551 if (idx == 2) { return "Bias"; }
1552 if (idx == 3) { return "OriginalOutputForResult"; }
1553 if (idx == 4) { return "GradOfOriginalOutputNamedResult"; }
1554 idx -= 5;
1555 llvm_unreachable("Invalid index");
1556}
1557
1558NodeValue FullyConnectedGradNode::getNthInput(unsigned idx) {
1559 if (idx == 0) { return Input_; }
1560 if (idx == 1) { return Weights_; }
1561 if (idx == 2) { return Bias_; }
1562 if (idx == 3) { return OriginalOutputForResult_; }
1563 if (idx == 4) { return GradOfOriginalOutputNamedResult_; }
1564 idx -= 5;
1565 llvm_unreachable("Invalid index");
1566}
1567
1568void FullyConnectedGradNode::setNthInput(unsigned idx, NodeValue val) {
1569 if (idx == 0) { Input_ = val; return; }
1570 if (idx == 1) { Weights_ = val; return; }
1571 if (idx == 2) { Bias_ = val; return; }
1572 if (idx == 3) { OriginalOutputForResult_ = val; return; }
1573 if (idx == 4) { GradOfOriginalOutputNamedResult_ = val; return; }
1574 idx -= 5;
1575 llvm_unreachable("Invalid index");
1576}
1577
1578llvm::StringRef FullyConnectedGradNode::getOutputName(unsigned idx) const {
1579 if (idx == 0) { return "GradOfInputNamedInput"; }
1580 if (idx == 1) { return "GradOfInputNamedWeights"; }
1581 if (idx == 2) { return "GradOfInputNamedBias"; }
1582 llvm_unreachable("Invalid index");
1583}
1584
1585std::string FullyConnectedGradNode::getDebugDesc() const {
1586 DescriptionBuilder db(getKindName());
1587 db.addParam("Name", separateString(getName(), 100, "\n"));
1588 if (hasPredicate()) db.addParam("Predicate", "Yes");
1589 db
1590 .addParam("Input", *(getInput().getType()))
1591 .addParam("Weights", *(getWeights().getType()))
1592 .addParam("Bias", *(getBias().getType()))
1593 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
1594 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
1595 .addParam("Users", getNumUsers());
1596 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
1597 db.addParam("GradOfInputNamedWeights", *(getGradOfInputNamedWeights().getType()));
1598 db.addParam("GradOfInputNamedBias", *(getGradOfInputNamedBias().getType()));
1599 return db;
1600}
1601
1602void FullyConnectedGradNode::visit(Node *parent, NodeWalker *visitor) {
1603 if (!visitor->shouldVisit(parent, this)) { return; }
1604 visitor->pre(parent, this);
1605if (hasPredicate())
1606 getPredicate().getNode()->visit(this, visitor);
1607 getInput().getNode()->visit(this, visitor);
1608 getWeights().getNode()->visit(this, visitor);
1609 getBias().getNode()->visit(this, visitor);
1610 getOriginalOutputForResult().getNode()->visit(this, visitor);
1611 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
1612 visitor->post(parent, this);
1613}
1614
1615bool FullyConnectedGradNode::isEqual(const FullyConnectedGradNode &other) const {
1616 return true &&
1617 Input_ == other.Input_ &&
1618 Weights_ == other.Weights_ &&
1619 Bias_ == other.Bias_ &&
1620 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
1621 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
1622 predicate_ == other.predicate_ &&
1623 getType(0) == other.getType(0) &&
1624 getType(1) == other.getType(1) &&
1625 getType(2) == other.getType(2);
1626}
1627
1628Node* FullyConnectedGradNode::clone() const {
1629 return new FullyConnectedGradNode(getName(), getInput(), getWeights(), getBias(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
1630}
1631
1632llvm::hash_code FullyConnectedGradNode::getHash() const {
1633 return llvm::hash_combine(
1634 Input_,
1635 Weights_,
1636 Bias_,
1637 OriginalOutputForResult_,
1638 GradOfOriginalOutputNamedResult_);
1639}
1640
1641unsigned FullyConnectedNode::getNumInputs() const {
1642 return 3;
1643}
1644
1645std::string FullyConnectedNode::getInputName(unsigned idx) const {
1646 if (idx == 0) { return "Input"; }
1647 if (idx == 1) { return "Weights"; }
1648 if (idx == 2) { return "Bias"; }
1649 idx -= 3;
1650 llvm_unreachable("Invalid index");
1651}
1652
1653NodeValue FullyConnectedNode::getNthInput(unsigned idx) {
1654 if (idx == 0) { return Input_; }
1655 if (idx == 1) { return Weights_; }
1656 if (idx == 2) { return Bias_; }
1657 idx -= 3;
1658 llvm_unreachable("Invalid index");
1659}
1660
1661void FullyConnectedNode::setNthInput(unsigned idx, NodeValue val) {
1662 if (idx == 0) { Input_ = val; return; }
1663 if (idx == 1) { Weights_ = val; return; }
1664 if (idx == 2) { Bias_ = val; return; }
1665 idx -= 3;
1666 llvm_unreachable("Invalid index");
1667}
1668
1669llvm::StringRef FullyConnectedNode::getOutputName(unsigned idx) const {
1670 if (idx == 0) { return "Result"; }
1671 llvm_unreachable("Invalid index");
1672}
1673
1674std::string FullyConnectedNode::getDebugDesc() const {
1675 DescriptionBuilder db(getKindName());
1676 db.addParam("Name", separateString(getName(), 100, "\n"));
1677 if (hasPredicate()) db.addParam("Predicate", "Yes");
1678 db
1679 .addParam("Input", *(getInput().getType()))
1680 .addParam("Weights", *(getWeights().getType()))
1681 .addParam("Bias", *(getBias().getType()))
1682 .addParam("Users", getNumUsers());
1683 db.addParam("Result", *(getResult().getType()));
1684 return db;
1685}
1686
1687void FullyConnectedNode::visit(Node *parent, NodeWalker *visitor) {
1688 if (!visitor->shouldVisit(parent, this)) { return; }
1689 visitor->pre(parent, this);
1690if (hasPredicate())
1691 getPredicate().getNode()->visit(this, visitor);
1692 getInput().getNode()->visit(this, visitor);
1693 getWeights().getNode()->visit(this, visitor);
1694 getBias().getNode()->visit(this, visitor);
1695 visitor->post(parent, this);
1696}
1697
1698bool FullyConnectedNode::isEqual(const FullyConnectedNode &other) const {
1699 return true &&
1700 Input_ == other.Input_ &&
1701 Weights_ == other.Weights_ &&
1702 Bias_ == other.Bias_ &&
1703 predicate_ == other.predicate_ &&
1704 getType(0) == other.getType(0);
1705}
1706
1707Node* FullyConnectedNode::clone() const {
1708 return new FullyConnectedNode(getName(), getResult().getType(), getInput(), getWeights(), getBias());
1709}
1710
1711llvm::hash_code FullyConnectedNode::getHash() const {
1712 return llvm::hash_combine(
1713 Input_,
1714 Weights_,
1715 Bias_);
1716}
1717
1718FullyConnectedGradNode *FullyConnectedNode::getGrad(GraphGradMapper &builder) {
1719 auto *x = new FullyConnectedGradNode(getName().str() + "_grad", getInput(), getWeights(), getBias(), getResult(), builder.getGradient(getResult()));
1720 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
1721 builder.addGradient(getWeights(), x->getGradOfInputNamedWeights());
1722 builder.addGradient(getBias(), x->getGradOfInputNamedBias());
1723 return x;
1724}
1725
1726unsigned RowwiseQuantizedFullyConnectedNode::getNumInputs() const {
1727 return 5;
1728}
1729
1730std::string RowwiseQuantizedFullyConnectedNode::getInputName(unsigned idx) const {
1731 if (idx == 0) { return "Input"; }
1732 if (idx == 1) { return "Weights"; }
1733 if (idx == 2) { return "Scales"; }
1734 if (idx == 3) { return "Offsets"; }
1735 if (idx == 4) { return "Bias"; }
1736 idx -= 5;
1737 llvm_unreachable("Invalid index");
1738}
1739
1740NodeValue RowwiseQuantizedFullyConnectedNode::getNthInput(unsigned idx) {
1741 if (idx == 0) { return Input_; }
1742 if (idx == 1) { return Weights_; }
1743 if (idx == 2) { return Scales_; }
1744 if (idx == 3) { return Offsets_; }
1745 if (idx == 4) { return Bias_; }
1746 idx -= 5;
1747 llvm_unreachable("Invalid index");
1748}
1749
1750void RowwiseQuantizedFullyConnectedNode::setNthInput(unsigned idx, NodeValue val) {
1751 if (idx == 0) { Input_ = val; return; }
1752 if (idx == 1) { Weights_ = val; return; }
1753 if (idx == 2) { Scales_ = val; return; }
1754 if (idx == 3) { Offsets_ = val; return; }
1755 if (idx == 4) { Bias_ = val; return; }
1756 idx -= 5;
1757 llvm_unreachable("Invalid index");
1758}
1759
1760llvm::StringRef RowwiseQuantizedFullyConnectedNode::getOutputName(unsigned idx) const {
1761 if (idx == 0) { return "Result"; }
1762 llvm_unreachable("Invalid index");
1763}
1764
1765std::string RowwiseQuantizedFullyConnectedNode::getDebugDesc() const {
1766 DescriptionBuilder db(getKindName());
1767 db.addParam("Name", separateString(getName(), 100, "\n"));
1768 if (hasPredicate()) db.addParam("Predicate", "Yes");
1769 db
1770 .addParam("Input", *(getInput().getType()))
1771 .addParam("Weights", *(getWeights().getType()))
1772 .addParam("Scales", *(getScales().getType()))
1773 .addParam("Offsets", *(getOffsets().getType()))
1774 .addParam("Bias", *(getBias().getType()))
1775 .addParam("Users", getNumUsers());
1776 db.addParam("Result", *(getResult().getType()));
1777 return db;
1778}
1779
1780void RowwiseQuantizedFullyConnectedNode::visit(Node *parent, NodeWalker *visitor) {
1781 if (!visitor->shouldVisit(parent, this)) { return; }
1782 visitor->pre(parent, this);
1783if (hasPredicate())
1784 getPredicate().getNode()->visit(this, visitor);
1785 getInput().getNode()->visit(this, visitor);
1786 getWeights().getNode()->visit(this, visitor);
1787 getScales().getNode()->visit(this, visitor);
1788 getOffsets().getNode()->visit(this, visitor);
1789 getBias().getNode()->visit(this, visitor);
1790 visitor->post(parent, this);
1791}
1792
1793bool RowwiseQuantizedFullyConnectedNode::isEqual(const RowwiseQuantizedFullyConnectedNode &other) const {
1794 return true &&
1795 Input_ == other.Input_ &&
1796 Weights_ == other.Weights_ &&
1797 Scales_ == other.Scales_ &&
1798 Offsets_ == other.Offsets_ &&
1799 Bias_ == other.Bias_ &&
1800 predicate_ == other.predicate_ &&
1801 getType(0) == other.getType(0);
1802}
1803
1804Node* RowwiseQuantizedFullyConnectedNode::clone() const {
1805 return new RowwiseQuantizedFullyConnectedNode(getName(), getResult().getType(), getInput(), getWeights(), getScales(), getOffsets(), getBias());
1806}
1807
1808llvm::hash_code RowwiseQuantizedFullyConnectedNode::getHash() const {
1809 return llvm::hash_combine(
1810 Input_,
1811 Weights_,
1812 Scales_,
1813 Offsets_,
1814 Bias_);
1815}
1816
1817unsigned DynamicQuantizedFullyConnectedNode::getNumInputs() const {
1818 return 3;
1819}
1820
1821std::string DynamicQuantizedFullyConnectedNode::getInputName(unsigned idx) const {
1822 if (idx == 0) { return "Input"; }
1823 if (idx == 1) { return "Weights"; }
1824 if (idx == 2) { return "Bias"; }
1825 idx -= 3;
1826 llvm_unreachable("Invalid index");
1827}
1828
1829NodeValue DynamicQuantizedFullyConnectedNode::getNthInput(unsigned idx) {
1830 if (idx == 0) { return Input_; }
1831 if (idx == 1) { return Weights_; }
1832 if (idx == 2) { return Bias_; }
1833 idx -= 3;
1834 llvm_unreachable("Invalid index");
1835}
1836
1837void DynamicQuantizedFullyConnectedNode::setNthInput(unsigned idx, NodeValue val) {
1838 if (idx == 0) { Input_ = val; return; }
1839 if (idx == 1) { Weights_ = val; return; }
1840 if (idx == 2) { Bias_ = val; return; }
1841 idx -= 3;
1842 llvm_unreachable("Invalid index");
1843}
1844
1845llvm::StringRef DynamicQuantizedFullyConnectedNode::getOutputName(unsigned idx) const {
1846 if (idx == 0) { return "Result"; }
1847 llvm_unreachable("Invalid index");
1848}
1849
1850std::string DynamicQuantizedFullyConnectedNode::getDebugDesc() const {
1851 DescriptionBuilder db(getKindName());
1852 db.addParam("Name", separateString(getName(), 100, "\n"));
1853 if (hasPredicate()) db.addParam("Predicate", "Yes");
1854 db
1855 .addParam("Input", *(getInput().getType()))
1856 .addParam("Weights", *(getWeights().getType()))
1857 .addParam("Bias", *(getBias().getType()))
1858 .addParam("IsSymmetric", getIsSymmetric())
1859 .addParam("IsPerBatchElement", getIsPerBatchElement())
1860 .addParam("Users", getNumUsers());
1861 db.addParam("Result", *(getResult().getType()));
1862 return db;
1863}
1864
1865void DynamicQuantizedFullyConnectedNode::visit(Node *parent, NodeWalker *visitor) {
1866 if (!visitor->shouldVisit(parent, this)) { return; }
1867 visitor->pre(parent, this);
1868if (hasPredicate())
1869 getPredicate().getNode()->visit(this, visitor);
1870 getInput().getNode()->visit(this, visitor);
1871 getWeights().getNode()->visit(this, visitor);
1872 getBias().getNode()->visit(this, visitor);
1873 visitor->post(parent, this);
1874}
1875
1876bool DynamicQuantizedFullyConnectedNode::isEqual(const DynamicQuantizedFullyConnectedNode &other) const {
1877 return true &&
1878 Input_ == other.Input_ &&
1879 Weights_ == other.Weights_ &&
1880 Bias_ == other.Bias_ &&
1881 predicate_ == other.predicate_ &&
1882 IsSymmetric_ == other.IsSymmetric_ &&
1883 IsPerBatchElement_ == other.IsPerBatchElement_ &&
1884 getType(0) == other.getType(0);
1885}
1886
1887Node* DynamicQuantizedFullyConnectedNode::clone() const {
1888 return new DynamicQuantizedFullyConnectedNode(getName(), getResult().getType(), getInput(), getWeights(), getBias(), getIsSymmetric(), getIsPerBatchElement());
1889}
1890
1891llvm::hash_code DynamicQuantizedFullyConnectedNode::getHash() const {
1892 return llvm::hash_combine(
1893 IsSymmetric_,
1894 IsPerBatchElement_,
1895 Input_,
1896 Weights_,
1897 Bias_);
1898}
1899
1900unsigned DynamicRowwiseQuantizedFullyConnectedNode::getNumInputs() const {
1901 return 5;
1902}
1903
1904std::string DynamicRowwiseQuantizedFullyConnectedNode::getInputName(unsigned idx) const {
1905 if (idx == 0) { return "Input"; }
1906 if (idx == 1) { return "Weights"; }
1907 if (idx == 2) { return "Bias"; }
1908 if (idx == 3) { return "Scales"; }
1909 if (idx == 4) { return "Offsets"; }
1910 idx -= 5;
1911 llvm_unreachable("Invalid index");
1912}
1913
1914NodeValue DynamicRowwiseQuantizedFullyConnectedNode::getNthInput(unsigned idx) {
1915 if (idx == 0) { return Input_; }
1916 if (idx == 1) { return Weights_; }
1917 if (idx == 2) { return Bias_; }
1918 if (idx == 3) { return Scales_; }
1919 if (idx == 4) { return Offsets_; }
1920 idx -= 5;
1921 llvm_unreachable("Invalid index");
1922}
1923
1924void DynamicRowwiseQuantizedFullyConnectedNode::setNthInput(unsigned idx, NodeValue val) {
1925 if (idx == 0) { Input_ = val; return; }
1926 if (idx == 1) { Weights_ = val; return; }
1927 if (idx == 2) { Bias_ = val; return; }
1928 if (idx == 3) { Scales_ = val; return; }
1929 if (idx == 4) { Offsets_ = val; return; }
1930 idx -= 5;
1931 llvm_unreachable("Invalid index");
1932}
1933
1934llvm::StringRef DynamicRowwiseQuantizedFullyConnectedNode::getOutputName(unsigned idx) const {
1935 if (idx == 0) { return "Result"; }
1936 llvm_unreachable("Invalid index");
1937}
1938
1939std::string DynamicRowwiseQuantizedFullyConnectedNode::getDebugDesc() const {
1940 DescriptionBuilder db(getKindName());
1941 db.addParam("Name", separateString(getName(), 100, "\n"));
1942 if (hasPredicate()) db.addParam("Predicate", "Yes");
1943 db
1944 .addParam("Input", *(getInput().getType()))
1945 .addParam("Weights", *(getWeights().getType()))
1946 .addParam("Bias", *(getBias().getType()))
1947 .addParam("Scales", *(getScales().getType()))
1948 .addParam("Offsets", *(getOffsets().getType()))
1949 .addParam("IsSymmetric", getIsSymmetric())
1950 .addParam("IsPerBatchElement", getIsPerBatchElement())
1951 .addParam("Users", getNumUsers());
1952 db.addParam("Result", *(getResult().getType()));
1953 return db;
1954}
1955
1956void DynamicRowwiseQuantizedFullyConnectedNode::visit(Node *parent, NodeWalker *visitor) {
1957 if (!visitor->shouldVisit(parent, this)) { return; }
1958 visitor->pre(parent, this);
1959if (hasPredicate())
1960 getPredicate().getNode()->visit(this, visitor);
1961 getInput().getNode()->visit(this, visitor);
1962 getWeights().getNode()->visit(this, visitor);
1963 getBias().getNode()->visit(this, visitor);
1964 getScales().getNode()->visit(this, visitor);
1965 getOffsets().getNode()->visit(this, visitor);
1966 visitor->post(parent, this);
1967}
1968
1969bool DynamicRowwiseQuantizedFullyConnectedNode::isEqual(const DynamicRowwiseQuantizedFullyConnectedNode &other) const {
1970 return true &&
1971 Input_ == other.Input_ &&
1972 Weights_ == other.Weights_ &&
1973 Bias_ == other.Bias_ &&
1974 Scales_ == other.Scales_ &&
1975 Offsets_ == other.Offsets_ &&
1976 predicate_ == other.predicate_ &&
1977 IsSymmetric_ == other.IsSymmetric_ &&
1978 IsPerBatchElement_ == other.IsPerBatchElement_ &&
1979 getType(0) == other.getType(0);
1980}
1981
1982Node* DynamicRowwiseQuantizedFullyConnectedNode::clone() const {
1983 return new DynamicRowwiseQuantizedFullyConnectedNode(getName(), getResult().getType(), getInput(), getWeights(), getBias(), getScales(), getOffsets(), getIsSymmetric(), getIsPerBatchElement());
1984}
1985
1986llvm::hash_code DynamicRowwiseQuantizedFullyConnectedNode::getHash() const {
1987 return llvm::hash_combine(
1988 IsSymmetric_,
1989 IsPerBatchElement_,
1990 Input_,
1991 Weights_,
1992 Bias_,
1993 Scales_,
1994 Offsets_);
1995}
1996
1997unsigned BatchNormalizationGradNode::getNumInputs() const {
1998 return 7;
1999}
2000
2001std::string BatchNormalizationGradNode::getInputName(unsigned idx) const {
2002 if (idx == 0) { return "Input"; }
2003 if (idx == 1) { return "Scale"; }
2004 if (idx == 2) { return "Bias"; }
2005 if (idx == 3) { return "Mean"; }
2006 if (idx == 4) { return "Var"; }
2007 if (idx == 5) { return "OriginalOutputForResult"; }
2008 if (idx == 6) { return "GradOfOriginalOutputNamedResult"; }
2009 idx -= 7;
2010 llvm_unreachable("Invalid index");
2011}
2012
2013NodeValue BatchNormalizationGradNode::getNthInput(unsigned idx) {
2014 if (idx == 0) { return Input_; }
2015 if (idx == 1) { return Scale_; }
2016 if (idx == 2) { return Bias_; }
2017 if (idx == 3) { return Mean_; }
2018 if (idx == 4) { return Var_; }
2019 if (idx == 5) { return OriginalOutputForResult_; }
2020 if (idx == 6) { return GradOfOriginalOutputNamedResult_; }
2021 idx -= 7;
2022 llvm_unreachable("Invalid index");
2023}
2024
2025void BatchNormalizationGradNode::setNthInput(unsigned idx, NodeValue val) {
2026 if (idx == 0) { Input_ = val; return; }
2027 if (idx == 1) { Scale_ = val; return; }
2028 if (idx == 2) { Bias_ = val; return; }
2029 if (idx == 3) { Mean_ = val; return; }
2030 if (idx == 4) { Var_ = val; return; }
2031 if (idx == 5) { OriginalOutputForResult_ = val; return; }
2032 if (idx == 6) { GradOfOriginalOutputNamedResult_ = val; return; }
2033 idx -= 7;
2034 llvm_unreachable("Invalid index");
2035}
2036
2037llvm::StringRef BatchNormalizationGradNode::getOutputName(unsigned idx) const {
2038 if (idx == 0) { return "GradOfInputNamedInput"; }
2039 if (idx == 1) { return "GradOfInputNamedScale"; }
2040 if (idx == 2) { return "GradOfInputNamedBias"; }
2041 if (idx == 3) { return "GradOfInputNamedMean"; }
2042 if (idx == 4) { return "GradOfInputNamedVar"; }
2043 llvm_unreachable("Invalid index");
2044}
2045
2046std::string BatchNormalizationGradNode::getDebugDesc() const {
2047 DescriptionBuilder db(getKindName());
2048 db.addParam("Name", separateString(getName(), 100, "\n"));
2049 if (hasPredicate()) db.addParam("Predicate", "Yes");
2050 db
2051 .addParam("Input", *(getInput().getType()))
2052 .addParam("Scale", *(getScale().getType()))
2053 .addParam("Bias", *(getBias().getType()))
2054 .addParam("Mean", *(getMean().getType()))
2055 .addParam("Var", *(getVar().getType()))
2056 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
2057 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
2058 .addParam("ChannelIdx", getChannelIdx())
2059 .addParam("Epsilon", getEpsilon())
2060 .addParam("Momentum", getMomentum())
2061 .addParam("Users", getNumUsers());
2062 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
2063 db.addParam("GradOfInputNamedScale", *(getGradOfInputNamedScale().getType()));
2064 db.addParam("GradOfInputNamedBias", *(getGradOfInputNamedBias().getType()));
2065 db.addParam("GradOfInputNamedMean", *(getGradOfInputNamedMean().getType()));
2066 db.addParam("GradOfInputNamedVar", *(getGradOfInputNamedVar().getType()));
2067 return db;
2068}
2069
2070void BatchNormalizationGradNode::visit(Node *parent, NodeWalker *visitor) {
2071 if (!visitor->shouldVisit(parent, this)) { return; }
2072 visitor->pre(parent, this);
2073if (hasPredicate())
2074 getPredicate().getNode()->visit(this, visitor);
2075 getInput().getNode()->visit(this, visitor);
2076 getScale().getNode()->visit(this, visitor);
2077 getBias().getNode()->visit(this, visitor);
2078 getMean().getNode()->visit(this, visitor);
2079 getVar().getNode()->visit(this, visitor);
2080 getOriginalOutputForResult().getNode()->visit(this, visitor);
2081 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
2082 visitor->post(parent, this);
2083}
2084
2085bool BatchNormalizationGradNode::isEqual(const BatchNormalizationGradNode &other) const {
2086 return true &&
2087 Input_ == other.Input_ &&
2088 Scale_ == other.Scale_ &&
2089 Bias_ == other.Bias_ &&
2090 Mean_ == other.Mean_ &&
2091 Var_ == other.Var_ &&
2092 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
2093 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
2094 predicate_ == other.predicate_ &&
2095 ChannelIdx_ == other.ChannelIdx_ &&
2096 Epsilon_ == other.Epsilon_ &&
2097 Momentum_ == other.Momentum_ &&
2098 getType(0) == other.getType(0) &&
2099 getType(1) == other.getType(1) &&
2100 getType(2) == other.getType(2) &&
2101 getType(3) == other.getType(3) &&
2102 getType(4) == other.getType(4);
2103}
2104
2105Node* BatchNormalizationGradNode::clone() const {
2106 return new BatchNormalizationGradNode(getName(), getInput(), getScale(), getBias(), getMean(), getVar(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getChannelIdx(), getEpsilon(), getMomentum());
2107}
2108
2109llvm::hash_code BatchNormalizationGradNode::getHash() const {
2110 return llvm::hash_combine(
2111 ChannelIdx_,
2112 toBinary(Epsilon_),
2113 toBinary(Momentum_),
2114 Input_,
2115 Scale_,
2116 Bias_,
2117 Mean_,
2118 Var_,
2119 OriginalOutputForResult_,
2120 GradOfOriginalOutputNamedResult_);
2121}
2122
2123unsigned BatchNormalizationNode::getNumInputs() const {
2124 return 5;
2125}
2126
2127std::string BatchNormalizationNode::getInputName(unsigned idx) const {
2128 if (idx == 0) { return "Input"; }
2129 if (idx == 1) { return "Scale"; }
2130 if (idx == 2) { return "Bias"; }
2131 if (idx == 3) { return "Mean"; }
2132 if (idx == 4) { return "Var"; }
2133 idx -= 5;
2134 llvm_unreachable("Invalid index");
2135}
2136
2137NodeValue BatchNormalizationNode::getNthInput(unsigned idx) {
2138 if (idx == 0) { return Input_; }
2139 if (idx == 1) { return Scale_; }
2140 if (idx == 2) { return Bias_; }
2141 if (idx == 3) { return Mean_; }
2142 if (idx == 4) { return Var_; }
2143 idx -= 5;
2144 llvm_unreachable("Invalid index");
2145}
2146
2147void BatchNormalizationNode::setNthInput(unsigned idx, NodeValue val) {
2148 if (idx == 0) { Input_ = val; return; }
2149 if (idx == 1) { Scale_ = val; return; }
2150 if (idx == 2) { Bias_ = val; return; }
2151 if (idx == 3) { Mean_ = val; return; }
2152 if (idx == 4) { Var_ = val; return; }
2153 idx -= 5;
2154 llvm_unreachable("Invalid index");
2155}
2156
2157llvm::StringRef BatchNormalizationNode::getOutputName(unsigned idx) const {
2158 if (idx == 0) { return "Result"; }
2159 llvm_unreachable("Invalid index");
2160}
2161
2162std::string BatchNormalizationNode::getDebugDesc() const {
2163 DescriptionBuilder db(getKindName());
2164 db.addParam("Name", separateString(getName(), 100, "\n"));
2165 if (hasPredicate()) db.addParam("Predicate", "Yes");
2166 db
2167 .addParam("Input", *(getInput().getType()))
2168 .addParam("Scale", *(getScale().getType()))
2169 .addParam("Bias", *(getBias().getType()))
2170 .addParam("Mean", *(getMean().getType()))
2171 .addParam("Var", *(getVar().getType()))
2172 .addParam("ChannelIdx", getChannelIdx())
2173 .addParam("Epsilon", getEpsilon())
2174 .addParam("Momentum", getMomentum())
2175 .addParam("Users", getNumUsers());
2176 db.addParam("Result", *(getResult().getType()));
2177 return db;
2178}
2179
2180void BatchNormalizationNode::visit(Node *parent, NodeWalker *visitor) {
2181 if (!visitor->shouldVisit(parent, this)) { return; }
2182 visitor->pre(parent, this);
2183if (hasPredicate())
2184 getPredicate().getNode()->visit(this, visitor);
2185 getInput().getNode()->visit(this, visitor);
2186 getScale().getNode()->visit(this, visitor);
2187 getBias().getNode()->visit(this, visitor);
2188 getMean().getNode()->visit(this, visitor);
2189 getVar().getNode()->visit(this, visitor);
2190 visitor->post(parent, this);
2191}
2192
2193bool BatchNormalizationNode::isEqual(const BatchNormalizationNode &other) const {
2194 return true &&
2195 Input_ == other.Input_ &&
2196 Scale_ == other.Scale_ &&
2197 Bias_ == other.Bias_ &&
2198 Mean_ == other.Mean_ &&
2199 Var_ == other.Var_ &&
2200 predicate_ == other.predicate_ &&
2201 ChannelIdx_ == other.ChannelIdx_ &&
2202 Epsilon_ == other.Epsilon_ &&
2203 Momentum_ == other.Momentum_ &&
2204 getType(0) == other.getType(0);
2205}
2206
2207Node* BatchNormalizationNode::clone() const {
2208 return new BatchNormalizationNode(getName(), getResult().getType(), getInput(), getScale(), getBias(), getMean(), getVar(), getChannelIdx(), getEpsilon(), getMomentum());
2209}
2210
2211llvm::hash_code BatchNormalizationNode::getHash() const {
2212 return llvm::hash_combine(
2213 ChannelIdx_,
2214 toBinary(Epsilon_),
2215 toBinary(Momentum_),
2216 Input_,
2217 Scale_,
2218 Bias_,
2219 Mean_,
2220 Var_);
2221}
2222
2223BatchNormalizationGradNode *BatchNormalizationNode::getGrad(GraphGradMapper &builder) {
2224 auto *x = new BatchNormalizationGradNode(getName().str() + "_grad", getInput(), getScale(), getBias(), getMean(), getVar(), getResult(), builder.getGradient(getResult()), getChannelIdx(), getEpsilon(), getMomentum());
2225 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
2226 builder.addGradient(getScale(), x->getGradOfInputNamedScale());
2227 builder.addGradient(getBias(), x->getGradOfInputNamedBias());
2228 builder.addGradient(getMean(), x->getGradOfInputNamedMean());
2229 builder.addGradient(getVar(), x->getGradOfInputNamedVar());
2230 return x;
2231}
2232
2233unsigned InstanceNormalizationNode::getNumInputs() const {
2234 return 3;
2235}
2236
2237std::string InstanceNormalizationNode::getInputName(unsigned idx) const {
2238 if (idx == 0) { return "Input"; }
2239 if (idx == 1) { return "Scale"; }
2240 if (idx == 2) { return "Bias"; }
2241 idx -= 3;
2242 llvm_unreachable("Invalid index");
2243}
2244
2245NodeValue InstanceNormalizationNode::getNthInput(unsigned idx) {
2246 if (idx == 0) { return Input_; }
2247 if (idx == 1) { return Scale_; }
2248 if (idx == 2) { return Bias_; }
2249 idx -= 3;
2250 llvm_unreachable("Invalid index");
2251}
2252
2253void InstanceNormalizationNode::setNthInput(unsigned idx, NodeValue val) {
2254 if (idx == 0) { Input_ = val; return; }
2255 if (idx == 1) { Scale_ = val; return; }
2256 if (idx == 2) { Bias_ = val; return; }
2257 idx -= 3;
2258 llvm_unreachable("Invalid index");
2259}
2260
2261llvm::StringRef InstanceNormalizationNode::getOutputName(unsigned idx) const {
2262 if (idx == 0) { return "Result"; }
2263 llvm_unreachable("Invalid index");
2264}
2265
2266std::string InstanceNormalizationNode::getDebugDesc() const {
2267 DescriptionBuilder db(getKindName());
2268 db.addParam("Name", separateString(getName(), 100, "\n"));
2269 if (hasPredicate()) db.addParam("Predicate", "Yes");
2270 db
2271 .addParam("Input", *(getInput().getType()))
2272 .addParam("Scale", *(getScale().getType()))
2273 .addParam("Bias", *(getBias().getType()))
2274 .addParam("ChannelIdx", getChannelIdx())
2275 .addParam("Epsilon", getEpsilon())
2276 .addParam("Users", getNumUsers());
2277 db.addParam("Result", *(getResult().getType()));
2278 return db;
2279}
2280
2281void InstanceNormalizationNode::visit(Node *parent, NodeWalker *visitor) {
2282 if (!visitor->shouldVisit(parent, this)) { return; }
2283 visitor->pre(parent, this);
2284if (hasPredicate())
2285 getPredicate().getNode()->visit(this, visitor);
2286 getInput().getNode()->visit(this, visitor);
2287 getScale().getNode()->visit(this, visitor);
2288 getBias().getNode()->visit(this, visitor);
2289 visitor->post(parent, this);
2290}
2291
2292bool InstanceNormalizationNode::isEqual(const InstanceNormalizationNode &other) const {
2293 return true &&
2294 Input_ == other.Input_ &&
2295 Scale_ == other.Scale_ &&
2296 Bias_ == other.Bias_ &&
2297 predicate_ == other.predicate_ &&
2298 ChannelIdx_ == other.ChannelIdx_ &&
2299 Epsilon_ == other.Epsilon_ &&
2300 getType(0) == other.getType(0);
2301}
2302
2303Node* InstanceNormalizationNode::clone() const {
2304 return new InstanceNormalizationNode(getName(), getInput(), getScale(), getBias(), getChannelIdx(), getEpsilon());
2305}
2306
2307llvm::hash_code InstanceNormalizationNode::getHash() const {
2308 return llvm::hash_combine(
2309 ChannelIdx_,
2310 toBinary(Epsilon_),
2311 Input_,
2312 Scale_,
2313 Bias_);
2314}
2315
2316unsigned MeanVarNormalizationNode::getNumInputs() const {
2317 return 3;
2318}
2319
2320std::string MeanVarNormalizationNode::getInputName(unsigned idx) const {
2321 if (idx == 0) { return "Input"; }
2322 if (idx == 1) { return "Mean"; }
2323 if (idx == 2) { return "Var"; }
2324 idx -= 3;
2325 llvm_unreachable("Invalid index");
2326}
2327
2328NodeValue MeanVarNormalizationNode::getNthInput(unsigned idx) {
2329 if (idx == 0) { return Input_; }
2330 if (idx == 1) { return Mean_; }
2331 if (idx == 2) { return Var_; }
2332 idx -= 3;
2333 llvm_unreachable("Invalid index");
2334}
2335
2336void MeanVarNormalizationNode::setNthInput(unsigned idx, NodeValue val) {
2337 if (idx == 0) { Input_ = val; return; }
2338 if (idx == 1) { Mean_ = val; return; }
2339 if (idx == 2) { Var_ = val; return; }
2340 idx -= 3;
2341 llvm_unreachable("Invalid index");
2342}
2343
2344llvm::StringRef MeanVarNormalizationNode::getOutputName(unsigned idx) const {
2345 if (idx == 0) { return "NewMean"; }
2346 if (idx == 1) { return "NewVar"; }
2347 llvm_unreachable("Invalid index");
2348}
2349
2350std::string MeanVarNormalizationNode::getDebugDesc() const {
2351 DescriptionBuilder db(getKindName());
2352 db.addParam("Name", separateString(getName(), 100, "\n"));
2353 if (hasPredicate()) db.addParam("Predicate", "Yes");
2354 db
2355 .addParam("Input", *(getInput().getType()))
2356 .addParam("Mean", *(getMean().getType()))
2357 .addParam("Var", *(getVar().getType()))
2358 .addParam("ChannelIdx", getChannelIdx())
2359 .addParam("Momentum", getMomentum())
2360 .addParam("Users", getNumUsers());
2361 db.addParam("NewMean", *(getNewMean().getType()));
2362 db.addParam("NewVar", *(getNewVar().getType()));
2363 return db;
2364}
2365
2366void MeanVarNormalizationNode::visit(Node *parent, NodeWalker *visitor) {
2367 if (!visitor->shouldVisit(parent, this)) { return; }
2368 visitor->pre(parent, this);
2369if (hasPredicate())
2370 getPredicate().getNode()->visit(this, visitor);
2371 getInput().getNode()->visit(this, visitor);
2372 getMean().getNode()->visit(this, visitor);
2373 getVar().getNode()->visit(this, visitor);
2374 visitor->post(parent, this);
2375}
2376
2377bool MeanVarNormalizationNode::isEqual(const MeanVarNormalizationNode &other) const {
2378 return true &&
2379 Input_ == other.Input_ &&
2380 Mean_ == other.Mean_ &&
2381 Var_ == other.Var_ &&
2382 predicate_ == other.predicate_ &&
2383 ChannelIdx_ == other.ChannelIdx_ &&
2384 Momentum_ == other.Momentum_ &&
2385 getType(0) == other.getType(0) &&
2386 getType(1) == other.getType(1);
2387}
2388
2389Node* MeanVarNormalizationNode::clone() const {
2390 return new MeanVarNormalizationNode(getName(), getInput(), getMean(), getVar(), getChannelIdx(), getMomentum());
2391}
2392
2393llvm::hash_code MeanVarNormalizationNode::getHash() const {
2394 return llvm::hash_combine(
2395 ChannelIdx_,
2396 toBinary(Momentum_),
2397 Input_,
2398 Mean_,
2399 Var_);
2400}
2401
2402unsigned LocalResponseNormalizationGradNode::getNumInputs() const {
2403 return 3;
2404}
2405
2406std::string LocalResponseNormalizationGradNode::getInputName(unsigned idx) const {
2407 if (idx == 0) { return "Input"; }
2408 if (idx == 1) { return "OriginalOutputForResult"; }
2409 if (idx == 2) { return "GradOfOriginalOutputNamedResult"; }
2410 idx -= 3;
2411 llvm_unreachable("Invalid index");
2412}
2413
2414NodeValue LocalResponseNormalizationGradNode::getNthInput(unsigned idx) {
2415 if (idx == 0) { return Input_; }
2416 if (idx == 1) { return OriginalOutputForResult_; }
2417 if (idx == 2) { return GradOfOriginalOutputNamedResult_; }
2418 idx -= 3;
2419 llvm_unreachable("Invalid index");
2420}
2421
2422void LocalResponseNormalizationGradNode::setNthInput(unsigned idx, NodeValue val) {
2423 if (idx == 0) { Input_ = val; return; }
2424 if (idx == 1) { OriginalOutputForResult_ = val; return; }
2425 if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; }
2426 idx -= 3;
2427 llvm_unreachable("Invalid index");
2428}
2429
2430llvm::StringRef LocalResponseNormalizationGradNode::getOutputName(unsigned idx) const {
2431 if (idx == 0) { return "GradOfInputNamedInput"; }
2432 llvm_unreachable("Invalid index");
2433}
2434
2435std::string LocalResponseNormalizationGradNode::getDebugDesc() const {
2436 DescriptionBuilder db(getKindName());
2437 db.addParam("Name", separateString(getName(), 100, "\n"));
2438 if (hasPredicate()) db.addParam("Predicate", "Yes");
2439 db
2440 .addParam("Input", *(getInput().getType()))
2441 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
2442 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
2443 .addParam("HalfWindowSize", getHalfWindowSize())
2444 .addParam("Alpha", getAlpha())
2445 .addParam("Beta", getBeta())
2446 .addParam("K", getK())
2447 .addParam("Users", getNumUsers());
2448 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
2449 return db;
2450}
2451
2452void LocalResponseNormalizationGradNode::visit(Node *parent, NodeWalker *visitor) {
2453 if (!visitor->shouldVisit(parent, this)) { return; }
2454 visitor->pre(parent, this);
2455if (hasPredicate())
2456 getPredicate().getNode()->visit(this, visitor);
2457 getInput().getNode()->visit(this, visitor);
2458 getOriginalOutputForResult().getNode()->visit(this, visitor);
2459 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
2460 visitor->post(parent, this);
2461}
2462
2463bool LocalResponseNormalizationGradNode::isEqual(const LocalResponseNormalizationGradNode &other) const {
2464 return true &&
2465 Input_ == other.Input_ &&
2466 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
2467 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
2468 predicate_ == other.predicate_ &&
2469 HalfWindowSize_ == other.HalfWindowSize_ &&
2470 Alpha_ == other.Alpha_ &&
2471 Beta_ == other.Beta_ &&
2472 K_ == other.K_ &&
2473 getType(0) == other.getType(0);
2474}
2475
2476Node* LocalResponseNormalizationGradNode::clone() const {
2477 return new LocalResponseNormalizationGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getHalfWindowSize(), getAlpha(), getBeta(), getK());
2478}
2479
2480llvm::hash_code LocalResponseNormalizationGradNode::getHash() const {
2481 return llvm::hash_combine(
2482 HalfWindowSize_,
2483 toBinary(Alpha_),
2484 toBinary(Beta_),
2485 toBinary(K_),
2486 Input_,
2487 OriginalOutputForResult_,
2488 GradOfOriginalOutputNamedResult_);
2489}
2490
2491unsigned LocalResponseNormalizationNode::getNumInputs() const {
2492 return 1;
2493}
2494
2495std::string LocalResponseNormalizationNode::getInputName(unsigned idx) const {
2496 if (idx == 0) { return "Input"; }
2497 idx -= 1;
2498 llvm_unreachable("Invalid index");
2499}
2500
2501NodeValue LocalResponseNormalizationNode::getNthInput(unsigned idx) {
2502 if (idx == 0) { return Input_; }
2503 idx -= 1;
2504 llvm_unreachable("Invalid index");
2505}
2506
2507void LocalResponseNormalizationNode::setNthInput(unsigned idx, NodeValue val) {
2508 if (idx == 0) { Input_ = val; return; }
2509 idx -= 1;
2510 llvm_unreachable("Invalid index");
2511}
2512
2513llvm::StringRef LocalResponseNormalizationNode::getOutputName(unsigned idx) const {
2514 if (idx == 0) { return "Result"; }
2515 llvm_unreachable("Invalid index");
2516}
2517
2518std::string LocalResponseNormalizationNode::getDebugDesc() const {
2519 DescriptionBuilder db(getKindName());
2520 db.addParam("Name", separateString(getName(), 100, "\n"));
2521 if (hasPredicate()) db.addParam("Predicate", "Yes");
2522 db
2523 .addParam("Input", *(getInput().getType()))
2524 .addParam("HalfWindowSize", getHalfWindowSize())
2525 .addParam("Alpha", getAlpha())
2526 .addParam("Beta", getBeta())
2527 .addParam("K", getK())
2528 .addParam("Users", getNumUsers());
2529 db.addParam("Result", *(getResult().getType()));
2530 return db;
2531}
2532
2533void LocalResponseNormalizationNode::visit(Node *parent, NodeWalker *visitor) {
2534 if (!visitor->shouldVisit(parent, this)) { return; }
2535 visitor->pre(parent, this);
2536if (hasPredicate())
2537 getPredicate().getNode()->visit(this, visitor);
2538 getInput().getNode()->visit(this, visitor);
2539 visitor->post(parent, this);
2540}
2541
2542bool LocalResponseNormalizationNode::isEqual(const LocalResponseNormalizationNode &other) const {
2543 return true &&
2544 Input_ == other.Input_ &&
2545 predicate_ == other.predicate_ &&
2546 HalfWindowSize_ == other.HalfWindowSize_ &&
2547 Alpha_ == other.Alpha_ &&
2548 Beta_ == other.Beta_ &&
2549 K_ == other.K_ &&
2550 getType(0) == other.getType(0);
2551}
2552
2553Node* LocalResponseNormalizationNode::clone() const {
2554 return new LocalResponseNormalizationNode(getName(), getInput(), getHalfWindowSize(), getAlpha(), getBeta(), getK());
2555}
2556
2557llvm::hash_code LocalResponseNormalizationNode::getHash() const {
2558 return llvm::hash_combine(
2559 HalfWindowSize_,
2560 toBinary(Alpha_),
2561 toBinary(Beta_),
2562 toBinary(K_),
2563 Input_);
2564}
2565
2566LocalResponseNormalizationGradNode *LocalResponseNormalizationNode::getGrad(GraphGradMapper &builder) {
2567 auto *x = new LocalResponseNormalizationGradNode(getName().str() + "_grad", getInput(), getResult(), builder.getGradient(getResult()), getHalfWindowSize(), getAlpha(), getBeta(), getK());
2568 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
2569 return x;
2570}
2571
2572unsigned LayerNormalizationNode::getNumInputs() const {
2573 return 3;
2574}
2575
2576std::string LayerNormalizationNode::getInputName(unsigned idx) const {
2577 if (idx == 0) { return "Input"; }
2578 if (idx == 1) { return "Scale"; }
2579 if (idx == 2) { return "Bias"; }
2580 idx -= 3;
2581 llvm_unreachable("Invalid index");
2582}
2583
2584NodeValue LayerNormalizationNode::getNthInput(unsigned idx) {
2585 if (idx == 0) { return Input_; }
2586 if (idx == 1) { return Scale_; }
2587 if (idx == 2) { return Bias_; }
2588 idx -= 3;
2589 llvm_unreachable("Invalid index");
2590}
2591
2592void LayerNormalizationNode::setNthInput(unsigned idx, NodeValue val) {
2593 if (idx == 0) { Input_ = val; return; }
2594 if (idx == 1) { Scale_ = val; return; }
2595 if (idx == 2) { Bias_ = val; return; }
2596 idx -= 3;
2597 llvm_unreachable("Invalid index");
2598}
2599
2600llvm::StringRef LayerNormalizationNode::getOutputName(unsigned idx) const {
2601 if (idx == 0) { return "Result"; }
2602 llvm_unreachable("Invalid index");
2603}
2604
2605std::string LayerNormalizationNode::getDebugDesc() const {
2606 DescriptionBuilder db(getKindName());
2607 db.addParam("Name", separateString(getName(), 100, "\n"));
2608 if (hasPredicate()) db.addParam("Predicate", "Yes");
2609 db
2610 .addParam("Input", *(getInput().getType()))
2611 .addParam("Scale", *(getScale().getType()))
2612 .addParam("Bias", *(getBias().getType()))
2613 .addParam("Epsilon", getEpsilon())
2614 .addParam("Users", getNumUsers());
2615 db.addParam("Result", *(getResult().getType()));
2616 return db;
2617}
2618
2619void LayerNormalizationNode::visit(Node *parent, NodeWalker *visitor) {
2620 if (!visitor->shouldVisit(parent, this)) { return; }
2621 visitor->pre(parent, this);
2622if (hasPredicate())
2623 getPredicate().getNode()->visit(this, visitor);
2624 getInput().getNode()->visit(this, visitor);
2625 getScale().getNode()->visit(this, visitor);
2626 getBias().getNode()->visit(this, visitor);
2627 visitor->post(parent, this);
2628}
2629
2630bool LayerNormalizationNode::isEqual(const LayerNormalizationNode &other) const {
2631 return true &&
2632 Input_ == other.Input_ &&
2633 Scale_ == other.Scale_ &&
2634 Bias_ == other.Bias_ &&
2635 predicate_ == other.predicate_ &&
2636 Epsilon_ == other.Epsilon_ &&
2637 getType(0) == other.getType(0);
2638}
2639
2640Node* LayerNormalizationNode::clone() const {
2641 return new LayerNormalizationNode(getName(), getResult().getType(), getInput(), getScale(), getBias(), getEpsilon());
2642}
2643
2644llvm::hash_code LayerNormalizationNode::getHash() const {
2645 return llvm::hash_combine(
2646 toBinary(Epsilon_),
2647 Input_,
2648 Scale_,
2649 Bias_);
2650}
2651
2652unsigned BatchBoxCoxNode::getNumInputs() const {
2653 return 3;
2654}
2655
2656std::string BatchBoxCoxNode::getInputName(unsigned idx) const {
2657 if (idx == 0) { return "Input"; }
2658 if (idx == 1) { return "Lambda1"; }
2659 if (idx == 2) { return "Lambda2"; }
2660 idx -= 3;
2661 llvm_unreachable("Invalid index");
2662}
2663
2664NodeValue BatchBoxCoxNode::getNthInput(unsigned idx) {
2665 if (idx == 0) { return Input_; }
2666 if (idx == 1) { return Lambda1_; }
2667 if (idx == 2) { return Lambda2_; }
2668 idx -= 3;
2669 llvm_unreachable("Invalid index");
2670}
2671
2672void BatchBoxCoxNode::setNthInput(unsigned idx, NodeValue val) {
2673 if (idx == 0) { Input_ = val; return; }
2674 if (idx == 1) { Lambda1_ = val; return; }
2675 if (idx == 2) { Lambda2_ = val; return; }
2676 idx -= 3;
2677 llvm_unreachable("Invalid index");
2678}
2679
2680llvm::StringRef BatchBoxCoxNode::getOutputName(unsigned idx) const {
2681 if (idx == 0) { return "Result"; }
2682 llvm_unreachable("Invalid index");
2683}
2684
2685std::string BatchBoxCoxNode::getDebugDesc() const {
2686 DescriptionBuilder db(getKindName());
2687 db.addParam("Name", separateString(getName(), 100, "\n"));
2688 if (hasPredicate()) db.addParam("Predicate", "Yes");
2689 db
2690 .addParam("Input", *(getInput().getType()))
2691 .addParam("Lambda1", *(getLambda1().getType()))
2692 .addParam("Lambda2", *(getLambda2().getType()))
2693 .addParam("Epsilon", getEpsilon())
2694 .addParam("Users", getNumUsers());
2695 db.addParam("Result", *(getResult().getType()));
2696 return db;
2697}
2698
2699void BatchBoxCoxNode::visit(Node *parent, NodeWalker *visitor) {
2700 if (!visitor->shouldVisit(parent, this)) { return; }
2701 visitor->pre(parent, this);
2702if (hasPredicate())
2703 getPredicate().getNode()->visit(this, visitor);
2704 getInput().getNode()->visit(this, visitor);
2705 getLambda1().getNode()->visit(this, visitor);
2706 getLambda2().getNode()->visit(this, visitor);
2707 visitor->post(parent, this);
2708}
2709
2710bool BatchBoxCoxNode::isEqual(const BatchBoxCoxNode &other) const {
2711 return true &&
2712 Input_ == other.Input_ &&
2713 Lambda1_ == other.Lambda1_ &&
2714 Lambda2_ == other.Lambda2_ &&
2715 predicate_ == other.predicate_ &&
2716 Epsilon_ == other.Epsilon_ &&
2717 getType(0) == other.getType(0);
2718}
2719
2720Node* BatchBoxCoxNode::clone() const {
2721 return new BatchBoxCoxNode(getName(), getInput(), getLambda1(), getLambda2(), getEpsilon());
2722}
2723
2724llvm::hash_code BatchBoxCoxNode::getHash() const {
2725 return llvm::hash_combine(
2726 toBinary(Epsilon_),
2727 Input_,
2728 Lambda1_,
2729 Lambda2_);
2730}
2731
2732unsigned VectorNormNode::getNumInputs() const {
2733 return 1;
2734}
2735
2736std::string VectorNormNode::getInputName(unsigned idx) const {
2737 if (idx == 0) { return "Input"; }
2738 idx -= 1;
2739 llvm_unreachable("Invalid index");
2740}
2741
2742NodeValue VectorNormNode::getNthInput(unsigned idx) {
2743 if (idx == 0) { return Input_; }
2744 idx -= 1;
2745 llvm_unreachable("Invalid index");
2746}
2747
2748void VectorNormNode::setNthInput(unsigned idx, NodeValue val) {
2749 if (idx == 0) { Input_ = val; return; }
2750 idx -= 1;
2751 llvm_unreachable("Invalid index");
2752}
2753
2754llvm::StringRef VectorNormNode::getOutputName(unsigned idx) const {
2755 if (idx == 0) { return "Result"; }
2756 llvm_unreachable("Invalid index");
2757}
2758
2759std::string VectorNormNode::getDebugDesc() const {
2760 DescriptionBuilder db(getKindName());
2761 db.addParam("Name", separateString(getName(), 100, "\n"));
2762 if (hasPredicate()) db.addParam("Predicate", "Yes");
2763 db
2764 .addParam("Input", *(getInput().getType()))
2765 .addParam("Axis", getAxis())
2766 .addParam("P", getP())
2767 .addParam("Users", getNumUsers());
2768 db.addParam("Result", *(getResult().getType()));
2769 return db;
2770}
2771
2772void VectorNormNode::visit(Node *parent, NodeWalker *visitor) {
2773 if (!visitor->shouldVisit(parent, this)) { return; }
2774 visitor->pre(parent, this);
2775if (hasPredicate())
2776 getPredicate().getNode()->visit(this, visitor);
2777 getInput().getNode()->visit(this, visitor);
2778 visitor->post(parent, this);
2779}
2780
2781bool VectorNormNode::isEqual(const VectorNormNode &other) const {
2782 return true &&
2783 Input_ == other.Input_ &&
2784 predicate_ == other.predicate_ &&
2785 Axis_ == other.Axis_ &&
2786 P_ == other.P_ &&
2787 getType(0) == other.getType(0);
2788}
2789
2790Node* VectorNormNode::clone() const {
2791 return new VectorNormNode(getName(), getResult().getType(), getInput(), getAxis(), getP());
2792}
2793
2794llvm::hash_code VectorNormNode::getHash() const {
2795 return llvm::hash_combine(
2796 Axis_,
2797 P_,
2798 Input_);
2799}
2800
2801unsigned BucketizeNode::getNumInputs() const {
2802 return 1;
2803}
2804
2805std::string BucketizeNode::getInputName(unsigned idx) const {
2806 if (idx == 0) { return "Input"; }
2807 idx -= 1;
2808 llvm_unreachable("Invalid index");
2809}
2810
2811NodeValue BucketizeNode::getNthInput(unsigned idx) {
2812 if (idx == 0) { return Input_; }
2813 idx -= 1;
2814 llvm_unreachable("Invalid index");
2815}
2816
2817void BucketizeNode::setNthInput(unsigned idx, NodeValue val) {
2818 if (idx == 0) { Input_ = val; return; }
2819 idx -= 1;
2820 llvm_unreachable("Invalid index");
2821}
2822
2823llvm::StringRef BucketizeNode::getOutputName(unsigned idx) const {
2824 if (idx == 0) { return "Result"; }
2825 llvm_unreachable("Invalid index");
2826}
2827
2828std::string BucketizeNode::getDebugDesc() const {
2829 DescriptionBuilder db(getKindName());
2830 db.addParam("Name", separateString(getName(), 100, "\n"));
2831 if (hasPredicate()) db.addParam("Predicate", "Yes");
2832 db
2833 .addParam("Input", *(getInput().getType()))
2834 .addParam("Boundaries", getBoundaries())
2835 .addParam("Users", getNumUsers());
2836 db.addParam("Result", *(getResult().getType()));
2837 return db;
2838}
2839
2840void BucketizeNode::visit(Node *parent, NodeWalker *visitor) {
2841 if (!visitor->shouldVisit(parent, this)) { return; }
2842 visitor->pre(parent, this);
2843if (hasPredicate())
2844 getPredicate().getNode()->visit(this, visitor);
2845 getInput().getNode()->visit(this, visitor);
2846 visitor->post(parent, this);
2847}
2848
2849bool BucketizeNode::isEqual(const BucketizeNode &other) const {
2850 return true &&
2851 Input_ == other.Input_ &&
2852 predicate_ == other.predicate_ &&
2853 Boundaries_ == other.Boundaries_ &&
2854 getType(0) == other.getType(0);
2855}
2856
2857Node* BucketizeNode::clone() const {
2858 return new BucketizeNode(getName(), getResult().getType(), getInput(), getBoundaries());
2859}
2860
2861llvm::hash_code BucketizeNode::getHash() const {
2862 return llvm::hash_combine(
2863 [](const std::vector<float>& floatVec) -> llvm::hash_code {
2864 std::vector<size_t> sizeVec = toBinary(floatVec);
2865 return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end());
2866 }(Boundaries_),
2867 Input_);
2868}
2869
2870unsigned SoftMaxGradNode::getNumInputs() const {
2871 return 4;
2872}
2873
2874std::string SoftMaxGradNode::getInputName(unsigned idx) const {
2875 if (idx == 0) { return "Input"; }
2876 if (idx == 1) { return "Selected"; }
2877 if (idx == 2) { return "OriginalOutputForResult"; }
2878 if (idx == 3) { return "GradOfOriginalOutputNamedResult"; }
2879 idx -= 4;
2880 llvm_unreachable("Invalid index");
2881}
2882
2883NodeValue SoftMaxGradNode::getNthInput(unsigned idx) {
2884 if (idx == 0) { return Input_; }
2885 if (idx == 1) { return Selected_; }
2886 if (idx == 2) { return OriginalOutputForResult_; }
2887 if (idx == 3) { return GradOfOriginalOutputNamedResult_; }
2888 idx -= 4;
2889 llvm_unreachable("Invalid index");
2890}
2891
2892void SoftMaxGradNode::setNthInput(unsigned idx, NodeValue val) {
2893 if (idx == 0) { Input_ = val; return; }
2894 if (idx == 1) { Selected_ = val; return; }
2895 if (idx == 2) { OriginalOutputForResult_ = val; return; }
2896 if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; }
2897 idx -= 4;
2898 llvm_unreachable("Invalid index");
2899}
2900
2901llvm::StringRef SoftMaxGradNode::getOutputName(unsigned idx) const {
2902 if (idx == 0) { return "GradOfInputNamedInput"; }
2903 if (idx == 1) { return "GradOfInputNamedSelected"; }
2904 llvm_unreachable("Invalid index");
2905}
2906
2907std::string SoftMaxGradNode::getDebugDesc() const {
2908 DescriptionBuilder db(getKindName());
2909 db.addParam("Name", separateString(getName(), 100, "\n"));
2910 if (hasPredicate()) db.addParam("Predicate", "Yes");
2911 db
2912 .addParam("Input", *(getInput().getType()))
2913 .addParam("Selected", *(getSelected().getType()))
2914 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
2915 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
2916 .addParam("Users", getNumUsers());
2917 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
2918 db.addParam("GradOfInputNamedSelected", *(getGradOfInputNamedSelected().getType()));
2919 return db;
2920}
2921
2922void SoftMaxGradNode::visit(Node *parent, NodeWalker *visitor) {
2923 if (!visitor->shouldVisit(parent, this)) { return; }
2924 visitor->pre(parent, this);
2925if (hasPredicate())
2926 getPredicate().getNode()->visit(this, visitor);
2927 getInput().getNode()->visit(this, visitor);
2928 getSelected().getNode()->visit(this, visitor);
2929 getOriginalOutputForResult().getNode()->visit(this, visitor);
2930 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
2931 visitor->post(parent, this);
2932}
2933
2934bool SoftMaxGradNode::isEqual(const SoftMaxGradNode &other) const {
2935 return true &&
2936 Input_ == other.Input_ &&
2937 Selected_ == other.Selected_ &&
2938 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
2939 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
2940 predicate_ == other.predicate_ &&
2941 getType(0) == other.getType(0) &&
2942 getType(1) == other.getType(1);
2943}
2944
2945Node* SoftMaxGradNode::clone() const {
2946 return new SoftMaxGradNode(getName(), getInput(), getSelected(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
2947}
2948
2949llvm::hash_code SoftMaxGradNode::getHash() const {
2950 return llvm::hash_combine(
2951 Input_,
2952 Selected_,
2953 OriginalOutputForResult_,
2954 GradOfOriginalOutputNamedResult_);
2955}
2956
2957unsigned SoftMaxNode::getNumInputs() const {
2958 return 2;
2959}
2960
2961std::string SoftMaxNode::getInputName(unsigned idx) const {
2962 if (idx == 0) { return "Input"; }
2963 if (idx == 1) { return "Selected"; }
2964 idx -= 2;
2965 llvm_unreachable("Invalid index");
2966}
2967
2968NodeValue SoftMaxNode::getNthInput(unsigned idx) {
2969 if (idx == 0) { return Input_; }
2970 if (idx == 1) { return Selected_; }
2971 idx -= 2;
2972 llvm_unreachable("Invalid index");
2973}
2974
2975void SoftMaxNode::setNthInput(unsigned idx, NodeValue val) {
2976 if (idx == 0) { Input_ = val; return; }
2977 if (idx == 1) { Selected_ = val; return; }
2978 idx -= 2;
2979 llvm_unreachable("Invalid index");
2980}
2981
2982llvm::StringRef SoftMaxNode::getOutputName(unsigned idx) const {
2983 if (idx == 0) { return "Result"; }
2984 llvm_unreachable("Invalid index");
2985}
2986
2987std::string SoftMaxNode::getDebugDesc() const {
2988 DescriptionBuilder db(getKindName());
2989 db.addParam("Name", separateString(getName(), 100, "\n"));
2990 if (hasPredicate()) db.addParam("Predicate", "Yes");
2991 db
2992 .addParam("Input", *(getInput().getType()))
2993 .addParam("Selected", *(getSelected().getType()))
2994 .addParam("Users", getNumUsers());
2995 db.addParam("Result", *(getResult().getType()));
2996 return db;
2997}
2998
2999void SoftMaxNode::visit(Node *parent, NodeWalker *visitor) {
3000 if (!visitor->shouldVisit(parent, this)) { return; }
3001 visitor->pre(parent, this);
3002if (hasPredicate())
3003 getPredicate().getNode()->visit(this, visitor);
3004 getInput().getNode()->visit(this, visitor);
3005 getSelected().getNode()->visit(this, visitor);
3006 visitor->post(parent, this);
3007}
3008
3009bool SoftMaxNode::isEqual(const SoftMaxNode &other) const {
3010 return true &&
3011 Input_ == other.Input_ &&
3012 Selected_ == other.Selected_ &&
3013 predicate_ == other.predicate_ &&
3014 getType(0) == other.getType(0);
3015}
3016
3017Node* SoftMaxNode::clone() const {
3018 return new SoftMaxNode(getName(), getResult().getType(), getInput(), getSelected());
3019}
3020
3021llvm::hash_code SoftMaxNode::getHash() const {
3022 return llvm::hash_combine(
3023 Input_,
3024 Selected_);
3025}
3026
3027SoftMaxGradNode *SoftMaxNode::getGrad(GraphGradMapper &builder) {
3028 auto *x = new SoftMaxGradNode(getName().str() + "_grad", getInput(), getSelected(), getResult(), builder.getGradient(getResult()));
3029 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
3030 builder.addGradient(getSelected(), x->getGradOfInputNamedSelected());
3031 return x;
3032}
3033
3034unsigned LogSoftMaxGradNode::getNumInputs() const {
3035 return 4;
3036}
3037
3038std::string LogSoftMaxGradNode::getInputName(unsigned idx) const {
3039 if (idx == 0) { return "Input"; }
3040 if (idx == 1) { return "Selected"; }
3041 if (idx == 2) { return "OriginalOutputForResult"; }
3042 if (idx == 3) { return "GradOfOriginalOutputNamedResult"; }
3043 idx -= 4;
3044 llvm_unreachable("Invalid index");
3045}
3046
3047NodeValue LogSoftMaxGradNode::getNthInput(unsigned idx) {
3048 if (idx == 0) { return Input_; }
3049 if (idx == 1) { return Selected_; }
3050 if (idx == 2) { return OriginalOutputForResult_; }
3051 if (idx == 3) { return GradOfOriginalOutputNamedResult_; }
3052 idx -= 4;
3053 llvm_unreachable("Invalid index");
3054}
3055
3056void LogSoftMaxGradNode::setNthInput(unsigned idx, NodeValue val) {
3057 if (idx == 0) { Input_ = val; return; }
3058 if (idx == 1) { Selected_ = val; return; }
3059 if (idx == 2) { OriginalOutputForResult_ = val; return; }
3060 if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; }
3061 idx -= 4;
3062 llvm_unreachable("Invalid index");
3063}
3064
3065llvm::StringRef LogSoftMaxGradNode::getOutputName(unsigned idx) const {
3066 if (idx == 0) { return "GradOfInputNamedInput"; }
3067 if (idx == 1) { return "GradOfInputNamedSelected"; }
3068 llvm_unreachable("Invalid index");
3069}
3070
3071std::string LogSoftMaxGradNode::getDebugDesc() const {
3072 DescriptionBuilder db(getKindName());
3073 db.addParam("Name", separateString(getName(), 100, "\n"));
3074 if (hasPredicate()) db.addParam("Predicate", "Yes");
3075 db
3076 .addParam("Input", *(getInput().getType()))
3077 .addParam("Selected", *(getSelected().getType()))
3078 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
3079 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
3080 .addParam("Users", getNumUsers());
3081 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
3082 db.addParam("GradOfInputNamedSelected", *(getGradOfInputNamedSelected().getType()));
3083 return db;
3084}
3085
3086void LogSoftMaxGradNode::visit(Node *parent, NodeWalker *visitor) {
3087 if (!visitor->shouldVisit(parent, this)) { return; }
3088 visitor->pre(parent, this);
3089if (hasPredicate())
3090 getPredicate().getNode()->visit(this, visitor);
3091 getInput().getNode()->visit(this, visitor);
3092 getSelected().getNode()->visit(this, visitor);
3093 getOriginalOutputForResult().getNode()->visit(this, visitor);
3094 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
3095 visitor->post(parent, this);
3096}
3097
3098bool LogSoftMaxGradNode::isEqual(const LogSoftMaxGradNode &other) const {
3099 return true &&
3100 Input_ == other.Input_ &&
3101 Selected_ == other.Selected_ &&
3102 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
3103 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
3104 predicate_ == other.predicate_ &&
3105 getType(0) == other.getType(0) &&
3106 getType(1) == other.getType(1);
3107}
3108
3109Node* LogSoftMaxGradNode::clone() const {
3110 return new LogSoftMaxGradNode(getName(), getInput(), getSelected(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
3111}
3112
3113llvm::hash_code LogSoftMaxGradNode::getHash() const {
3114 return llvm::hash_combine(
3115 Input_,
3116 Selected_,
3117 OriginalOutputForResult_,
3118 GradOfOriginalOutputNamedResult_);
3119}
3120
3121unsigned LogSoftMaxNode::getNumInputs() const {
3122 return 2;
3123}
3124
3125std::string LogSoftMaxNode::getInputName(unsigned idx) const {
3126 if (idx == 0) { return "Input"; }
3127 if (idx == 1) { return "Selected"; }
3128 idx -= 2;
3129 llvm_unreachable("Invalid index");
3130}
3131
3132NodeValue LogSoftMaxNode::getNthInput(unsigned idx) {
3133 if (idx == 0) { return Input_; }
3134 if (idx == 1) { return Selected_; }
3135 idx -= 2;
3136 llvm_unreachable("Invalid index");
3137}
3138
3139void LogSoftMaxNode::setNthInput(unsigned idx, NodeValue val) {
3140 if (idx == 0) { Input_ = val; return; }
3141 if (idx == 1) { Selected_ = val; return; }
3142 idx -= 2;
3143 llvm_unreachable("Invalid index");
3144}
3145
3146llvm::StringRef LogSoftMaxNode::getOutputName(unsigned idx) const {
3147 if (idx == 0) { return "Result"; }
3148 llvm_unreachable("Invalid index");
3149}
3150
3151std::string LogSoftMaxNode::getDebugDesc() const {
3152 DescriptionBuilder db(getKindName());
3153 db.addParam("Name", separateString(getName(), 100, "\n"));
3154 if (hasPredicate()) db.addParam("Predicate", "Yes");
3155 db
3156 .addParam("Input", *(getInput().getType()))
3157 .addParam("Selected", *(getSelected().getType()))
3158 .addParam("Users", getNumUsers());
3159 db.addParam("Result", *(getResult().getType()));
3160 return db;
3161}
3162
3163void LogSoftMaxNode::visit(Node *parent, NodeWalker *visitor) {
3164 if (!visitor->shouldVisit(parent, this)) { return; }
3165 visitor->pre(parent, this);
3166if (hasPredicate())
3167 getPredicate().getNode()->visit(this, visitor);
3168 getInput().getNode()->visit(this, visitor);
3169 getSelected().getNode()->visit(this, visitor);
3170 visitor->post(parent, this);
3171}
3172
3173bool LogSoftMaxNode::isEqual(const LogSoftMaxNode &other) const {
3174 return true &&
3175 Input_ == other.Input_ &&
3176 Selected_ == other.Selected_ &&
3177 predicate_ == other.predicate_ &&
3178 getType(0) == other.getType(0);
3179}
3180
3181Node* LogSoftMaxNode::clone() const {
3182 return new LogSoftMaxNode(getName(), getResult().getType(), getInput(), getSelected());
3183}
3184
3185llvm::hash_code LogSoftMaxNode::getHash() const {
3186 return llvm::hash_combine(
3187 Input_,
3188 Selected_);
3189}
3190
3191LogSoftMaxGradNode *LogSoftMaxNode::getGrad(GraphGradMapper &builder) {
3192 auto *x = new LogSoftMaxGradNode(getName().str() + "_grad", getInput(), getSelected(), getResult(), builder.getGradient(getResult()));
3193 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
3194 builder.addGradient(getSelected(), x->getGradOfInputNamedSelected());
3195 return x;
3196}
3197
3198unsigned CrossEntropyLossGradNode::getNumInputs() const {
3199 return 4;
3200}
3201
3202std::string CrossEntropyLossGradNode::getInputName(unsigned idx) const {
3203 if (idx == 0) { return "P"; }
3204 if (idx == 1) { return "Labels"; }
3205 if (idx == 2) { return "OriginalOutputForCE"; }
3206 if (idx == 3) { return "GradOfOriginalOutputNamedCE"; }
3207 idx -= 4;
3208 llvm_unreachable("Invalid index");
3209}
3210
3211NodeValue CrossEntropyLossGradNode::getNthInput(unsigned idx) {
3212 if (idx == 0) { return P_; }
3213 if (idx == 1) { return Labels_; }
3214 if (idx == 2) { return OriginalOutputForCE_; }
3215 if (idx == 3) { return GradOfOriginalOutputNamedCE_; }
3216 idx -= 4;
3217 llvm_unreachable("Invalid index");
3218}
3219
3220void CrossEntropyLossGradNode::setNthInput(unsigned idx, NodeValue val) {
3221 if (idx == 0) { P_ = val; return; }
3222 if (idx == 1) { Labels_ = val; return; }
3223 if (idx == 2) { OriginalOutputForCE_ = val; return; }
3224 if (idx == 3) { GradOfOriginalOutputNamedCE_ = val; return; }
3225 idx -= 4;
3226 llvm_unreachable("Invalid index");
3227}
3228
3229llvm::StringRef CrossEntropyLossGradNode::getOutputName(unsigned idx) const {
3230 if (idx == 0) { return "GradOfInputNamedP"; }
3231 if (idx == 1) { return "GradOfInputNamedLabels"; }
3232 llvm_unreachable("Invalid index");
3233}
3234
3235std::string CrossEntropyLossGradNode::getDebugDesc() const {
3236 DescriptionBuilder db(getKindName());
3237 db.addParam("Name", separateString(getName(), 100, "\n"));
3238 if (hasPredicate()) db.addParam("Predicate", "Yes");
3239 db
3240 .addParam("P", *(getP().getType()))
3241 .addParam("Labels", *(getLabels().getType()))
3242 .addParam("OriginalOutputForCE", *(getOriginalOutputForCE().getType()))
3243 .addParam("GradOfOriginalOutputNamedCE", *(getGradOfOriginalOutputNamedCE().getType()))
3244 .addParam("Users", getNumUsers());
3245 db.addParam("GradOfInputNamedP", *(getGradOfInputNamedP().getType()));
3246 db.addParam("GradOfInputNamedLabels", *(getGradOfInputNamedLabels().getType()));
3247 return db;
3248}
3249
3250void CrossEntropyLossGradNode::visit(Node *parent, NodeWalker *visitor) {
3251 if (!visitor->shouldVisit(parent, this)) { return; }
3252 visitor->pre(parent, this);
3253if (hasPredicate())
3254 getPredicate().getNode()->visit(this, visitor);
3255 getP().getNode()->visit(this, visitor);
3256 getLabels().getNode()->visit(this, visitor);
3257 getOriginalOutputForCE().getNode()->visit(this, visitor);
3258 getGradOfOriginalOutputNamedCE().getNode()->visit(this, visitor);
3259 visitor->post(parent, this);
3260}
3261
3262bool CrossEntropyLossGradNode::isEqual(const CrossEntropyLossGradNode &other) const {
3263 return true &&
3264 P_ == other.P_ &&
3265 Labels_ == other.Labels_ &&
3266 OriginalOutputForCE_ == other.OriginalOutputForCE_ &&
3267 GradOfOriginalOutputNamedCE_ == other.GradOfOriginalOutputNamedCE_ &&
3268 predicate_ == other.predicate_ &&
3269 getType(0) == other.getType(0) &&
3270 getType(1) == other.getType(1);
3271}
3272
3273Node* CrossEntropyLossGradNode::clone() const {
3274 return new CrossEntropyLossGradNode(getName(), getP(), getLabels(), getOriginalOutputForCE(), getGradOfOriginalOutputNamedCE());
3275}
3276
3277llvm::hash_code CrossEntropyLossGradNode::getHash() const {
3278 return llvm::hash_combine(
3279 P_,
3280 Labels_,
3281 OriginalOutputForCE_,
3282 GradOfOriginalOutputNamedCE_);
3283}
3284
3285unsigned CrossEntropyLossNode::getNumInputs() const {
3286 return 2;
3287}
3288
3289std::string CrossEntropyLossNode::getInputName(unsigned idx) const {
3290 if (idx == 0) { return "P"; }
3291 if (idx == 1) { return "Labels"; }
3292 idx -= 2;
3293 llvm_unreachable("Invalid index");
3294}
3295
3296NodeValue CrossEntropyLossNode::getNthInput(unsigned idx) {
3297 if (idx == 0) { return P_; }
3298 if (idx == 1) { return Labels_; }
3299 idx -= 2;
3300 llvm_unreachable("Invalid index");
3301}
3302
3303void CrossEntropyLossNode::setNthInput(unsigned idx, NodeValue val) {
3304 if (idx == 0) { P_ = val; return; }
3305 if (idx == 1) { Labels_ = val; return; }
3306 idx -= 2;
3307 llvm_unreachable("Invalid index");
3308}
3309
3310llvm::StringRef CrossEntropyLossNode::getOutputName(unsigned idx) const {
3311 if (idx == 0) { return "CE"; }
3312 llvm_unreachable("Invalid index");
3313}
3314
3315std::string CrossEntropyLossNode::getDebugDesc() const {
3316 DescriptionBuilder db(getKindName());
3317 db.addParam("Name", separateString(getName(), 100, "\n"));
3318 if (hasPredicate()) db.addParam("Predicate", "Yes");
3319 db
3320 .addParam("P", *(getP().getType()))
3321 .addParam("Labels", *(getLabels().getType()))
3322 .addParam("Users", getNumUsers());
3323 db.addParam("CE", *(getCE().getType()));
3324 return db;
3325}
3326
3327void CrossEntropyLossNode::visit(Node *parent, NodeWalker *visitor) {
3328 if (!visitor->shouldVisit(parent, this)) { return; }
3329 visitor->pre(parent, this);
3330if (hasPredicate())
3331 getPredicate().getNode()->visit(this, visitor);
3332 getP().getNode()->visit(this, visitor);
3333 getLabels().getNode()->visit(this, visitor);
3334 visitor->post(parent, this);
3335}
3336
3337bool CrossEntropyLossNode::isEqual(const CrossEntropyLossNode &other) const {
3338 return true &&
3339 P_ == other.P_ &&
3340 Labels_ == other.Labels_ &&
3341 predicate_ == other.predicate_ &&
3342 getType(0) == other.getType(0);
3343}
3344
3345Node* CrossEntropyLossNode::clone() const {
3346 return new CrossEntropyLossNode(getName(), getCE().getType(), getP(), getLabels());
3347}
3348
3349llvm::hash_code CrossEntropyLossNode::getHash() const {
3350 return llvm::hash_combine(
3351 P_,
3352 Labels_);
3353}
3354
3355CrossEntropyLossGradNode *CrossEntropyLossNode::getGrad(GraphGradMapper &builder) {
3356 auto *x = new CrossEntropyLossGradNode(getName().str() + "_grad", getP(), getLabels(), getCE(), builder.getGradient(getCE()));
3357 builder.addGradient(getP(), x->getGradOfInputNamedP());
3358 builder.addGradient(getLabels(), x->getGradOfInputNamedLabels());
3359 return x;
3360}
3361
3362unsigned RegressionGradNode::getNumInputs() const {
3363 return 4;
3364}
3365
3366std::string RegressionGradNode::getInputName(unsigned idx) const {
3367 if (idx == 0) { return "Input"; }
3368 if (idx == 1) { return "Expected"; }
3369 if (idx == 2) { return "OriginalOutputForResult"; }
3370 if (idx == 3) { return "GradOfOriginalOutputNamedResult"; }
3371 idx -= 4;
3372 llvm_unreachable("Invalid index");
3373}
3374
3375NodeValue RegressionGradNode::getNthInput(unsigned idx) {
3376 if (idx == 0) { return Input_; }
3377 if (idx == 1) { return Expected_; }
3378 if (idx == 2) { return OriginalOutputForResult_; }
3379 if (idx == 3) { return GradOfOriginalOutputNamedResult_; }
3380 idx -= 4;
3381 llvm_unreachable("Invalid index");
3382}
3383
3384void RegressionGradNode::setNthInput(unsigned idx, NodeValue val) {
3385 if (idx == 0) { Input_ = val; return; }
3386 if (idx == 1) { Expected_ = val; return; }
3387 if (idx == 2) { OriginalOutputForResult_ = val; return; }
3388 if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; }
3389 idx -= 4;
3390 llvm_unreachable("Invalid index");
3391}
3392
3393llvm::StringRef RegressionGradNode::getOutputName(unsigned idx) const {
3394 if (idx == 0) { return "GradOfInputNamedInput"; }
3395 if (idx == 1) { return "GradOfInputNamedExpected"; }
3396 llvm_unreachable("Invalid index");
3397}
3398
3399std::string RegressionGradNode::getDebugDesc() const {
3400 DescriptionBuilder db(getKindName());
3401 db.addParam("Name", separateString(getName(), 100, "\n"));
3402 if (hasPredicate()) db.addParam("Predicate", "Yes");
3403 db
3404 .addParam("Input", *(getInput().getType()))
3405 .addParam("Expected", *(getExpected().getType()))
3406 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
3407 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
3408 .addParam("Users", getNumUsers());
3409 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
3410 db.addParam("GradOfInputNamedExpected", *(getGradOfInputNamedExpected().getType()));
3411 return db;
3412}
3413
3414void RegressionGradNode::visit(Node *parent, NodeWalker *visitor) {
3415 if (!visitor->shouldVisit(parent, this)) { return; }
3416 visitor->pre(parent, this);
3417if (hasPredicate())
3418 getPredicate().getNode()->visit(this, visitor);
3419 getInput().getNode()->visit(this, visitor);
3420 getExpected().getNode()->visit(this, visitor);
3421 getOriginalOutputForResult().getNode()->visit(this, visitor);
3422 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
3423 visitor->post(parent, this);
3424}
3425
3426bool RegressionGradNode::isEqual(const RegressionGradNode &other) const {
3427 return true &&
3428 Input_ == other.Input_ &&
3429 Expected_ == other.Expected_ &&
3430 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
3431 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
3432 predicate_ == other.predicate_ &&
3433 getType(0) == other.getType(0) &&
3434 getType(1) == other.getType(1);
3435}
3436
3437Node* RegressionGradNode::clone() const {
3438 return new RegressionGradNode(getName(), getInput(), getExpected(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
3439}
3440
3441llvm::hash_code RegressionGradNode::getHash() const {
3442 return llvm::hash_combine(
3443 Input_,
3444 Expected_,
3445 OriginalOutputForResult_,
3446 GradOfOriginalOutputNamedResult_);
3447}
3448
3449unsigned RegressionNode::getNumInputs() const {
3450 return 2;
3451}
3452
3453std::string RegressionNode::getInputName(unsigned idx) const {
3454 if (idx == 0) { return "Input"; }
3455 if (idx == 1) { return "Expected"; }
3456 idx -= 2;
3457 llvm_unreachable("Invalid index");
3458}
3459
3460NodeValue RegressionNode::getNthInput(unsigned idx) {
3461 if (idx == 0) { return Input_; }
3462 if (idx == 1) { return Expected_; }
3463 idx -= 2;
3464 llvm_unreachable("Invalid index");
3465}
3466
3467void RegressionNode::setNthInput(unsigned idx, NodeValue val) {
3468 if (idx == 0) { Input_ = val; return; }
3469 if (idx == 1) { Expected_ = val; return; }
3470 idx -= 2;
3471 llvm_unreachable("Invalid index");
3472}
3473
3474llvm::StringRef RegressionNode::getOutputName(unsigned idx) const {
3475 if (idx == 0) { return "Result"; }
3476 llvm_unreachable("Invalid index");
3477}
3478
3479std::string RegressionNode::getDebugDesc() const {
3480 DescriptionBuilder db(getKindName());
3481 db.addParam("Name", separateString(getName(), 100, "\n"));
3482 if (hasPredicate()) db.addParam("Predicate", "Yes");
3483 db
3484 .addParam("Input", *(getInput().getType()))
3485 .addParam("Expected", *(getExpected().getType()))
3486 .addParam("Users", getNumUsers());
3487 db.addParam("Result", *(getResult().getType()));
3488 return db;
3489}
3490
3491void RegressionNode::visit(Node *parent, NodeWalker *visitor) {
3492 if (!visitor->shouldVisit(parent, this)) { return; }
3493 visitor->pre(parent, this);
3494if (hasPredicate())
3495 getPredicate().getNode()->visit(this, visitor);
3496 getInput().getNode()->visit(this, visitor);
3497 getExpected().getNode()->visit(this, visitor);
3498 visitor->post(parent, this);
3499}
3500
3501bool RegressionNode::isEqual(const RegressionNode &other) const {
3502 return true &&
3503 Input_ == other.Input_ &&
3504 Expected_ == other.Expected_ &&
3505 predicate_ == other.predicate_ &&
3506 getType(0) == other.getType(0);
3507}
3508
3509Node* RegressionNode::clone() const {
3510 return new RegressionNode(getName(), getInput(), getExpected());
3511}
3512
3513llvm::hash_code RegressionNode::getHash() const {
3514 return llvm::hash_combine(
3515 Input_,
3516 Expected_);
3517}
3518
3519RegressionGradNode *RegressionNode::getGrad(GraphGradMapper &builder) {
3520 auto *x = new RegressionGradNode(getName().str() + "_grad", getInput(), getExpected(), getResult(), builder.getGradient(getResult()));
3521 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
3522 builder.addGradient(getExpected(), x->getGradOfInputNamedExpected());
3523 return x;
3524}
3525
3526unsigned SigmoidCrossEntropyWithLogitsNode::getNumInputs() const {
3527 return 2;
3528}
3529
3530std::string SigmoidCrossEntropyWithLogitsNode::getInputName(unsigned idx) const {
3531 if (idx == 0) { return "Logits"; }
3532 if (idx == 1) { return "Targets"; }
3533 idx -= 2;
3534 llvm_unreachable("Invalid index");
3535}
3536
3537NodeValue SigmoidCrossEntropyWithLogitsNode::getNthInput(unsigned idx) {
3538 if (idx == 0) { return Logits_; }
3539 if (idx == 1) { return Targets_; }
3540 idx -= 2;
3541 llvm_unreachable("Invalid index");
3542}
3543
3544void SigmoidCrossEntropyWithLogitsNode::setNthInput(unsigned idx, NodeValue val) {
3545 if (idx == 0) { Logits_ = val; return; }
3546 if (idx == 1) { Targets_ = val; return; }
3547 idx -= 2;
3548 llvm_unreachable("Invalid index");
3549}
3550
3551llvm::StringRef SigmoidCrossEntropyWithLogitsNode::getOutputName(unsigned idx) const {
3552 if (idx == 0) { return "Result"; }
3553 llvm_unreachable("Invalid index");
3554}
3555
3556std::string SigmoidCrossEntropyWithLogitsNode::getDebugDesc() const {
3557 DescriptionBuilder db(getKindName());
3558 db.addParam("Name", separateString(getName(), 100, "\n"));
3559 if (hasPredicate()) db.addParam("Predicate", "Yes");
3560 db
3561 .addParam("Logits", *(getLogits().getType()))
3562 .addParam("Targets", *(getTargets().getType()))
3563 .addParam("Users", getNumUsers());
3564 db.addParam("Result", *(getResult().getType()));
3565 return db;
3566}
3567
3568void SigmoidCrossEntropyWithLogitsNode::visit(Node *parent, NodeWalker *visitor) {
3569 if (!visitor->shouldVisit(parent, this)) { return; }
3570 visitor->pre(parent, this);
3571if (hasPredicate())
3572 getPredicate().getNode()->visit(this, visitor);
3573 getLogits().getNode()->visit(this, visitor);
3574 getTargets().getNode()->visit(this, visitor);
3575 visitor->post(parent, this);
3576}
3577
3578bool SigmoidCrossEntropyWithLogitsNode::isEqual(const SigmoidCrossEntropyWithLogitsNode &other) const {
3579 return true &&
3580 Logits_ == other.Logits_ &&
3581 Targets_ == other.Targets_ &&
3582 predicate_ == other.predicate_ &&
3583 getType(0) == other.getType(0);
3584}
3585
3586Node* SigmoidCrossEntropyWithLogitsNode::clone() const {
3587 return new SigmoidCrossEntropyWithLogitsNode(getName(), getResult().getType(), getLogits(), getTargets());
3588}
3589
3590llvm::hash_code SigmoidCrossEntropyWithLogitsNode::getHash() const {
3591 return llvm::hash_combine(
3592 Logits_,
3593 Targets_);
3594}
3595
3596unsigned AddGradNode::getNumInputs() const {
3597 return 4;
3598}
3599
3600std::string AddGradNode::getInputName(unsigned idx) const {
3601 if (idx == 0) { return "LHS"; }
3602 if (idx == 1) { return "RHS"; }
3603 if (idx == 2) { return "OriginalOutputForResult"; }
3604 if (idx == 3) { return "GradOfOriginalOutputNamedResult"; }
3605 idx -= 4;
3606 llvm_unreachable("Invalid index");
3607}
3608
3609NodeValue AddGradNode::getNthInput(unsigned idx) {
3610 if (idx == 0) { return LHS_; }
3611 if (idx == 1) { return RHS_; }
3612 if (idx == 2) { return OriginalOutputForResult_; }
3613 if (idx == 3) { return GradOfOriginalOutputNamedResult_; }
3614 idx -= 4;
3615 llvm_unreachable("Invalid index");
3616}
3617
3618void AddGradNode::setNthInput(unsigned idx, NodeValue val) {
3619 if (idx == 0) { LHS_ = val; return; }
3620 if (idx == 1) { RHS_ = val; return; }
3621 if (idx == 2) { OriginalOutputForResult_ = val; return; }
3622 if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; }
3623 idx -= 4;
3624 llvm_unreachable("Invalid index");
3625}
3626
3627llvm::StringRef AddGradNode::getOutputName(unsigned idx) const {
3628 if (idx == 0) { return "GradOfInputNamedLHS"; }
3629 if (idx == 1) { return "GradOfInputNamedRHS"; }
3630 llvm_unreachable("Invalid index");
3631}
3632
3633std::string AddGradNode::getDebugDesc() const {
3634 DescriptionBuilder db(getKindName());
3635 db.addParam("Name", separateString(getName(), 100, "\n"));
3636 if (hasPredicate()) db.addParam("Predicate", "Yes");
3637 db
3638 .addParam("LHS", *(getLHS().getType()))
3639 .addParam("RHS", *(getRHS().getType()))
3640 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
3641 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
3642 .addParam("Users", getNumUsers());
3643 db.addParam("GradOfInputNamedLHS", *(getGradOfInputNamedLHS().getType()));
3644 db.addParam("GradOfInputNamedRHS", *(getGradOfInputNamedRHS().getType()));
3645 return db;
3646}
3647
3648void AddGradNode::visit(Node *parent, NodeWalker *visitor) {
3649 if (!visitor->shouldVisit(parent, this)) { return; }
3650 visitor->pre(parent, this);
3651if (hasPredicate())
3652 getPredicate().getNode()->visit(this, visitor);
3653 getLHS().getNode()->visit(this, visitor);
3654 getRHS().getNode()->visit(this, visitor);
3655 getOriginalOutputForResult().getNode()->visit(this, visitor);
3656 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
3657 visitor->post(parent, this);
3658}
3659
3660bool AddGradNode::isEqual(const AddGradNode &other) const {
3661 return true &&
3662 LHS_ == other.LHS_ &&
3663 RHS_ == other.RHS_ &&
3664 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
3665 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
3666 predicate_ == other.predicate_ &&
3667 getType(0) == other.getType(0) &&
3668 getType(1) == other.getType(1);
3669}
3670
3671Node* AddGradNode::clone() const {
3672 return new AddGradNode(getName(), getLHS(), getRHS(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
3673}
3674
3675llvm::hash_code AddGradNode::getHash() const {
3676 return llvm::hash_combine(
3677 LHS_,
3678 RHS_,
3679 OriginalOutputForResult_,
3680 GradOfOriginalOutputNamedResult_);
3681}
3682
3683unsigned AddNode::getNumInputs() const {
3684 return 2;
3685}
3686
3687std::string AddNode::getInputName(unsigned idx) const {
3688 if (idx == 0) { return "LHS"; }
3689 if (idx == 1) { return "RHS"; }
3690 idx -= 2;
3691 llvm_unreachable("Invalid index");
3692}
3693
3694NodeValue AddNode::getNthInput(unsigned idx) {
3695 if (idx == 0) { return LHS_; }
3696 if (idx == 1) { return RHS_; }
3697 idx -= 2;
3698 llvm_unreachable("Invalid index");
3699}
3700
3701void AddNode::setNthInput(unsigned idx, NodeValue val) {
3702 if (idx == 0) { LHS_ = val; return; }
3703 if (idx == 1) { RHS_ = val; return; }
3704 idx -= 2;
3705 llvm_unreachable("Invalid index");
3706}
3707
3708llvm::StringRef AddNode::getOutputName(unsigned idx) const {
3709 if (idx == 0) { return "Result"; }
3710 llvm_unreachable("Invalid index");
3711}
3712
3713std::string AddNode::getDebugDesc() const {
3714 DescriptionBuilder db(getKindName());
3715 db.addParam("Name", separateString(getName(), 100, "\n"));
3716 if (hasPredicate()) db.addParam("Predicate", "Yes");
3717 db
3718 .addParam("LHS", *(getLHS().getType()))
3719 .addParam("RHS", *(getRHS().getType()))
3720 .addParam("Users", getNumUsers());
3721 db.addParam("Result", *(getResult().getType()));
3722 return db;
3723}
3724
3725void AddNode::visit(Node *parent, NodeWalker *visitor) {
3726 if (!visitor->shouldVisit(parent, this)) { return; }
3727 visitor->pre(parent, this);
3728if (hasPredicate())
3729 getPredicate().getNode()->visit(this, visitor);
3730 getLHS().getNode()->visit(this, visitor);
3731 getRHS().getNode()->visit(this, visitor);
3732 visitor->post(parent, this);
3733}
3734
3735bool AddNode::isEqual(const AddNode &other) const {
3736 return true &&
3737 LHS_ == other.LHS_ &&
3738 RHS_ == other.RHS_ &&
3739 predicate_ == other.predicate_ &&
3740 getType(0) == other.getType(0);
3741}
3742
3743Node* AddNode::clone() const {
3744 return new AddNode(getName(), getResult().getType(), getLHS(), getRHS());
3745}
3746
3747llvm::hash_code AddNode::getHash() const {
3748 return llvm::hash_combine(
3749 LHS_,
3750 RHS_);
3751}
3752
3753AddGradNode *AddNode::getGrad(GraphGradMapper &builder) {
3754 auto *x = new AddGradNode(getName().str() + "_grad", getLHS(), getRHS(), getResult(), builder.getGradient(getResult()));
3755 builder.addGradient(getLHS(), x->getGradOfInputNamedLHS());
3756 builder.addGradient(getRHS(), x->getGradOfInputNamedRHS());
3757 return x;
3758}
3759
3760unsigned MulGradNode::getNumInputs() const {
3761 return 4;
3762}
3763
3764std::string MulGradNode::getInputName(unsigned idx) const {
3765 if (idx == 0) { return "LHS"; }
3766 if (idx == 1) { return "RHS"; }
3767 if (idx == 2) { return "OriginalOutputForResult"; }
3768 if (idx == 3) { return "GradOfOriginalOutputNamedResult"; }
3769 idx -= 4;
3770 llvm_unreachable("Invalid index");
3771}
3772
3773NodeValue MulGradNode::getNthInput(unsigned idx) {
3774 if (idx == 0) { return LHS_; }
3775 if (idx == 1) { return RHS_; }
3776 if (idx == 2) { return OriginalOutputForResult_; }
3777 if (idx == 3) { return GradOfOriginalOutputNamedResult_; }
3778 idx -= 4;
3779 llvm_unreachable("Invalid index");
3780}
3781
3782void MulGradNode::setNthInput(unsigned idx, NodeValue val) {
3783 if (idx == 0) { LHS_ = val; return; }
3784 if (idx == 1) { RHS_ = val; return; }
3785 if (idx == 2) { OriginalOutputForResult_ = val; return; }
3786 if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; }
3787 idx -= 4;
3788 llvm_unreachable("Invalid index");
3789}
3790
3791llvm::StringRef MulGradNode::getOutputName(unsigned idx) const {
3792 if (idx == 0) { return "GradOfInputNamedLHS"; }
3793 if (idx == 1) { return "GradOfInputNamedRHS"; }
3794 llvm_unreachable("Invalid index");
3795}
3796
3797std::string MulGradNode::getDebugDesc() const {
3798 DescriptionBuilder db(getKindName());
3799 db.addParam("Name", separateString(getName(), 100, "\n"));
3800 if (hasPredicate()) db.addParam("Predicate", "Yes");
3801 db
3802 .addParam("LHS", *(getLHS().getType()))
3803 .addParam("RHS", *(getRHS().getType()))
3804 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
3805 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
3806 .addParam("Users", getNumUsers());
3807 db.addParam("GradOfInputNamedLHS", *(getGradOfInputNamedLHS().getType()));
3808 db.addParam("GradOfInputNamedRHS", *(getGradOfInputNamedRHS().getType()));
3809 return db;
3810}
3811
3812void MulGradNode::visit(Node *parent, NodeWalker *visitor) {
3813 if (!visitor->shouldVisit(parent, this)) { return; }
3814 visitor->pre(parent, this);
3815if (hasPredicate())
3816 getPredicate().getNode()->visit(this, visitor);
3817 getLHS().getNode()->visit(this, visitor);
3818 getRHS().getNode()->visit(this, visitor);
3819 getOriginalOutputForResult().getNode()->visit(this, visitor);
3820 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
3821 visitor->post(parent, this);
3822}
3823
3824bool MulGradNode::isEqual(const MulGradNode &other) const {
3825 return true &&
3826 LHS_ == other.LHS_ &&
3827 RHS_ == other.RHS_ &&
3828 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
3829 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
3830 predicate_ == other.predicate_ &&
3831 getType(0) == other.getType(0) &&
3832 getType(1) == other.getType(1);
3833}
3834
3835Node* MulGradNode::clone() const {
3836 return new MulGradNode(getName(), getLHS(), getRHS(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
3837}
3838
3839llvm::hash_code MulGradNode::getHash() const {
3840 return llvm::hash_combine(
3841 LHS_,
3842 RHS_,
3843 OriginalOutputForResult_,
3844 GradOfOriginalOutputNamedResult_);
3845}
3846
3847unsigned MulNode::getNumInputs() const {
3848 return 2;
3849}
3850
3851std::string MulNode::getInputName(unsigned idx) const {
3852 if (idx == 0) { return "LHS"; }
3853 if (idx == 1) { return "RHS"; }
3854 idx -= 2;
3855 llvm_unreachable("Invalid index");
3856}
3857
3858NodeValue MulNode::getNthInput(unsigned idx) {
3859 if (idx == 0) { return LHS_; }
3860 if (idx == 1) { return RHS_; }
3861 idx -= 2;
3862 llvm_unreachable("Invalid index");
3863}
3864
3865void MulNode::setNthInput(unsigned idx, NodeValue val) {
3866 if (idx == 0) { LHS_ = val; return; }
3867 if (idx == 1) { RHS_ = val; return; }
3868 idx -= 2;
3869 llvm_unreachable("Invalid index");
3870}
3871
3872llvm::StringRef MulNode::getOutputName(unsigned idx) const {
3873 if (idx == 0) { return "Result"; }
3874 llvm_unreachable("Invalid index");
3875}
3876
3877std::string MulNode::getDebugDesc() const {
3878 DescriptionBuilder db(getKindName());
3879 db.addParam("Name", separateString(getName(), 100, "\n"));
3880 if (hasPredicate()) db.addParam("Predicate", "Yes");
3881 db
3882 .addParam("LHS", *(getLHS().getType()))
3883 .addParam("RHS", *(getRHS().getType()))
3884 .addParam("Users", getNumUsers());
3885 db.addParam("Result", *(getResult().getType()));
3886 return db;
3887}
3888
3889void MulNode::visit(Node *parent, NodeWalker *visitor) {
3890 if (!visitor->shouldVisit(parent, this)) { return; }
3891 visitor->pre(parent, this);
3892if (hasPredicate())
3893 getPredicate().getNode()->visit(this, visitor);
3894 getLHS().getNode()->visit(this, visitor);
3895 getRHS().getNode()->visit(this, visitor);
3896 visitor->post(parent, this);
3897}
3898
3899bool MulNode::isEqual(const MulNode &other) const {
3900 return true &&
3901 LHS_ == other.LHS_ &&
3902 RHS_ == other.RHS_ &&
3903 predicate_ == other.predicate_ &&
3904 getType(0) == other.getType(0);
3905}
3906
3907Node* MulNode::clone() const {
3908 return new MulNode(getName(), getResult().getType(), getLHS(), getRHS());
3909}
3910
3911llvm::hash_code MulNode::getHash() const {
3912 return llvm::hash_combine(
3913 LHS_,
3914 RHS_);
3915}
3916
3917MulGradNode *MulNode::getGrad(GraphGradMapper &builder) {
3918 auto *x = new MulGradNode(getName().str() + "_grad", getLHS(), getRHS(), getResult(), builder.getGradient(getResult()));
3919 builder.addGradient(getLHS(), x->getGradOfInputNamedLHS());
3920 builder.addGradient(getRHS(), x->getGradOfInputNamedRHS());
3921 return x;
3922}
3923
3924unsigned SubGradNode::getNumInputs() const {
3925 return 4;
3926}
3927
3928std::string SubGradNode::getInputName(unsigned idx) const {
3929 if (idx == 0) { return "LHS"; }
3930 if (idx == 1) { return "RHS"; }
3931 if (idx == 2) { return "OriginalOutputForResult"; }
3932 if (idx == 3) { return "GradOfOriginalOutputNamedResult"; }
3933 idx -= 4;
3934 llvm_unreachable("Invalid index");
3935}
3936
3937NodeValue SubGradNode::getNthInput(unsigned idx) {
3938 if (idx == 0) { return LHS_; }
3939 if (idx == 1) { return RHS_; }
3940 if (idx == 2) { return OriginalOutputForResult_; }
3941 if (idx == 3) { return GradOfOriginalOutputNamedResult_; }
3942 idx -= 4;
3943 llvm_unreachable("Invalid index");
3944}
3945
3946void SubGradNode::setNthInput(unsigned idx, NodeValue val) {
3947 if (idx == 0) { LHS_ = val; return; }
3948 if (idx == 1) { RHS_ = val; return; }
3949 if (idx == 2) { OriginalOutputForResult_ = val; return; }
3950 if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; }
3951 idx -= 4;
3952 llvm_unreachable("Invalid index");
3953}
3954
3955llvm::StringRef SubGradNode::getOutputName(unsigned idx) const {
3956 if (idx == 0) { return "GradOfInputNamedLHS"; }
3957 if (idx == 1) { return "GradOfInputNamedRHS"; }
3958 llvm_unreachable("Invalid index");
3959}
3960
3961std::string SubGradNode::getDebugDesc() const {
3962 DescriptionBuilder db(getKindName());
3963 db.addParam("Name", separateString(getName(), 100, "\n"));
3964 if (hasPredicate()) db.addParam("Predicate", "Yes");
3965 db
3966 .addParam("LHS", *(getLHS().getType()))
3967 .addParam("RHS", *(getRHS().getType()))
3968 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
3969 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
3970 .addParam("Users", getNumUsers());
3971 db.addParam("GradOfInputNamedLHS", *(getGradOfInputNamedLHS().getType()));
3972 db.addParam("GradOfInputNamedRHS", *(getGradOfInputNamedRHS().getType()));
3973 return db;
3974}
3975
3976void SubGradNode::visit(Node *parent, NodeWalker *visitor) {
3977 if (!visitor->shouldVisit(parent, this)) { return; }
3978 visitor->pre(parent, this);
3979if (hasPredicate())
3980 getPredicate().getNode()->visit(this, visitor);
3981 getLHS().getNode()->visit(this, visitor);
3982 getRHS().getNode()->visit(this, visitor);
3983 getOriginalOutputForResult().getNode()->visit(this, visitor);
3984 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
3985 visitor->post(parent, this);
3986}
3987
3988bool SubGradNode::isEqual(const SubGradNode &other) const {
3989 return true &&
3990 LHS_ == other.LHS_ &&
3991 RHS_ == other.RHS_ &&
3992 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
3993 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
3994 predicate_ == other.predicate_ &&
3995 getType(0) == other.getType(0) &&
3996 getType(1) == other.getType(1);
3997}
3998
3999Node* SubGradNode::clone() const {
4000 return new SubGradNode(getName(), getLHS(), getRHS(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
4001}
4002
4003llvm::hash_code SubGradNode::getHash() const {
4004 return llvm::hash_combine(
4005 LHS_,
4006 RHS_,
4007 OriginalOutputForResult_,
4008 GradOfOriginalOutputNamedResult_);
4009}
4010
4011unsigned SubNode::getNumInputs() const {
4012 return 2;
4013}
4014
4015std::string SubNode::getInputName(unsigned idx) const {
4016 if (idx == 0) { return "LHS"; }
4017 if (idx == 1) { return "RHS"; }
4018 idx -= 2;
4019 llvm_unreachable("Invalid index");
4020}
4021
4022NodeValue SubNode::getNthInput(unsigned idx) {
4023 if (idx == 0) { return LHS_; }
4024 if (idx == 1) { return RHS_; }
4025 idx -= 2;
4026 llvm_unreachable("Invalid index");
4027}
4028
4029void SubNode::setNthInput(unsigned idx, NodeValue val) {
4030 if (idx == 0) { LHS_ = val; return; }
4031 if (idx == 1) { RHS_ = val; return; }
4032 idx -= 2;
4033 llvm_unreachable("Invalid index");
4034}
4035
4036llvm::StringRef SubNode::getOutputName(unsigned idx) const {
4037 if (idx == 0) { return "Result"; }
4038 llvm_unreachable("Invalid index");
4039}
4040
4041std::string SubNode::getDebugDesc() const {
4042 DescriptionBuilder db(getKindName());
4043 db.addParam("Name", separateString(getName(), 100, "\n"));
4044 if (hasPredicate()) db.addParam("Predicate", "Yes");
4045 db
4046 .addParam("LHS", *(getLHS().getType()))
4047 .addParam("RHS", *(getRHS().getType()))
4048 .addParam("Users", getNumUsers());
4049 db.addParam("Result", *(getResult().getType()));
4050 return db;
4051}
4052
4053void SubNode::visit(Node *parent, NodeWalker *visitor) {
4054 if (!visitor->shouldVisit(parent, this)) { return; }
4055 visitor->pre(parent, this);
4056if (hasPredicate())
4057 getPredicate().getNode()->visit(this, visitor);
4058 getLHS().getNode()->visit(this, visitor);
4059 getRHS().getNode()->visit(this, visitor);
4060 visitor->post(parent, this);
4061}
4062
4063bool SubNode::isEqual(const SubNode &other) const {
4064 return true &&
4065 LHS_ == other.LHS_ &&
4066 RHS_ == other.RHS_ &&
4067 predicate_ == other.predicate_ &&
4068 getType(0) == other.getType(0);
4069}
4070
4071Node* SubNode::clone() const {
4072 return new SubNode(getName(), getResult().getType(), getLHS(), getRHS());
4073}
4074
4075llvm::hash_code SubNode::getHash() const {
4076 return llvm::hash_combine(
4077 LHS_,
4078 RHS_);
4079}
4080
4081SubGradNode *SubNode::getGrad(GraphGradMapper &builder) {
4082 auto *x = new SubGradNode(getName().str() + "_grad", getLHS(), getRHS(), getResult(), builder.getGradient(getResult()));
4083 builder.addGradient(getLHS(), x->getGradOfInputNamedLHS());
4084 builder.addGradient(getRHS(), x->getGradOfInputNamedRHS());
4085 return x;
4086}
4087
4088unsigned DivGradNode::getNumInputs() const {
4089 return 4;
4090}
4091
4092std::string DivGradNode::getInputName(unsigned idx) const {
4093 if (idx == 0) { return "LHS"; }
4094 if (idx == 1) { return "RHS"; }
4095 if (idx == 2) { return "OriginalOutputForResult"; }
4096 if (idx == 3) { return "GradOfOriginalOutputNamedResult"; }
4097 idx -= 4;
4098 llvm_unreachable("Invalid index");
4099}
4100
4101NodeValue DivGradNode::getNthInput(unsigned idx) {
4102 if (idx == 0) { return LHS_; }
4103 if (idx == 1) { return RHS_; }
4104 if (idx == 2) { return OriginalOutputForResult_; }
4105 if (idx == 3) { return GradOfOriginalOutputNamedResult_; }
4106 idx -= 4;
4107 llvm_unreachable("Invalid index");
4108}
4109
4110void DivGradNode::setNthInput(unsigned idx, NodeValue val) {
4111 if (idx == 0) { LHS_ = val; return; }
4112 if (idx == 1) { RHS_ = val; return; }
4113 if (idx == 2) { OriginalOutputForResult_ = val; return; }
4114 if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; }
4115 idx -= 4;
4116 llvm_unreachable("Invalid index");
4117}
4118
4119llvm::StringRef DivGradNode::getOutputName(unsigned idx) const {
4120 if (idx == 0) { return "GradOfInputNamedLHS"; }
4121 if (idx == 1) { return "GradOfInputNamedRHS"; }
4122 llvm_unreachable("Invalid index");
4123}
4124
4125std::string DivGradNode::getDebugDesc() const {
4126 DescriptionBuilder db(getKindName());
4127 db.addParam("Name", separateString(getName(), 100, "\n"));
4128 if (hasPredicate()) db.addParam("Predicate", "Yes");
4129 db
4130 .addParam("LHS", *(getLHS().getType()))
4131 .addParam("RHS", *(getRHS().getType()))
4132 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
4133 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
4134 .addParam("Users", getNumUsers());
4135 db.addParam("GradOfInputNamedLHS", *(getGradOfInputNamedLHS().getType()));
4136 db.addParam("GradOfInputNamedRHS", *(getGradOfInputNamedRHS().getType()));
4137 return db;
4138}
4139
4140void DivGradNode::visit(Node *parent, NodeWalker *visitor) {
4141 if (!visitor->shouldVisit(parent, this)) { return; }
4142 visitor->pre(parent, this);
4143if (hasPredicate())
4144 getPredicate().getNode()->visit(this, visitor);
4145 getLHS().getNode()->visit(this, visitor);
4146 getRHS().getNode()->visit(this, visitor);
4147 getOriginalOutputForResult().getNode()->visit(this, visitor);
4148 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
4149 visitor->post(parent, this);
4150}
4151
4152bool DivGradNode::isEqual(const DivGradNode &other) const {
4153 return true &&
4154 LHS_ == other.LHS_ &&
4155 RHS_ == other.RHS_ &&
4156 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
4157 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
4158 predicate_ == other.predicate_ &&
4159 getType(0) == other.getType(0) &&
4160 getType(1) == other.getType(1);
4161}
4162
4163Node* DivGradNode::clone() const {
4164 return new DivGradNode(getName(), getLHS(), getRHS(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
4165}
4166
4167llvm::hash_code DivGradNode::getHash() const {
4168 return llvm::hash_combine(
4169 LHS_,
4170 RHS_,
4171 OriginalOutputForResult_,
4172 GradOfOriginalOutputNamedResult_);
4173}
4174
4175unsigned DivNode::getNumInputs() const {
4176 return 2;
4177}
4178
4179std::string DivNode::getInputName(unsigned idx) const {
4180 if (idx == 0) { return "LHS"; }
4181 if (idx == 1) { return "RHS"; }
4182 idx -= 2;
4183 llvm_unreachable("Invalid index");
4184}
4185
4186NodeValue DivNode::getNthInput(unsigned idx) {
4187 if (idx == 0) { return LHS_; }
4188 if (idx == 1) { return RHS_; }
4189 idx -= 2;
4190 llvm_unreachable("Invalid index");
4191}
4192
4193void DivNode::setNthInput(unsigned idx, NodeValue val) {
4194 if (idx == 0) { LHS_ = val; return; }
4195 if (idx == 1) { RHS_ = val; return; }
4196 idx -= 2;
4197 llvm_unreachable("Invalid index");
4198}
4199
4200llvm::StringRef DivNode::getOutputName(unsigned idx) const {
4201 if (idx == 0) { return "Result"; }
4202 llvm_unreachable("Invalid index");
4203}
4204
4205std::string DivNode::getDebugDesc() const {
4206 DescriptionBuilder db(getKindName());
4207 db.addParam("Name", separateString(getName(), 100, "\n"));
4208 if (hasPredicate()) db.addParam("Predicate", "Yes");
4209 db
4210 .addParam("LHS", *(getLHS().getType()))
4211 .addParam("RHS", *(getRHS().getType()))
4212 .addParam("Users", getNumUsers());
4213 db.addParam("Result", *(getResult().getType()));
4214 return db;
4215}
4216
4217void DivNode::visit(Node *parent, NodeWalker *visitor) {
4218 if (!visitor->shouldVisit(parent, this)) { return; }
4219 visitor->pre(parent, this);
4220if (hasPredicate())
4221 getPredicate().getNode()->visit(this, visitor);
4222 getLHS().getNode()->visit(this, visitor);
4223 getRHS().getNode()->visit(this, visitor);
4224 visitor->post(parent, this);
4225}
4226
4227bool DivNode::isEqual(const DivNode &other) const {
4228 return true &&
4229 LHS_ == other.LHS_ &&
4230 RHS_ == other.RHS_ &&
4231 predicate_ == other.predicate_ &&
4232 getType(0) == other.getType(0);
4233}
4234
4235Node* DivNode::clone() const {
4236 return new DivNode(getName(), getResult().getType(), getLHS(), getRHS());
4237}
4238
4239llvm::hash_code DivNode::getHash() const {
4240 return llvm::hash_combine(
4241 LHS_,
4242 RHS_);
4243}
4244
4245DivGradNode *DivNode::getGrad(GraphGradMapper &builder) {
4246 auto *x = new DivGradNode(getName().str() + "_grad", getLHS(), getRHS(), getResult(), builder.getGradient(getResult()));
4247 builder.addGradient(getLHS(), x->getGradOfInputNamedLHS());
4248 builder.addGradient(getRHS(), x->getGradOfInputNamedRHS());
4249 return x;
4250}
4251
4252unsigned FloorDivNode::getNumInputs() const {
4253 return 2;
4254}
4255
4256std::string FloorDivNode::getInputName(unsigned idx) const {
4257 if (idx == 0) { return "LHS"; }
4258 if (idx == 1) { return "RHS"; }
4259 idx -= 2;
4260 llvm_unreachable("Invalid index");
4261}
4262
4263NodeValue FloorDivNode::getNthInput(unsigned idx) {
4264 if (idx == 0) { return LHS_; }
4265 if (idx == 1) { return RHS_; }
4266 idx -= 2;
4267 llvm_unreachable("Invalid index");
4268}
4269
4270void FloorDivNode::setNthInput(unsigned idx, NodeValue val) {
4271 if (idx == 0) { LHS_ = val; return; }
4272 if (idx == 1) { RHS_ = val; return; }
4273 idx -= 2;
4274 llvm_unreachable("Invalid index");
4275}
4276
4277llvm::StringRef FloorDivNode::getOutputName(unsigned idx) const {
4278 if (idx == 0) { return "Result"; }
4279 llvm_unreachable("Invalid index");
4280}
4281
4282std::string FloorDivNode::getDebugDesc() const {
4283 DescriptionBuilder db(getKindName());
4284 db.addParam("Name", separateString(getName(), 100, "\n"));
4285 if (hasPredicate()) db.addParam("Predicate", "Yes");
4286 db
4287 .addParam("LHS", *(getLHS().getType()))
4288 .addParam("RHS", *(getRHS().getType()))
4289 .addParam("Truncate", getTruncate())
4290 .addParam("Users", getNumUsers());
4291 db.addParam("Result", *(getResult().getType()));
4292 return db;
4293}
4294
4295void FloorDivNode::visit(Node *parent, NodeWalker *visitor) {
4296 if (!visitor->shouldVisit(parent, this)) { return; }
4297 visitor->pre(parent, this);
4298if (hasPredicate())
4299 getPredicate().getNode()->visit(this, visitor);
4300 getLHS().getNode()->visit(this, visitor);
4301 getRHS().getNode()->visit(this, visitor);
4302 visitor->post(parent, this);
4303}
4304
4305bool FloorDivNode::isEqual(const FloorDivNode &other) const {
4306 return true &&
4307 LHS_ == other.LHS_ &&
4308 RHS_ == other.RHS_ &&
4309 predicate_ == other.predicate_ &&
4310 Truncate_ == other.Truncate_ &&
4311 getType(0) == other.getType(0);
4312}
4313
4314Node* FloorDivNode::clone() const {
4315 return new FloorDivNode(getName(), getResult().getType(), getLHS(), getRHS(), getTruncate());
4316}
4317
4318llvm::hash_code FloorDivNode::getHash() const {
4319 return llvm::hash_combine(
4320 Truncate_,
4321 LHS_,
4322 RHS_);
4323}
4324
4325unsigned FmodNode::getNumInputs() const {
4326 return 2;
4327}
4328
4329std::string FmodNode::getInputName(unsigned idx) const {
4330 if (idx == 0) { return "LHS"; }
4331 if (idx == 1) { return "RHS"; }
4332 idx -= 2;
4333 llvm_unreachable("Invalid index");
4334}
4335
4336NodeValue FmodNode::getNthInput(unsigned idx) {
4337 if (idx == 0) { return LHS_; }
4338 if (idx == 1) { return RHS_; }
4339 idx -= 2;
4340 llvm_unreachable("Invalid index");
4341}
4342
4343void FmodNode::setNthInput(unsigned idx, NodeValue val) {
4344 if (idx == 0) { LHS_ = val; return; }
4345 if (idx == 1) { RHS_ = val; return; }
4346 idx -= 2;
4347 llvm_unreachable("Invalid index");
4348}
4349
4350llvm::StringRef FmodNode::getOutputName(unsigned idx) const {
4351 if (idx == 0) { return "Result"; }
4352 llvm_unreachable("Invalid index");
4353}
4354
4355std::string FmodNode::getDebugDesc() const {
4356 DescriptionBuilder db(getKindName());
4357 db.addParam("Name", separateString(getName(), 100, "\n"));
4358 if (hasPredicate()) db.addParam("Predicate", "Yes");
4359 db
4360 .addParam("LHS", *(getLHS().getType()))
4361 .addParam("RHS", *(getRHS().getType()))
4362 .addParam("Users", getNumUsers());
4363 db.addParam("Result", *(getResult().getType()));
4364 return db;
4365}
4366
4367void FmodNode::visit(Node *parent, NodeWalker *visitor) {
4368 if (!visitor->shouldVisit(parent, this)) { return; }
4369 visitor->pre(parent, this);
4370if (hasPredicate())
4371 getPredicate().getNode()->visit(this, visitor);
4372 getLHS().getNode()->visit(this, visitor);
4373 getRHS().getNode()->visit(this, visitor);
4374 visitor->post(parent, this);
4375}
4376
4377bool FmodNode::isEqual(const FmodNode &other) const {
4378 return true &&
4379 LHS_ == other.LHS_ &&
4380 RHS_ == other.RHS_ &&
4381 predicate_ == other.predicate_ &&
4382 getType(0) == other.getType(0);
4383}
4384
4385Node* FmodNode::clone() const {
4386 return new FmodNode(getName(), getResult().getType(), getLHS(), getRHS());
4387}
4388
4389llvm::hash_code FmodNode::getHash() const {
4390 return llvm::hash_combine(
4391 LHS_,
4392 RHS_);
4393}
4394
4395unsigned MaxNode::getNumInputs() const {
4396 return 2;
4397}
4398
4399std::string MaxNode::getInputName(unsigned idx) const {
4400 if (idx == 0) { return "LHS"; }
4401 if (idx == 1) { return "RHS"; }
4402 idx -= 2;
4403 llvm_unreachable("Invalid index");
4404}
4405
4406NodeValue MaxNode::getNthInput(unsigned idx) {
4407 if (idx == 0) { return LHS_; }
4408 if (idx == 1) { return RHS_; }
4409 idx -= 2;
4410 llvm_unreachable("Invalid index");
4411}
4412
4413void MaxNode::setNthInput(unsigned idx, NodeValue val) {
4414 if (idx == 0) { LHS_ = val; return; }
4415 if (idx == 1) { RHS_ = val; return; }
4416 idx -= 2;
4417 llvm_unreachable("Invalid index");
4418}
4419
4420llvm::StringRef MaxNode::getOutputName(unsigned idx) const {
4421 if (idx == 0) { return "Result"; }
4422 llvm_unreachable("Invalid index");
4423}
4424
4425std::string MaxNode::getDebugDesc() const {
4426 DescriptionBuilder db(getKindName());
4427 db.addParam("Name", separateString(getName(), 100, "\n"));
4428 if (hasPredicate()) db.addParam("Predicate", "Yes");
4429 db
4430 .addParam("LHS", *(getLHS().getType()))
4431 .addParam("RHS", *(getRHS().getType()))
4432 .addParam("Users", getNumUsers());
4433 db.addParam("Result", *(getResult().getType()));
4434 return db;
4435}
4436
4437void MaxNode::visit(Node *parent, NodeWalker *visitor) {
4438 if (!visitor->shouldVisit(parent, this)) { return; }
4439 visitor->pre(parent, this);
4440if (hasPredicate())
4441 getPredicate().getNode()->visit(this, visitor);
4442 getLHS().getNode()->visit(this, visitor);
4443 getRHS().getNode()->visit(this, visitor);
4444 visitor->post(parent, this);
4445}
4446
4447bool MaxNode::isEqual(const MaxNode &other) const {
4448 return true &&
4449 LHS_ == other.LHS_ &&
4450 RHS_ == other.RHS_ &&
4451 predicate_ == other.predicate_ &&
4452 getType(0) == other.getType(0);
4453}
4454
4455Node* MaxNode::clone() const {
4456 return new MaxNode(getName(), getResult().getType(), getLHS(), getRHS());
4457}
4458
4459llvm::hash_code MaxNode::getHash() const {
4460 return llvm::hash_combine(
4461 LHS_,
4462 RHS_);
4463}
4464
4465unsigned MinNode::getNumInputs() const {
4466 return 2;
4467}
4468
4469std::string MinNode::getInputName(unsigned idx) const {
4470 if (idx == 0) { return "LHS"; }
4471 if (idx == 1) { return "RHS"; }
4472 idx -= 2;
4473 llvm_unreachable("Invalid index");
4474}
4475
4476NodeValue MinNode::getNthInput(unsigned idx) {
4477 if (idx == 0) { return LHS_; }
4478 if (idx == 1) { return RHS_; }
4479 idx -= 2;
4480 llvm_unreachable("Invalid index");
4481}
4482
4483void MinNode::setNthInput(unsigned idx, NodeValue val) {
4484 if (idx == 0) { LHS_ = val; return; }
4485 if (idx == 1) { RHS_ = val; return; }
4486 idx -= 2;
4487 llvm_unreachable("Invalid index");
4488}
4489
4490llvm::StringRef MinNode::getOutputName(unsigned idx) const {
4491 if (idx == 0) { return "Result"; }
4492 llvm_unreachable("Invalid index");
4493}
4494
4495std::string MinNode::getDebugDesc() const {
4496 DescriptionBuilder db(getKindName());
4497 db.addParam("Name", separateString(getName(), 100, "\n"));
4498 if (hasPredicate()) db.addParam("Predicate", "Yes");
4499 db
4500 .addParam("LHS", *(getLHS().getType()))
4501 .addParam("RHS", *(getRHS().getType()))
4502 .addParam("Users", getNumUsers());
4503 db.addParam("Result", *(getResult().getType()));
4504 return db;
4505}
4506
4507void MinNode::visit(Node *parent, NodeWalker *visitor) {
4508 if (!visitor->shouldVisit(parent, this)) { return; }
4509 visitor->pre(parent, this);
4510if (hasPredicate())
4511 getPredicate().getNode()->visit(this, visitor);
4512 getLHS().getNode()->visit(this, visitor);
4513 getRHS().getNode()->visit(this, visitor);
4514 visitor->post(parent, this);
4515}
4516
4517bool MinNode::isEqual(const MinNode &other) const {
4518 return true &&
4519 LHS_ == other.LHS_ &&
4520 RHS_ == other.RHS_ &&
4521 predicate_ == other.predicate_ &&
4522 getType(0) == other.getType(0);
4523}
4524
4525Node* MinNode::clone() const {
4526 return new MinNode(getName(), getResult().getType(), getLHS(), getRHS());
4527}
4528
4529llvm::hash_code MinNode::getHash() const {
4530 return llvm::hash_combine(
4531 LHS_,
4532 RHS_);
4533}
4534
4535unsigned CmpEQNode::getNumInputs() const {
4536 return 2;
4537}
4538
4539std::string CmpEQNode::getInputName(unsigned idx) const {
4540 if (idx == 0) { return "LHS"; }
4541 if (idx == 1) { return "RHS"; }
4542 idx -= 2;
4543 llvm_unreachable("Invalid index");
4544}
4545
4546NodeValue CmpEQNode::getNthInput(unsigned idx) {
4547 if (idx == 0) { return LHS_; }
4548 if (idx == 1) { return RHS_; }
4549 idx -= 2;
4550 llvm_unreachable("Invalid index");
4551}
4552
4553void CmpEQNode::setNthInput(unsigned idx, NodeValue val) {
4554 if (idx == 0) { LHS_ = val; return; }
4555 if (idx == 1) { RHS_ = val; return; }
4556 idx -= 2;
4557 llvm_unreachable("Invalid index");
4558}
4559
4560llvm::StringRef CmpEQNode::getOutputName(unsigned idx) const {
4561 if (idx == 0) { return "Result"; }
4562 llvm_unreachable("Invalid index");
4563}
4564
4565std::string CmpEQNode::getDebugDesc() const {
4566 DescriptionBuilder db(getKindName());
4567 db.addParam("Name", separateString(getName(), 100, "\n"));
4568 if (hasPredicate()) db.addParam("Predicate", "Yes");
4569 db
4570 .addParam("LHS", *(getLHS().getType()))
4571 .addParam("RHS", *(getRHS().getType()))
4572 .addParam("Users", getNumUsers());
4573 db.addParam("Result", *(getResult().getType()));
4574 return db;
4575}
4576
4577void CmpEQNode::visit(Node *parent, NodeWalker *visitor) {
4578 if (!visitor->shouldVisit(parent, this)) { return; }
4579 visitor->pre(parent, this);
4580if (hasPredicate())
4581 getPredicate().getNode()->visit(this, visitor);
4582 getLHS().getNode()->visit(this, visitor);
4583 getRHS().getNode()->visit(this, visitor);
4584 visitor->post(parent, this);
4585}
4586
4587bool CmpEQNode::isEqual(const CmpEQNode &other) const {
4588 return true &&
4589 LHS_ == other.LHS_ &&
4590 RHS_ == other.RHS_ &&
4591 predicate_ == other.predicate_ &&
4592 getType(0) == other.getType(0);
4593}
4594
4595Node* CmpEQNode::clone() const {
4596 return new CmpEQNode(getName(), getResult().getType(), getLHS(), getRHS());
4597}
4598
4599llvm::hash_code CmpEQNode::getHash() const {
4600 return llvm::hash_combine(
4601 LHS_,
4602 RHS_);
4603}
4604
4605unsigned CmpNEQNode::getNumInputs() const {
4606 return 2;
4607}
4608
4609std::string CmpNEQNode::getInputName(unsigned idx) const {
4610 if (idx == 0) { return "LHS"; }
4611 if (idx == 1) { return "RHS"; }
4612 idx -= 2;
4613 llvm_unreachable("Invalid index");
4614}
4615
4616NodeValue CmpNEQNode::getNthInput(unsigned idx) {
4617 if (idx == 0) { return LHS_; }
4618 if (idx == 1) { return RHS_; }
4619 idx -= 2;
4620 llvm_unreachable("Invalid index");
4621}
4622
4623void CmpNEQNode::setNthInput(unsigned idx, NodeValue val) {
4624 if (idx == 0) { LHS_ = val; return; }
4625 if (idx == 1) { RHS_ = val; return; }
4626 idx -= 2;
4627 llvm_unreachable("Invalid index");
4628}
4629
4630llvm::StringRef CmpNEQNode::getOutputName(unsigned idx) const {
4631 if (idx == 0) { return "Result"; }
4632 llvm_unreachable("Invalid index");
4633}
4634
4635std::string CmpNEQNode::getDebugDesc() const {
4636 DescriptionBuilder db(getKindName());
4637 db.addParam("Name", separateString(getName(), 100, "\n"));
4638 if (hasPredicate()) db.addParam("Predicate", "Yes");
4639 db
4640 .addParam("LHS", *(getLHS().getType()))
4641 .addParam("RHS", *(getRHS().getType()))
4642 .addParam("Users", getNumUsers());
4643 db.addParam("Result", *(getResult().getType()));
4644 return db;
4645}
4646
4647void CmpNEQNode::visit(Node *parent, NodeWalker *visitor) {
4648 if (!visitor->shouldVisit(parent, this)) { return; }
4649 visitor->pre(parent, this);
4650if (hasPredicate())
4651 getPredicate().getNode()->visit(this, visitor);
4652 getLHS().getNode()->visit(this, visitor);
4653 getRHS().getNode()->visit(this, visitor);
4654 visitor->post(parent, this);
4655}
4656
4657bool CmpNEQNode::isEqual(const CmpNEQNode &other) const {
4658 return true &&
4659 LHS_ == other.LHS_ &&
4660 RHS_ == other.RHS_ &&
4661 predicate_ == other.predicate_ &&
4662 getType(0) == other.getType(0);
4663}
4664
4665Node* CmpNEQNode::clone() const {
4666 return new CmpNEQNode(getName(), getResult().getType(), getLHS(), getRHS());
4667}
4668
4669llvm::hash_code CmpNEQNode::getHash() const {
4670 return llvm::hash_combine(
4671 LHS_,
4672 RHS_);
4673}
4674
4675unsigned CmpLTNode::getNumInputs() const {
4676 return 2;
4677}
4678
4679std::string CmpLTNode::getInputName(unsigned idx) const {
4680 if (idx == 0) { return "LHS"; }
4681 if (idx == 1) { return "RHS"; }
4682 idx -= 2;
4683 llvm_unreachable("Invalid index");
4684}
4685
4686NodeValue CmpLTNode::getNthInput(unsigned idx) {
4687 if (idx == 0) { return LHS_; }
4688 if (idx == 1) { return RHS_; }
4689 idx -= 2;
4690 llvm_unreachable("Invalid index");
4691}
4692
4693void CmpLTNode::setNthInput(unsigned idx, NodeValue val) {
4694 if (idx == 0) { LHS_ = val; return; }
4695 if (idx == 1) { RHS_ = val; return; }
4696 idx -= 2;
4697 llvm_unreachable("Invalid index");
4698}
4699
4700llvm::StringRef CmpLTNode::getOutputName(unsigned idx) const {
4701 if (idx == 0) { return "Result"; }
4702 llvm_unreachable("Invalid index");
4703}
4704
4705std::string CmpLTNode::getDebugDesc() const {
4706 DescriptionBuilder db(getKindName());
4707 db.addParam("Name", separateString(getName(), 100, "\n"));
4708 if (hasPredicate()) db.addParam("Predicate", "Yes");
4709 db
4710 .addParam("LHS", *(getLHS().getType()))
4711 .addParam("RHS", *(getRHS().getType()))
4712 .addParam("Users", getNumUsers());
4713 db.addParam("Result", *(getResult().getType()));
4714 return db;
4715}
4716
4717void CmpLTNode::visit(Node *parent, NodeWalker *visitor) {
4718 if (!visitor->shouldVisit(parent, this)) { return; }
4719 visitor->pre(parent, this);
4720if (hasPredicate())
4721 getPredicate().getNode()->visit(this, visitor);
4722 getLHS().getNode()->visit(this, visitor);
4723 getRHS().getNode()->visit(this, visitor);
4724 visitor->post(parent, this);
4725}
4726
4727bool CmpLTNode::isEqual(const CmpLTNode &other) const {
4728 return true &&
4729 LHS_ == other.LHS_ &&
4730 RHS_ == other.RHS_ &&
4731 predicate_ == other.predicate_ &&
4732 getType(0) == other.getType(0);
4733}
4734
4735Node* CmpLTNode::clone() const {
4736 return new CmpLTNode(getName(), getResult().getType(), getLHS(), getRHS());
4737}
4738
4739llvm::hash_code CmpLTNode::getHash() const {
4740 return llvm::hash_combine(
4741 LHS_,
4742 RHS_);
4743}
4744
4745unsigned CmpLTENode::getNumInputs() const {
4746 return 2;
4747}
4748
4749std::string CmpLTENode::getInputName(unsigned idx) const {
4750 if (idx == 0) { return "LHS"; }
4751 if (idx == 1) { return "RHS"; }
4752 idx -= 2;
4753 llvm_unreachable("Invalid index");
4754}
4755
4756NodeValue CmpLTENode::getNthInput(unsigned idx) {
4757 if (idx == 0) { return LHS_; }
4758 if (idx == 1) { return RHS_; }
4759 idx -= 2;
4760 llvm_unreachable("Invalid index");
4761}
4762
4763void CmpLTENode::setNthInput(unsigned idx, NodeValue val) {
4764 if (idx == 0) { LHS_ = val; return; }
4765 if (idx == 1) { RHS_ = val; return; }
4766 idx -= 2;
4767 llvm_unreachable("Invalid index");
4768}
4769
4770llvm::StringRef CmpLTENode::getOutputName(unsigned idx) const {
4771 if (idx == 0) { return "Result"; }
4772 llvm_unreachable("Invalid index");
4773}
4774
4775std::string CmpLTENode::getDebugDesc() const {
4776 DescriptionBuilder db(getKindName());
4777 db.addParam("Name", separateString(getName(), 100, "\n"));
4778 if (hasPredicate()) db.addParam("Predicate", "Yes");
4779 db
4780 .addParam("LHS", *(getLHS().getType()))
4781 .addParam("RHS", *(getRHS().getType()))
4782 .addParam("Users", getNumUsers());
4783 db.addParam("Result", *(getResult().getType()));
4784 return db;
4785}
4786
4787void CmpLTENode::visit(Node *parent, NodeWalker *visitor) {
4788 if (!visitor->shouldVisit(parent, this)) { return; }
4789 visitor->pre(parent, this);
4790if (hasPredicate())
4791 getPredicate().getNode()->visit(this, visitor);
4792 getLHS().getNode()->visit(this, visitor);
4793 getRHS().getNode()->visit(this, visitor);
4794 visitor->post(parent, this);
4795}
4796
4797bool CmpLTENode::isEqual(const CmpLTENode &other) const {
4798 return true &&
4799 LHS_ == other.LHS_ &&
4800 RHS_ == other.RHS_ &&
4801 predicate_ == other.predicate_ &&
4802 getType(0) == other.getType(0);
4803}
4804
4805Node* CmpLTENode::clone() const {
4806 return new CmpLTENode(getName(), getResult().getType(), getLHS(), getRHS());
4807}
4808
4809llvm::hash_code CmpLTENode::getHash() const {
4810 return llvm::hash_combine(
4811 LHS_,
4812 RHS_);
4813}
4814
4815unsigned PowNode::getNumInputs() const {
4816 return 2;
4817}
4818
4819std::string PowNode::getInputName(unsigned idx) const {
4820 if (idx == 0) { return "LHS"; }
4821 if (idx == 1) { return "RHS"; }
4822 idx -= 2;
4823 llvm_unreachable("Invalid index");
4824}
4825
4826NodeValue PowNode::getNthInput(unsigned idx) {
4827 if (idx == 0) { return LHS_; }
4828 if (idx == 1) { return RHS_; }
4829 idx -= 2;
4830 llvm_unreachable("Invalid index");
4831}
4832
4833void PowNode::setNthInput(unsigned idx, NodeValue val) {
4834 if (idx == 0) { LHS_ = val; return; }
4835 if (idx == 1) { RHS_ = val; return; }
4836 idx -= 2;
4837 llvm_unreachable("Invalid index");
4838}
4839
4840llvm::StringRef PowNode::getOutputName(unsigned idx) const {
4841 if (idx == 0) { return "Result"; }
4842 llvm_unreachable("Invalid index");
4843}
4844
4845std::string PowNode::getDebugDesc() const {
4846 DescriptionBuilder db(getKindName());
4847 db.addParam("Name", separateString(getName(), 100, "\n"));
4848 if (hasPredicate()) db.addParam("Predicate", "Yes");
4849 db
4850 .addParam("LHS", *(getLHS().getType()))
4851 .addParam("RHS", *(getRHS().getType()))
4852 .addParam("Users", getNumUsers());
4853 db.addParam("Result", *(getResult().getType()));
4854 return db;
4855}
4856
4857void PowNode::visit(Node *parent, NodeWalker *visitor) {
4858 if (!visitor->shouldVisit(parent, this)) { return; }
4859 visitor->pre(parent, this);
4860if (hasPredicate())
4861 getPredicate().getNode()->visit(this, visitor);
4862 getLHS().getNode()->visit(this, visitor);
4863 getRHS().getNode()->visit(this, visitor);
4864 visitor->post(parent, this);
4865}
4866
4867bool PowNode::isEqual(const PowNode &other) const {
4868 return true &&
4869 LHS_ == other.LHS_ &&
4870 RHS_ == other.RHS_ &&
4871 predicate_ == other.predicate_ &&
4872 getType(0) == other.getType(0);
4873}
4874
4875Node* PowNode::clone() const {
4876 return new PowNode(getName(), getResult().getType(), getLHS(), getRHS());
4877}
4878
4879llvm::hash_code PowNode::getHash() const {
4880 return llvm::hash_combine(
4881 LHS_,
4882 RHS_);
4883}
4884
4885unsigned AndNode::getNumInputs() const {
4886 return 2;
4887}
4888
4889std::string AndNode::getInputName(unsigned idx) const {
4890 if (idx == 0) { return "LHS"; }
4891 if (idx == 1) { return "RHS"; }
4892 idx -= 2;
4893 llvm_unreachable("Invalid index");
4894}
4895
4896NodeValue AndNode::getNthInput(unsigned idx) {
4897 if (idx == 0) { return LHS_; }
4898 if (idx == 1) { return RHS_; }
4899 idx -= 2;
4900 llvm_unreachable("Invalid index");
4901}
4902
4903void AndNode::setNthInput(unsigned idx, NodeValue val) {
4904 if (idx == 0) { LHS_ = val; return; }
4905 if (idx == 1) { RHS_ = val; return; }
4906 idx -= 2;
4907 llvm_unreachable("Invalid index");
4908}
4909
4910llvm::StringRef AndNode::getOutputName(unsigned idx) const {
4911 if (idx == 0) { return "Result"; }
4912 llvm_unreachable("Invalid index");
4913}
4914
4915std::string AndNode::getDebugDesc() const {
4916 DescriptionBuilder db(getKindName());
4917 db.addParam("Name", separateString(getName(), 100, "\n"));
4918 if (hasPredicate()) db.addParam("Predicate", "Yes");
4919 db
4920 .addParam("LHS", *(getLHS().getType()))
4921 .addParam("RHS", *(getRHS().getType()))
4922 .addParam("Users", getNumUsers());
4923 db.addParam("Result", *(getResult().getType()));
4924 return db;
4925}
4926
4927void AndNode::visit(Node *parent, NodeWalker *visitor) {
4928 if (!visitor->shouldVisit(parent, this)) { return; }
4929 visitor->pre(parent, this);
4930if (hasPredicate())
4931 getPredicate().getNode()->visit(this, visitor);
4932 getLHS().getNode()->visit(this, visitor);
4933 getRHS().getNode()->visit(this, visitor);
4934 visitor->post(parent, this);
4935}
4936
4937bool AndNode::isEqual(const AndNode &other) const {
4938 return true &&
4939 LHS_ == other.LHS_ &&
4940 RHS_ == other.RHS_ &&
4941 predicate_ == other.predicate_ &&
4942 getType(0) == other.getType(0);
4943}
4944
4945Node* AndNode::clone() const {
4946 return new AndNode(getName(), getResult().getType(), getLHS(), getRHS());
4947}
4948
4949llvm::hash_code AndNode::getHash() const {
4950 return llvm::hash_combine(
4951 LHS_,
4952 RHS_);
4953}
4954
4955unsigned BitwiseAndNode::getNumInputs() const {
4956 return 2;
4957}
4958
4959std::string BitwiseAndNode::getInputName(unsigned idx) const {
4960 if (idx == 0) { return "LHS"; }
4961 if (idx == 1) { return "RHS"; }
4962 idx -= 2;
4963 llvm_unreachable("Invalid index");
4964}
4965
4966NodeValue BitwiseAndNode::getNthInput(unsigned idx) {
4967 if (idx == 0) { return LHS_; }
4968 if (idx == 1) { return RHS_; }
4969 idx -= 2;
4970 llvm_unreachable("Invalid index");
4971}
4972
4973void BitwiseAndNode::setNthInput(unsigned idx, NodeValue val) {
4974 if (idx == 0) { LHS_ = val; return; }
4975 if (idx == 1) { RHS_ = val; return; }
4976 idx -= 2;
4977 llvm_unreachable("Invalid index");
4978}
4979
4980llvm::StringRef BitwiseAndNode::getOutputName(unsigned idx) const {
4981 if (idx == 0) { return "Result"; }
4982 llvm_unreachable("Invalid index");
4983}
4984
4985std::string BitwiseAndNode::getDebugDesc() const {
4986 DescriptionBuilder db(getKindName());
4987 db.addParam("Name", separateString(getName(), 100, "\n"));
4988 if (hasPredicate()) db.addParam("Predicate", "Yes");
4989 db
4990 .addParam("LHS", *(getLHS().getType()))
4991 .addParam("RHS", *(getRHS().getType()))
4992 .addParam("Users", getNumUsers());
4993 db.addParam("Result", *(getResult().getType()));
4994 return db;
4995}
4996
4997void BitwiseAndNode::visit(Node *parent, NodeWalker *visitor) {
4998 if (!visitor->shouldVisit(parent, this)) { return; }
4999 visitor->pre(parent, this);
5000if (hasPredicate())
5001 getPredicate().getNode()->visit(this, visitor);
5002 getLHS().getNode()->visit(this, visitor);
5003 getRHS().getNode()->visit(this, visitor);
5004 visitor->post(parent, this);
5005}
5006
5007bool BitwiseAndNode::isEqual(const BitwiseAndNode &other) const {
5008 return true &&
5009 LHS_ == other.LHS_ &&
5010 RHS_ == other.RHS_ &&
5011 predicate_ == other.predicate_ &&
5012 getType(0) == other.getType(0);
5013}
5014
5015Node* BitwiseAndNode::clone() const {
5016 return new BitwiseAndNode(getName(), getResult().getType(), getLHS(), getRHS());
5017}
5018
5019llvm::hash_code BitwiseAndNode::getHash() const {
5020 return llvm::hash_combine(
5021 LHS_,
5022 RHS_);
5023}
5024
5025unsigned OrNode::getNumInputs() const {
5026 return 2;
5027}
5028
5029std::string OrNode::getInputName(unsigned idx) const {
5030 if (idx == 0) { return "LHS"; }
5031 if (idx == 1) { return "RHS"; }
5032 idx -= 2;
5033 llvm_unreachable("Invalid index");
5034}
5035
5036NodeValue OrNode::getNthInput(unsigned idx) {
5037 if (idx == 0) { return LHS_; }
5038 if (idx == 1) { return RHS_; }
5039 idx -= 2;
5040 llvm_unreachable("Invalid index");
5041}
5042
5043void OrNode::setNthInput(unsigned idx, NodeValue val) {
5044 if (idx == 0) { LHS_ = val; return; }
5045 if (idx == 1) { RHS_ = val; return; }
5046 idx -= 2;
5047 llvm_unreachable("Invalid index");
5048}
5049
5050llvm::StringRef OrNode::getOutputName(unsigned idx) const {
5051 if (idx == 0) { return "Result"; }
5052 llvm_unreachable("Invalid index");
5053}
5054
5055std::string OrNode::getDebugDesc() const {
5056 DescriptionBuilder db(getKindName());
5057 db.addParam("Name", separateString(getName(), 100, "\n"));
5058 if (hasPredicate()) db.addParam("Predicate", "Yes");
5059 db
5060 .addParam("LHS", *(getLHS().getType()))
5061 .addParam("RHS", *(getRHS().getType()))
5062 .addParam("Users", getNumUsers());
5063 db.addParam("Result", *(getResult().getType()));
5064 return db;
5065}
5066
5067void OrNode::visit(Node *parent, NodeWalker *visitor) {
5068 if (!visitor->shouldVisit(parent, this)) { return; }
5069 visitor->pre(parent, this);
5070if (hasPredicate())
5071 getPredicate().getNode()->visit(this, visitor);
5072 getLHS().getNode()->visit(this, visitor);
5073 getRHS().getNode()->visit(this, visitor);
5074 visitor->post(parent, this);
5075}
5076
5077bool OrNode::isEqual(const OrNode &other) const {
5078 return true &&
5079 LHS_ == other.LHS_ &&
5080 RHS_ == other.RHS_ &&
5081 predicate_ == other.predicate_ &&
5082 getType(0) == other.getType(0);
5083}
5084
5085Node* OrNode::clone() const {
5086 return new OrNode(getName(), getResult().getType(), getLHS(), getRHS());
5087}
5088
5089llvm::hash_code OrNode::getHash() const {
5090 return llvm::hash_combine(
5091 LHS_,
5092 RHS_);
5093}
5094
5095unsigned BitwiseOrNode::getNumInputs() const {
5096 return 2;
5097}
5098
5099std::string BitwiseOrNode::getInputName(unsigned idx) const {
5100 if (idx == 0) { return "LHS"; }
5101 if (idx == 1) { return "RHS"; }
5102 idx -= 2;
5103 llvm_unreachable("Invalid index");
5104}
5105
5106NodeValue BitwiseOrNode::getNthInput(unsigned idx) {
5107 if (idx == 0) { return LHS_; }
5108 if (idx == 1) { return RHS_; }
5109 idx -= 2;
5110 llvm_unreachable("Invalid index");
5111}
5112
5113void BitwiseOrNode::setNthInput(unsigned idx, NodeValue val) {
5114 if (idx == 0) { LHS_ = val; return; }
5115 if (idx == 1) { RHS_ = val; return; }
5116 idx -= 2;
5117 llvm_unreachable("Invalid index");
5118}
5119
5120llvm::StringRef BitwiseOrNode::getOutputName(unsigned idx) const {
5121 if (idx == 0) { return "Result"; }
5122 llvm_unreachable("Invalid index");
5123}
5124
5125std::string BitwiseOrNode::getDebugDesc() const {
5126 DescriptionBuilder db(getKindName());
5127 db.addParam("Name", separateString(getName(), 100, "\n"));
5128 if (hasPredicate()) db.addParam("Predicate", "Yes");
5129 db
5130 .addParam("LHS", *(getLHS().getType()))
5131 .addParam("RHS", *(getRHS().getType()))
5132 .addParam("Users", getNumUsers());
5133 db.addParam("Result", *(getResult().getType()));
5134 return db;
5135}
5136
5137void BitwiseOrNode::visit(Node *parent, NodeWalker *visitor) {
5138 if (!visitor->shouldVisit(parent, this)) { return; }
5139 visitor->pre(parent, this);
5140if (hasPredicate())
5141 getPredicate().getNode()->visit(this, visitor);
5142 getLHS().getNode()->visit(this, visitor);
5143 getRHS().getNode()->visit(this, visitor);
5144 visitor->post(parent, this);
5145}
5146
5147bool BitwiseOrNode::isEqual(const BitwiseOrNode &other) const {
5148 return true &&
5149 LHS_ == other.LHS_ &&
5150 RHS_ == other.RHS_ &&
5151 predicate_ == other.predicate_ &&
5152 getType(0) == other.getType(0);
5153}
5154
5155Node* BitwiseOrNode::clone() const {
5156 return new BitwiseOrNode(getName(), getResult().getType(), getLHS(), getRHS());
5157}
5158
5159llvm::hash_code BitwiseOrNode::getHash() const {
5160 return llvm::hash_combine(
5161 LHS_,
5162 RHS_);
5163}
5164
5165unsigned XorNode::getNumInputs() const {
5166 return 2;
5167}
5168
5169std::string XorNode::getInputName(unsigned idx) const {
5170 if (idx == 0) { return "LHS"; }
5171 if (idx == 1) { return "RHS"; }
5172 idx -= 2;
5173 llvm_unreachable("Invalid index");
5174}
5175
5176NodeValue XorNode::getNthInput(unsigned idx) {
5177 if (idx == 0) { return LHS_; }
5178 if (idx == 1) { return RHS_; }
5179 idx -= 2;
5180 llvm_unreachable("Invalid index");
5181}
5182
5183void XorNode::setNthInput(unsigned idx, NodeValue val) {
5184 if (idx == 0) { LHS_ = val; return; }
5185 if (idx == 1) { RHS_ = val; return; }
5186 idx -= 2;
5187 llvm_unreachable("Invalid index");
5188}
5189
5190llvm::StringRef XorNode::getOutputName(unsigned idx) const {
5191 if (idx == 0) { return "Result"; }
5192 llvm_unreachable("Invalid index");
5193}
5194
5195std::string XorNode::getDebugDesc() const {
5196 DescriptionBuilder db(getKindName());
5197 db.addParam("Name", separateString(getName(), 100, "\n"));
5198 if (hasPredicate()) db.addParam("Predicate", "Yes");
5199 db
5200 .addParam("LHS", *(getLHS().getType()))
5201 .addParam("RHS", *(getRHS().getType()))
5202 .addParam("Users", getNumUsers());
5203 db.addParam("Result", *(getResult().getType()));
5204 return db;
5205}
5206
5207void XorNode::visit(Node *parent, NodeWalker *visitor) {
5208 if (!visitor->shouldVisit(parent, this)) { return; }
5209 visitor->pre(parent, this);
5210if (hasPredicate())
5211 getPredicate().getNode()->visit(this, visitor);
5212 getLHS().getNode()->visit(this, visitor);
5213 getRHS().getNode()->visit(this, visitor);
5214 visitor->post(parent, this);
5215}
5216
5217bool XorNode::isEqual(const XorNode &other) const {
5218 return true &&
5219 LHS_ == other.LHS_ &&
5220 RHS_ == other.RHS_ &&
5221 predicate_ == other.predicate_ &&
5222 getType(0) == other.getType(0);
5223}
5224
5225Node* XorNode::clone() const {
5226 return new XorNode(getName(), getResult().getType(), getLHS(), getRHS());
5227}
5228
5229llvm::hash_code XorNode::getHash() const {
5230 return llvm::hash_combine(
5231 LHS_,
5232 RHS_);
5233}
5234
5235unsigned BitwiseXorNode::getNumInputs() const {
5236 return 2;
5237}
5238
5239std::string BitwiseXorNode::getInputName(unsigned idx) const {
5240 if (idx == 0) { return "LHS"; }
5241 if (idx == 1) { return "RHS"; }
5242 idx -= 2;
5243 llvm_unreachable("Invalid index");
5244}
5245
5246NodeValue BitwiseXorNode::getNthInput(unsigned idx) {
5247 if (idx == 0) { return LHS_; }
5248 if (idx == 1) { return RHS_; }
5249 idx -= 2;
5250 llvm_unreachable("Invalid index");
5251}
5252
5253void BitwiseXorNode::setNthInput(unsigned idx, NodeValue val) {
5254 if (idx == 0) { LHS_ = val; return; }
5255 if (idx == 1) { RHS_ = val; return; }
5256 idx -= 2;
5257 llvm_unreachable("Invalid index");
5258}
5259
5260llvm::StringRef BitwiseXorNode::getOutputName(unsigned idx) const {
5261 if (idx == 0) { return "Result"; }
5262 llvm_unreachable("Invalid index");
5263}
5264
5265std::string BitwiseXorNode::getDebugDesc() const {
5266 DescriptionBuilder db(getKindName());
5267 db.addParam("Name", separateString(getName(), 100, "\n"));
5268 if (hasPredicate()) db.addParam("Predicate", "Yes");
5269 db
5270 .addParam("LHS", *(getLHS().getType()))
5271 .addParam("RHS", *(getRHS().getType()))
5272 .addParam("Users", getNumUsers());
5273 db.addParam("Result", *(getResult().getType()));
5274 return db;
5275}
5276
5277void BitwiseXorNode::visit(Node *parent, NodeWalker *visitor) {
5278 if (!visitor->shouldVisit(parent, this)) { return; }
5279 visitor->pre(parent, this);
5280if (hasPredicate())
5281 getPredicate().getNode()->visit(this, visitor);
5282 getLHS().getNode()->visit(this, visitor);
5283 getRHS().getNode()->visit(this, visitor);
5284 visitor->post(parent, this);
5285}
5286
5287bool BitwiseXorNode::isEqual(const BitwiseXorNode &other) const {
5288 return true &&
5289 LHS_ == other.LHS_ &&
5290 RHS_ == other.RHS_ &&
5291 predicate_ == other.predicate_ &&
5292 getType(0) == other.getType(0);
5293}
5294
5295Node* BitwiseXorNode::clone() const {
5296 return new BitwiseXorNode(getName(), getResult().getType(), getLHS(), getRHS());
5297}
5298
5299llvm::hash_code BitwiseXorNode::getHash() const {
5300 return llvm::hash_combine(
5301 LHS_,
5302 RHS_);
5303}
5304
5305unsigned NotNode::getNumInputs() const {
5306 return 1;
5307}
5308
5309std::string NotNode::getInputName(unsigned idx) const {
5310 if (idx == 0) { return "Input"; }
5311 idx -= 1;
5312 llvm_unreachable("Invalid index");
5313}
5314
5315NodeValue NotNode::getNthInput(unsigned idx) {
5316 if (idx == 0) { return Input_; }
5317 idx -= 1;
5318 llvm_unreachable("Invalid index");
5319}
5320
5321void NotNode::setNthInput(unsigned idx, NodeValue val) {
5322 if (idx == 0) { Input_ = val; return; }
5323 idx -= 1;
5324 llvm_unreachable("Invalid index");
5325}
5326
5327llvm::StringRef NotNode::getOutputName(unsigned idx) const {
5328 if (idx == 0) { return "Result"; }
5329 llvm_unreachable("Invalid index");
5330}
5331
5332std::string NotNode::getDebugDesc() const {
5333 DescriptionBuilder db(getKindName());
5334 db.addParam("Name", separateString(getName(), 100, "\n"));
5335 if (hasPredicate()) db.addParam("Predicate", "Yes");
5336 db
5337 .addParam("Input", *(getInput().getType()))
5338 .addParam("Users", getNumUsers());
5339 db.addParam("Result", *(getResult().getType()));
5340 return db;
5341}
5342
5343void NotNode::visit(Node *parent, NodeWalker *visitor) {
5344 if (!visitor->shouldVisit(parent, this)) { return; }
5345 visitor->pre(parent, this);
5346if (hasPredicate())
5347 getPredicate().getNode()->visit(this, visitor);
5348 getInput().getNode()->visit(this, visitor);
5349 visitor->post(parent, this);
5350}
5351
5352bool NotNode::isEqual(const NotNode &other) const {
5353 return true &&
5354 Input_ == other.Input_ &&
5355 predicate_ == other.predicate_ &&
5356 getType(0) == other.getType(0);
5357}
5358
5359Node* NotNode::clone() const {
5360 return new NotNode(getName(), getResult().getType(), getInput());
5361}
5362
5363llvm::hash_code NotNode::getHash() const {
5364 return llvm::hash_combine(
5365 Input_);
5366}
5367
5368unsigned BitwiseNotNode::getNumInputs() const {
5369 return 1;
5370}
5371
5372std::string BitwiseNotNode::getInputName(unsigned idx) const {
5373 if (idx == 0) { return "Input"; }
5374 idx -= 1;
5375 llvm_unreachable("Invalid index");
5376}
5377
5378NodeValue BitwiseNotNode::getNthInput(unsigned idx) {
5379 if (idx == 0) { return Input_; }
5380 idx -= 1;
5381 llvm_unreachable("Invalid index");
5382}
5383
5384void BitwiseNotNode::setNthInput(unsigned idx, NodeValue val) {
5385 if (idx == 0) { Input_ = val; return; }
5386 idx -= 1;
5387 llvm_unreachable("Invalid index");
5388}
5389
5390llvm::StringRef BitwiseNotNode::getOutputName(unsigned idx) const {
5391 if (idx == 0) { return "Result"; }
5392 llvm_unreachable("Invalid index");
5393}
5394
5395std::string BitwiseNotNode::getDebugDesc() const {
5396 DescriptionBuilder db(getKindName());
5397 db.addParam("Name", separateString(getName(), 100, "\n"));
5398 if (hasPredicate()) db.addParam("Predicate", "Yes");
5399 db
5400 .addParam("Input", *(getInput().getType()))
5401 .addParam("Users", getNumUsers());
5402 db.addParam("Result", *(getResult().getType()));
5403 return db;
5404}
5405
5406void BitwiseNotNode::visit(Node *parent, NodeWalker *visitor) {
5407 if (!visitor->shouldVisit(parent, this)) { return; }
5408 visitor->pre(parent, this);
5409if (hasPredicate())
5410 getPredicate().getNode()->visit(this, visitor);
5411 getInput().getNode()->visit(this, visitor);
5412 visitor->post(parent, this);
5413}
5414
5415bool BitwiseNotNode::isEqual(const BitwiseNotNode &other) const {
5416 return true &&
5417 Input_ == other.Input_ &&
5418 predicate_ == other.predicate_ &&
5419 getType(0) == other.getType(0);
5420}
5421
5422Node* BitwiseNotNode::clone() const {
5423 return new BitwiseNotNode(getName(), getResult().getType(), getInput());
5424}
5425
5426llvm::hash_code BitwiseNotNode::getHash() const {
5427 return llvm::hash_combine(
5428 Input_);
5429}
5430
5431unsigned NegNode::getNumInputs() const {
5432 return 1;
5433}
5434
5435std::string NegNode::getInputName(unsigned idx) const {
5436 if (idx == 0) { return "Input"; }
5437 idx -= 1;
5438 llvm_unreachable("Invalid index");
5439}
5440
5441NodeValue NegNode::getNthInput(unsigned idx) {
5442 if (idx == 0) { return Input_; }
5443 idx -= 1;
5444 llvm_unreachable("Invalid index");
5445}
5446
5447void NegNode::setNthInput(unsigned idx, NodeValue val) {
5448 if (idx == 0) { Input_ = val; return; }
5449 idx -= 1;
5450 llvm_unreachable("Invalid index");
5451}
5452
5453llvm::StringRef NegNode::getOutputName(unsigned idx) const {
5454 if (idx == 0) { return "Result"; }
5455 llvm_unreachable("Invalid index");
5456}
5457
5458std::string NegNode::getDebugDesc() const {
5459 DescriptionBuilder db(getKindName());
5460 db.addParam("Name", separateString(getName(), 100, "\n"));
5461 if (hasPredicate()) db.addParam("Predicate", "Yes");
5462 db
5463 .addParam("Input", *(getInput().getType()))
5464 .addParam("Users", getNumUsers());
5465 db.addParam("Result", *(getResult().getType()));
5466 return db;
5467}
5468
5469void NegNode::visit(Node *parent, NodeWalker *visitor) {
5470 if (!visitor->shouldVisit(parent, this)) { return; }
5471 visitor->pre(parent, this);
5472if (hasPredicate())
5473 getPredicate().getNode()->visit(this, visitor);
5474 getInput().getNode()->visit(this, visitor);
5475 visitor->post(parent, this);
5476}
5477
5478bool NegNode::isEqual(const NegNode &other) const {
5479 return true &&
5480 Input_ == other.Input_ &&
5481 predicate_ == other.predicate_ &&
5482 getType(0) == other.getType(0);
5483}
5484
5485Node* NegNode::clone() const {
5486 return new NegNode(getName(), getResult().getType(), getInput());
5487}
5488
5489llvm::hash_code NegNode::getHash() const {
5490 return llvm::hash_combine(
5491 Input_);
5492}
5493
5494unsigned AbsNode::getNumInputs() const {
5495 return 1;
5496}
5497
5498std::string AbsNode::getInputName(unsigned idx) const {
5499 if (idx == 0) { return "Input"; }
5500 idx -= 1;
5501 llvm_unreachable("Invalid index");
5502}
5503
5504NodeValue AbsNode::getNthInput(unsigned idx) {
5505 if (idx == 0) { return Input_; }
5506 idx -= 1;
5507 llvm_unreachable("Invalid index");
5508}
5509
5510void AbsNode::setNthInput(unsigned idx, NodeValue val) {
5511 if (idx == 0) { Input_ = val; return; }
5512 idx -= 1;
5513 llvm_unreachable("Invalid index");
5514}
5515
5516llvm::StringRef AbsNode::getOutputName(unsigned idx) const {
5517 if (idx == 0) { return "Result"; }
5518 llvm_unreachable("Invalid index");
5519}
5520
5521std::string AbsNode::getDebugDesc() const {
5522 DescriptionBuilder db(getKindName());
5523 db.addParam("Name", separateString(getName(), 100, "\n"));
5524 if (hasPredicate()) db.addParam("Predicate", "Yes");
5525 db
5526 .addParam("Input", *(getInput().getType()))
5527 .addParam("Users", getNumUsers());
5528 db.addParam("Result", *(getResult().getType()));
5529 return db;
5530}
5531
5532void AbsNode::visit(Node *parent, NodeWalker *visitor) {
5533 if (!visitor->shouldVisit(parent, this)) { return; }
5534 visitor->pre(parent, this);
5535if (hasPredicate())
5536 getPredicate().getNode()->visit(this, visitor);
5537 getInput().getNode()->visit(this, visitor);
5538 visitor->post(parent, this);
5539}
5540
5541bool AbsNode::isEqual(const AbsNode &other) const {
5542 return true &&
5543 Input_ == other.Input_ &&
5544 predicate_ == other.predicate_ &&
5545 getType(0) == other.getType(0);
5546}
5547
5548Node* AbsNode::clone() const {
5549 return new AbsNode(getName(), getResult().getType(), getInput());
5550}
5551
5552llvm::hash_code AbsNode::getHash() const {
5553 return llvm::hash_combine(
5554 Input_);
5555}
5556
5557unsigned FloorNode::getNumInputs() const {
5558 return 1;
5559}
5560
5561std::string FloorNode::getInputName(unsigned idx) const {
5562 if (idx == 0) { return "Input"; }
5563 idx -= 1;
5564 llvm_unreachable("Invalid index");
5565}
5566
5567NodeValue FloorNode::getNthInput(unsigned idx) {
5568 if (idx == 0) { return Input_; }
5569 idx -= 1;
5570 llvm_unreachable("Invalid index");
5571}
5572
5573void FloorNode::setNthInput(unsigned idx, NodeValue val) {
5574 if (idx == 0) { Input_ = val; return; }
5575 idx -= 1;
5576 llvm_unreachable("Invalid index");
5577}
5578
5579llvm::StringRef FloorNode::getOutputName(unsigned idx) const {
5580 if (idx == 0) { return "Result"; }
5581 llvm_unreachable("Invalid index");
5582}
5583
5584std::string FloorNode::getDebugDesc() const {
5585 DescriptionBuilder db(getKindName());
5586 db.addParam("Name", separateString(getName(), 100, "\n"));
5587 if (hasPredicate()) db.addParam("Predicate", "Yes");
5588 db
5589 .addParam("Input", *(getInput().getType()))
5590 .addParam("Users", getNumUsers());
5591 db.addParam("Result", *(getResult().getType()));
5592 return db;
5593}
5594
5595void FloorNode::visit(Node *parent, NodeWalker *visitor) {
5596 if (!visitor->shouldVisit(parent, this)) { return; }
5597 visitor->pre(parent, this);
5598if (hasPredicate())
5599 getPredicate().getNode()->visit(this, visitor);
5600 getInput().getNode()->visit(this, visitor);
5601 visitor->post(parent, this);
5602}
5603
5604bool FloorNode::isEqual(const FloorNode &other) const {
5605 return true &&
5606 Input_ == other.Input_ &&
5607 predicate_ == other.predicate_ &&
5608 getType(0) == other.getType(0);
5609}
5610
5611Node* FloorNode::clone() const {
5612 return new FloorNode(getName(), getResult().getType(), getInput());
5613}
5614
5615llvm::hash_code FloorNode::getHash() const {
5616 return llvm::hash_combine(
5617 Input_);
5618}
5619
5620unsigned SignNode::getNumInputs() const {
5621 return 1;
5622}
5623
5624std::string SignNode::getInputName(unsigned idx) const {
5625 if (idx == 0) { return "Input"; }
5626 idx -= 1;
5627 llvm_unreachable("Invalid index");
5628}
5629
5630NodeValue SignNode::getNthInput(unsigned idx) {
5631 if (idx == 0) { return Input_; }
5632 idx -= 1;
5633 llvm_unreachable("Invalid index");
5634}
5635
5636void SignNode::setNthInput(unsigned idx, NodeValue val) {
5637 if (idx == 0) { Input_ = val; return; }
5638 idx -= 1;
5639 llvm_unreachable("Invalid index");
5640}
5641
5642llvm::StringRef SignNode::getOutputName(unsigned idx) const {
5643 if (idx == 0) { return "Result"; }
5644 llvm_unreachable("Invalid index");
5645}
5646
5647std::string SignNode::getDebugDesc() const {
5648 DescriptionBuilder db(getKindName());
5649 db.addParam("Name", separateString(getName(), 100, "\n"));
5650 if (hasPredicate()) db.addParam("Predicate", "Yes");
5651 db
5652 .addParam("Input", *(getInput().getType()))
5653 .addParam("Users", getNumUsers());
5654 db.addParam("Result", *(getResult().getType()));
5655 return db;
5656}
5657
5658void SignNode::visit(Node *parent, NodeWalker *visitor) {
5659 if (!visitor->shouldVisit(parent, this)) { return; }
5660 visitor->pre(parent, this);
5661if (hasPredicate())
5662 getPredicate().getNode()->visit(this, visitor);
5663 getInput().getNode()->visit(this, visitor);
5664 visitor->post(parent, this);
5665}
5666
5667bool SignNode::isEqual(const SignNode &other) const {
5668 return true &&
5669 Input_ == other.Input_ &&
5670 predicate_ == other.predicate_ &&
5671 getType(0) == other.getType(0);
5672}
5673
5674Node* SignNode::clone() const {
5675 return new SignNode(getName(), getResult().getType(), getInput());
5676}
5677
5678llvm::hash_code SignNode::getHash() const {
5679 return llvm::hash_combine(
5680 Input_);
5681}
5682
5683unsigned CeilNode::getNumInputs() const {
5684 return 1;
5685}
5686
5687std::string CeilNode::getInputName(unsigned idx) const {
5688 if (idx == 0) { return "Input"; }
5689 idx -= 1;
5690 llvm_unreachable("Invalid index");
5691}
5692
5693NodeValue CeilNode::getNthInput(unsigned idx) {
5694 if (idx == 0) { return Input_; }
5695 idx -= 1;
5696 llvm_unreachable("Invalid index");
5697}
5698
5699void CeilNode::setNthInput(unsigned idx, NodeValue val) {
5700 if (idx == 0) { Input_ = val; return; }
5701 idx -= 1;
5702 llvm_unreachable("Invalid index");
5703}
5704
5705llvm::StringRef CeilNode::getOutputName(unsigned idx) const {
5706 if (idx == 0) { return "Result"; }
5707 llvm_unreachable("Invalid index");
5708}
5709
5710std::string CeilNode::getDebugDesc() const {
5711 DescriptionBuilder db(getKindName());
5712 db.addParam("Name", separateString(getName(), 100, "\n"));
5713 if (hasPredicate()) db.addParam("Predicate", "Yes");
5714 db
5715 .addParam("Input", *(getInput().getType()))
5716 .addParam("Users", getNumUsers());
5717 db.addParam("Result", *(getResult().getType()));
5718 return db;
5719}
5720
5721void CeilNode::visit(Node *parent, NodeWalker *visitor) {
5722 if (!visitor->shouldVisit(parent, this)) { return; }
5723 visitor->pre(parent, this);
5724if (hasPredicate())
5725 getPredicate().getNode()->visit(this, visitor);
5726 getInput().getNode()->visit(this, visitor);
5727 visitor->post(parent, this);
5728}
5729
5730bool CeilNode::isEqual(const CeilNode &other) const {
5731 return true &&
5732 Input_ == other.Input_ &&
5733 predicate_ == other.predicate_ &&
5734 getType(0) == other.getType(0);
5735}
5736
5737Node* CeilNode::clone() const {
5738 return new CeilNode(getName(), getResult().getType(), getInput());
5739}
5740
5741llvm::hash_code CeilNode::getHash() const {
5742 return llvm::hash_combine(
5743 Input_);
5744}
5745
5746unsigned RoundNode::getNumInputs() const {
5747 return 1;
5748}
5749
5750std::string RoundNode::getInputName(unsigned idx) const {
5751 if (idx == 0) { return "Input"; }
5752 idx -= 1;
5753 llvm_unreachable("Invalid index");
5754}
5755
5756NodeValue RoundNode::getNthInput(unsigned idx) {
5757 if (idx == 0) { return Input_; }
5758 idx -= 1;
5759 llvm_unreachable("Invalid index");
5760}
5761
5762void RoundNode::setNthInput(unsigned idx, NodeValue val) {
5763 if (idx == 0) { Input_ = val; return; }
5764 idx -= 1;
5765 llvm_unreachable("Invalid index");
5766}
5767
5768llvm::StringRef RoundNode::getOutputName(unsigned idx) const {
5769 if (idx == 0) { return "Result"; }
5770 llvm_unreachable("Invalid index");
5771}
5772
5773std::string RoundNode::getDebugDesc() const {
5774 DescriptionBuilder db(getKindName());
5775 db.addParam("Name", separateString(getName(), 100, "\n"));
5776 if (hasPredicate()) db.addParam("Predicate", "Yes");
5777 db
5778 .addParam("Input", *(getInput().getType()))
5779 .addParam("Users", getNumUsers());
5780 db.addParam("Result", *(getResult().getType()));
5781 return db;
5782}
5783
5784void RoundNode::visit(Node *parent, NodeWalker *visitor) {
5785 if (!visitor->shouldVisit(parent, this)) { return; }
5786 visitor->pre(parent, this);
5787if (hasPredicate())
5788 getPredicate().getNode()->visit(this, visitor);
5789 getInput().getNode()->visit(this, visitor);
5790 visitor->post(parent, this);
5791}
5792
5793bool RoundNode::isEqual(const RoundNode &other) const {
5794 return true &&
5795 Input_ == other.Input_ &&
5796 predicate_ == other.predicate_ &&
5797 getType(0) == other.getType(0);
5798}
5799
5800Node* RoundNode::clone() const {
5801 return new RoundNode(getName(), getResult().getType(), getInput());
5802}
5803
5804llvm::hash_code RoundNode::getHash() const {
5805 return llvm::hash_combine(
5806 Input_);
5807}
5808
5809unsigned TruncateNode::getNumInputs() const {
5810 return 1;
5811}
5812
5813std::string TruncateNode::getInputName(unsigned idx) const {
5814 if (idx == 0) { return "Input"; }
5815 idx -= 1;
5816 llvm_unreachable("Invalid index");
5817}
5818
5819NodeValue TruncateNode::getNthInput(unsigned idx) {
5820 if (idx == 0) { return Input_; }
5821 idx -= 1;
5822 llvm_unreachable("Invalid index");
5823}
5824
5825void TruncateNode::setNthInput(unsigned idx, NodeValue val) {
5826 if (idx == 0) { Input_ = val; return; }
5827 idx -= 1;
5828 llvm_unreachable("Invalid index");
5829}
5830
5831llvm::StringRef TruncateNode::getOutputName(unsigned idx) const {
5832 if (idx == 0) { return "Result"; }
5833 llvm_unreachable("Invalid index");
5834}
5835
5836std::string TruncateNode::getDebugDesc() const {
5837 DescriptionBuilder db(getKindName());
5838 db.addParam("Name", separateString(getName(), 100, "\n"));
5839 if (hasPredicate()) db.addParam("Predicate", "Yes");
5840 db
5841 .addParam("Input", *(getInput().getType()))
5842 .addParam("Users", getNumUsers());
5843 db.addParam("Result", *(getResult().getType()));
5844 return db;
5845}
5846
5847void TruncateNode::visit(Node *parent, NodeWalker *visitor) {
5848 if (!visitor->shouldVisit(parent, this)) { return; }
5849 visitor->pre(parent, this);
5850if (hasPredicate())
5851 getPredicate().getNode()->visit(this, visitor);
5852 getInput().getNode()->visit(this, visitor);
5853 visitor->post(parent, this);
5854}
5855
5856bool TruncateNode::isEqual(const TruncateNode &other) const {
5857 return true &&
5858 Input_ == other.Input_ &&
5859 predicate_ == other.predicate_ &&
5860 getType(0) == other.getType(0);
5861}
5862
5863Node* TruncateNode::clone() const {
5864 return new TruncateNode(getName(), getResult().getType(), getInput());
5865}
5866
5867llvm::hash_code TruncateNode::getHash() const {
5868 return llvm::hash_combine(
5869 Input_);
5870}
5871
5872unsigned SqrtNode::getNumInputs() const {
5873 return 1;
5874}
5875
5876std::string SqrtNode::getInputName(unsigned idx) const {
5877 if (idx == 0) { return "Input"; }
5878 idx -= 1;
5879 llvm_unreachable("Invalid index");
5880}
5881
5882NodeValue SqrtNode::getNthInput(unsigned idx) {
5883 if (idx == 0) { return Input_; }
5884 idx -= 1;
5885 llvm_unreachable("Invalid index");
5886}
5887
5888void SqrtNode::setNthInput(unsigned idx, NodeValue val) {
5889 if (idx == 0) { Input_ = val; return; }
5890 idx -= 1;
5891 llvm_unreachable("Invalid index");
5892}
5893
5894llvm::StringRef SqrtNode::getOutputName(unsigned idx) const {
5895 if (idx == 0) { return "Result"; }
5896 llvm_unreachable("Invalid index");
5897}
5898
5899std::string SqrtNode::getDebugDesc() const {
5900 DescriptionBuilder db(getKindName());
5901 db.addParam("Name", separateString(getName(), 100, "\n"));
5902 if (hasPredicate()) db.addParam("Predicate", "Yes");
5903 db
5904 .addParam("Input", *(getInput().getType()))
5905 .addParam("Users", getNumUsers());
5906 db.addParam("Result", *(getResult().getType()));
5907 return db;
5908}
5909
5910void SqrtNode::visit(Node *parent, NodeWalker *visitor) {
5911 if (!visitor->shouldVisit(parent, this)) { return; }
5912 visitor->pre(parent, this);
5913if (hasPredicate())
5914 getPredicate().getNode()->visit(this, visitor);
5915 getInput().getNode()->visit(this, visitor);
5916 visitor->post(parent, this);
5917}
5918
5919bool SqrtNode::isEqual(const SqrtNode &other) const {
5920 return true &&
5921 Input_ == other.Input_ &&
5922 predicate_ == other.predicate_ &&
5923 getType(0) == other.getType(0);
5924}
5925
5926Node* SqrtNode::clone() const {
5927 return new SqrtNode(getName(), getResult().getType(), getInput());
5928}
5929
5930llvm::hash_code SqrtNode::getHash() const {
5931 return llvm::hash_combine(
5932 Input_);
5933}
5934
5935unsigned RsqrtNode::getNumInputs() const {
5936 return 1;
5937}
5938
5939std::string RsqrtNode::getInputName(unsigned idx) const {
5940 if (idx == 0) { return "Input"; }
5941 idx -= 1;
5942 llvm_unreachable("Invalid index");
5943}
5944
5945NodeValue RsqrtNode::getNthInput(unsigned idx) {
5946 if (idx == 0) { return Input_; }
5947 idx -= 1;
5948 llvm_unreachable("Invalid index");
5949}
5950
5951void RsqrtNode::setNthInput(unsigned idx, NodeValue val) {
5952 if (idx == 0) { Input_ = val; return; }
5953 idx -= 1;
5954 llvm_unreachable("Invalid index");
5955}
5956
5957llvm::StringRef RsqrtNode::getOutputName(unsigned idx) const {
5958 if (idx == 0) { return "Result"; }
5959 llvm_unreachable("Invalid index");
5960}
5961
5962std::string RsqrtNode::getDebugDesc() const {
5963 DescriptionBuilder db(getKindName());
5964 db.addParam("Name", separateString(getName(), 100, "\n"));
5965 if (hasPredicate()) db.addParam("Predicate", "Yes");
5966 db
5967 .addParam("Input", *(getInput().getType()))
5968 .addParam("Users", getNumUsers());
5969 db.addParam("Result", *(getResult().getType()));
5970 return db;
5971}
5972
5973void RsqrtNode::visit(Node *parent, NodeWalker *visitor) {
5974 if (!visitor->shouldVisit(parent, this)) { return; }
5975 visitor->pre(parent, this);
5976if (hasPredicate())
5977 getPredicate().getNode()->visit(this, visitor);
5978 getInput().getNode()->visit(this, visitor);
5979 visitor->post(parent, this);
5980}
5981
5982bool RsqrtNode::isEqual(const RsqrtNode &other) const {
5983 return true &&
5984 Input_ == other.Input_ &&
5985 predicate_ == other.predicate_ &&
5986 getType(0) == other.getType(0);
5987}
5988
5989Node* RsqrtNode::clone() const {
5990 return new RsqrtNode(getName(), getResult().getType(), getInput());
5991}
5992
5993llvm::hash_code RsqrtNode::getHash() const {
5994 return llvm::hash_combine(
5995 Input_);
5996}
5997
5998unsigned ReciprocalNode::getNumInputs() const {
5999 return 1;
6000}
6001
6002std::string ReciprocalNode::getInputName(unsigned idx) const {
6003 if (idx == 0) { return "Input"; }
6004 idx -= 1;
6005 llvm_unreachable("Invalid index");
6006}
6007
6008NodeValue ReciprocalNode::getNthInput(unsigned idx) {
6009 if (idx == 0) { return Input_; }
6010 idx -= 1;
6011 llvm_unreachable("Invalid index");
6012}
6013
6014void ReciprocalNode::setNthInput(unsigned idx, NodeValue val) {
6015 if (idx == 0) { Input_ = val; return; }
6016 idx -= 1;
6017 llvm_unreachable("Invalid index");
6018}
6019
6020llvm::StringRef ReciprocalNode::getOutputName(unsigned idx) const {
6021 if (idx == 0) { return "Result"; }
6022 llvm_unreachable("Invalid index");
6023}
6024
6025std::string ReciprocalNode::getDebugDesc() const {
6026 DescriptionBuilder db(getKindName());
6027 db.addParam("Name", separateString(getName(), 100, "\n"));
6028 if (hasPredicate()) db.addParam("Predicate", "Yes");
6029 db
6030 .addParam("Input", *(getInput().getType()))
6031 .addParam("Users", getNumUsers());
6032 db.addParam("Result", *(getResult().getType()));
6033 return db;
6034}
6035
6036void ReciprocalNode::visit(Node *parent, NodeWalker *visitor) {
6037 if (!visitor->shouldVisit(parent, this)) { return; }
6038 visitor->pre(parent, this);
6039if (hasPredicate())
6040 getPredicate().getNode()->visit(this, visitor);
6041 getInput().getNode()->visit(this, visitor);
6042 visitor->post(parent, this);
6043}
6044
6045bool ReciprocalNode::isEqual(const ReciprocalNode &other) const {
6046 return true &&
6047 Input_ == other.Input_ &&
6048 predicate_ == other.predicate_ &&
6049 getType(0) == other.getType(0);
6050}
6051
6052Node* ReciprocalNode::clone() const {
6053 return new ReciprocalNode(getName(), getResult().getType(), getInput());
6054}
6055
6056llvm::hash_code ReciprocalNode::getHash() const {
6057 return llvm::hash_combine(
6058 Input_);
6059}
6060
6061unsigned SinNode::getNumInputs() const {
6062 return 1;
6063}
6064
6065std::string SinNode::getInputName(unsigned idx) const {
6066 if (idx == 0) { return "Input"; }
6067 idx -= 1;
6068 llvm_unreachable("Invalid index");
6069}
6070
6071NodeValue SinNode::getNthInput(unsigned idx) {
6072 if (idx == 0) { return Input_; }
6073 idx -= 1;
6074 llvm_unreachable("Invalid index");
6075}
6076
6077void SinNode::setNthInput(unsigned idx, NodeValue val) {
6078 if (idx == 0) { Input_ = val; return; }
6079 idx -= 1;
6080 llvm_unreachable("Invalid index");
6081}
6082
6083llvm::StringRef SinNode::getOutputName(unsigned idx) const {
6084 if (idx == 0) { return "Result"; }
6085 llvm_unreachable("Invalid index");
6086}
6087
6088std::string SinNode::getDebugDesc() const {
6089 DescriptionBuilder db(getKindName());
6090 db.addParam("Name", separateString(getName(), 100, "\n"));
6091 if (hasPredicate()) db.addParam("Predicate", "Yes");
6092 db
6093 .addParam("Input", *(getInput().getType()))
6094 .addParam("Users", getNumUsers());
6095 db.addParam("Result", *(getResult().getType()));
6096 return db;
6097}
6098
6099void SinNode::visit(Node *parent, NodeWalker *visitor) {
6100 if (!visitor->shouldVisit(parent, this)) { return; }
6101 visitor->pre(parent, this);
6102if (hasPredicate())
6103 getPredicate().getNode()->visit(this, visitor);
6104 getInput().getNode()->visit(this, visitor);
6105 visitor->post(parent, this);
6106}
6107
6108bool SinNode::isEqual(const SinNode &other) const {
6109 return true &&
6110 Input_ == other.Input_ &&
6111 predicate_ == other.predicate_ &&
6112 getType(0) == other.getType(0);
6113}
6114
6115Node* SinNode::clone() const {
6116 return new SinNode(getName(), getResult().getType(), getInput());
6117}
6118
6119llvm::hash_code SinNode::getHash() const {
6120 return llvm::hash_combine(
6121 Input_);
6122}
6123
6124unsigned CosNode::getNumInputs() const {
6125 return 1;
6126}
6127
6128std::string CosNode::getInputName(unsigned idx) const {
6129 if (idx == 0) { return "Input"; }
6130 idx -= 1;
6131 llvm_unreachable("Invalid index");
6132}
6133
6134NodeValue CosNode::getNthInput(unsigned idx) {
6135 if (idx == 0) { return Input_; }
6136 idx -= 1;
6137 llvm_unreachable("Invalid index");
6138}
6139
6140void CosNode::setNthInput(unsigned idx, NodeValue val) {
6141 if (idx == 0) { Input_ = val; return; }
6142 idx -= 1;
6143 llvm_unreachable("Invalid index");
6144}
6145
6146llvm::StringRef CosNode::getOutputName(unsigned idx) const {
6147 if (idx == 0) { return "Result"; }
6148 llvm_unreachable("Invalid index");
6149}
6150
6151std::string CosNode::getDebugDesc() const {
6152 DescriptionBuilder db(getKindName());
6153 db.addParam("Name", separateString(getName(), 100, "\n"));
6154 if (hasPredicate()) db.addParam("Predicate", "Yes");
6155 db
6156 .addParam("Input", *(getInput().getType()))
6157 .addParam("Users", getNumUsers());
6158 db.addParam("Result", *(getResult().getType()));
6159 return db;
6160}
6161
6162void CosNode::visit(Node *parent, NodeWalker *visitor) {
6163 if (!visitor->shouldVisit(parent, this)) { return; }
6164 visitor->pre(parent, this);
6165if (hasPredicate())
6166 getPredicate().getNode()->visit(this, visitor);
6167 getInput().getNode()->visit(this, visitor);
6168 visitor->post(parent, this);
6169}
6170
6171bool CosNode::isEqual(const CosNode &other) const {
6172 return true &&
6173 Input_ == other.Input_ &&
6174 predicate_ == other.predicate_ &&
6175 getType(0) == other.getType(0);
6176}
6177
6178Node* CosNode::clone() const {
6179 return new CosNode(getName(), getResult().getType(), getInput());
6180}
6181
6182llvm::hash_code CosNode::getHash() const {
6183 return llvm::hash_combine(
6184 Input_);
6185}
6186
6187unsigned LogNode::getNumInputs() const {
6188 return 1;
6189}
6190
6191std::string LogNode::getInputName(unsigned idx) const {
6192 if (idx == 0) { return "Input"; }
6193 idx -= 1;
6194 llvm_unreachable("Invalid index");
6195}
6196
6197NodeValue LogNode::getNthInput(unsigned idx) {
6198 if (idx == 0) { return Input_; }
6199 idx -= 1;
6200 llvm_unreachable("Invalid index");
6201}
6202
6203void LogNode::setNthInput(unsigned idx, NodeValue val) {
6204 if (idx == 0) { Input_ = val; return; }
6205 idx -= 1;
6206 llvm_unreachable("Invalid index");
6207}
6208
6209llvm::StringRef LogNode::getOutputName(unsigned idx) const {
6210 if (idx == 0) { return "Result"; }
6211 llvm_unreachable("Invalid index");
6212}
6213
6214std::string LogNode::getDebugDesc() const {
6215 DescriptionBuilder db(getKindName());
6216 db.addParam("Name", separateString(getName(), 100, "\n"));
6217 if (hasPredicate()) db.addParam("Predicate", "Yes");
6218 db
6219 .addParam("Input", *(getInput().getType()))
6220 .addParam("Users", getNumUsers());
6221 db.addParam("Result", *(getResult().getType()));
6222 return db;
6223}
6224
6225void LogNode::visit(Node *parent, NodeWalker *visitor) {
6226 if (!visitor->shouldVisit(parent, this)) { return; }
6227 visitor->pre(parent, this);
6228if (hasPredicate())
6229 getPredicate().getNode()->visit(this, visitor);
6230 getInput().getNode()->visit(this, visitor);
6231 visitor->post(parent, this);
6232}
6233
6234bool LogNode::isEqual(const LogNode &other) const {
6235 return true &&
6236 Input_ == other.Input_ &&
6237 predicate_ == other.predicate_ &&
6238 getType(0) == other.getType(0);
6239}
6240
6241Node* LogNode::clone() const {
6242 return new LogNode(getName(), getResult().getType(), getInput());
6243}
6244
6245llvm::hash_code LogNode::getHash() const {
6246 return llvm::hash_combine(
6247 Input_);
6248}
6249
6250unsigned AcosNode::getNumInputs() const {
6251 return 1;
6252}
6253
6254std::string AcosNode::getInputName(unsigned idx) const {
6255 if (idx == 0) { return "Input"; }
6256 idx -= 1;
6257 llvm_unreachable("Invalid index");
6258}
6259
6260NodeValue AcosNode::getNthInput(unsigned idx) {
6261 if (idx == 0) { return Input_; }
6262 idx -= 1;
6263 llvm_unreachable("Invalid index");
6264}
6265
6266void AcosNode::setNthInput(unsigned idx, NodeValue val) {
6267 if (idx == 0) { Input_ = val; return; }
6268 idx -= 1;
6269 llvm_unreachable("Invalid index");
6270}
6271
6272llvm::StringRef AcosNode::getOutputName(unsigned idx) const {
6273 if (idx == 0) { return "Result"; }
6274 llvm_unreachable("Invalid index");
6275}
6276
6277std::string AcosNode::getDebugDesc() const {
6278 DescriptionBuilder db(getKindName());
6279 db.addParam("Name", separateString(getName(), 100, "\n"));
6280 if (hasPredicate()) db.addParam("Predicate", "Yes");
6281 db
6282 .addParam("Input", *(getInput().getType()))
6283 .addParam("Users", getNumUsers());
6284 db.addParam("Result", *(getResult().getType()));
6285 return db;
6286}
6287
6288void AcosNode::visit(Node *parent, NodeWalker *visitor) {
6289 if (!visitor->shouldVisit(parent, this)) { return; }
6290 visitor->pre(parent, this);
6291if (hasPredicate())
6292 getPredicate().getNode()->visit(this, visitor);
6293 getInput().getNode()->visit(this, visitor);
6294 visitor->post(parent, this);
6295}
6296
6297bool AcosNode::isEqual(const AcosNode &other) const {
6298 return true &&
6299 Input_ == other.Input_ &&
6300 predicate_ == other.predicate_ &&
6301 getType(0) == other.getType(0);
6302}
6303
6304Node* AcosNode::clone() const {
6305 return new AcosNode(getName(), getResult().getType(), getInput());
6306}
6307
6308llvm::hash_code AcosNode::getHash() const {
6309 return llvm::hash_combine(
6310 Input_);
6311}
6312
6313unsigned AsinNode::getNumInputs() const {
6314 return 1;
6315}
6316
6317std::string AsinNode::getInputName(unsigned idx) const {
6318 if (idx == 0) { return "Input"; }
6319 idx -= 1;
6320 llvm_unreachable("Invalid index");
6321}
6322
6323NodeValue AsinNode::getNthInput(unsigned idx) {
6324 if (idx == 0) { return Input_; }
6325 idx -= 1;
6326 llvm_unreachable("Invalid index");
6327}
6328
6329void AsinNode::setNthInput(unsigned idx, NodeValue val) {
6330 if (idx == 0) { Input_ = val; return; }
6331 idx -= 1;
6332 llvm_unreachable("Invalid index");
6333}
6334
6335llvm::StringRef AsinNode::getOutputName(unsigned idx) const {
6336 if (idx == 0) { return "Result"; }
6337 llvm_unreachable("Invalid index");
6338}
6339
6340std::string AsinNode::getDebugDesc() const {
6341 DescriptionBuilder db(getKindName());
6342 db.addParam("Name", separateString(getName(), 100, "\n"));
6343 if (hasPredicate()) db.addParam("Predicate", "Yes");
6344 db
6345 .addParam("Input", *(getInput().getType()))
6346 .addParam("Users", getNumUsers());
6347 db.addParam("Result", *(getResult().getType()));
6348 return db;
6349}
6350
6351void AsinNode::visit(Node *parent, NodeWalker *visitor) {
6352 if (!visitor->shouldVisit(parent, this)) { return; }
6353 visitor->pre(parent, this);
6354if (hasPredicate())
6355 getPredicate().getNode()->visit(this, visitor);
6356 getInput().getNode()->visit(this, visitor);
6357 visitor->post(parent, this);
6358}
6359
6360bool AsinNode::isEqual(const AsinNode &other) const {
6361 return true &&
6362 Input_ == other.Input_ &&
6363 predicate_ == other.predicate_ &&
6364 getType(0) == other.getType(0);
6365}
6366
6367Node* AsinNode::clone() const {
6368 return new AsinNode(getName(), getResult().getType(), getInput());
6369}
6370
6371llvm::hash_code AsinNode::getHash() const {
6372 return llvm::hash_combine(
6373 Input_);
6374}
6375
6376unsigned AtanNode::getNumInputs() const {
6377 return 1;
6378}
6379
6380std::string AtanNode::getInputName(unsigned idx) const {
6381 if (idx == 0) { return "Input"; }
6382 idx -= 1;
6383 llvm_unreachable("Invalid index");
6384}
6385
6386NodeValue AtanNode::getNthInput(unsigned idx) {
6387 if (idx == 0) { return Input_; }
6388 idx -= 1;
6389 llvm_unreachable("Invalid index");
6390}
6391
6392void AtanNode::setNthInput(unsigned idx, NodeValue val) {
6393 if (idx == 0) { Input_ = val; return; }
6394 idx -= 1;
6395 llvm_unreachable("Invalid index");
6396}
6397
6398llvm::StringRef AtanNode::getOutputName(unsigned idx) const {
6399 if (idx == 0) { return "Result"; }
6400 llvm_unreachable("Invalid index");
6401}
6402
6403std::string AtanNode::getDebugDesc() const {
6404 DescriptionBuilder db(getKindName());
6405 db.addParam("Name", separateString(getName(), 100, "\n"));
6406 if (hasPredicate()) db.addParam("Predicate", "Yes");
6407 db
6408 .addParam("Input", *(getInput().getType()))
6409 .addParam("Users", getNumUsers());
6410 db.addParam("Result", *(getResult().getType()));
6411 return db;
6412}
6413
6414void AtanNode::visit(Node *parent, NodeWalker *visitor) {
6415 if (!visitor->shouldVisit(parent, this)) { return; }
6416 visitor->pre(parent, this);
6417if (hasPredicate())
6418 getPredicate().getNode()->visit(this, visitor);
6419 getInput().getNode()->visit(this, visitor);
6420 visitor->post(parent, this);
6421}
6422
6423bool AtanNode::isEqual(const AtanNode &other) const {
6424 return true &&
6425 Input_ == other.Input_ &&
6426 predicate_ == other.predicate_ &&
6427 getType(0) == other.getType(0);
6428}
6429
6430Node* AtanNode::clone() const {
6431 return new AtanNode(getName(), getResult().getType(), getInput());
6432}
6433
6434llvm::hash_code AtanNode::getHash() const {
6435 return llvm::hash_combine(
6436 Input_);
6437}
6438
6439unsigned ErfNode::getNumInputs() const {
6440 return 1;
6441}
6442
6443std::string ErfNode::getInputName(unsigned idx) const {
6444 if (idx == 0) { return "Input"; }
6445 idx -= 1;
6446 llvm_unreachable("Invalid index");
6447}
6448
6449NodeValue ErfNode::getNthInput(unsigned idx) {
6450 if (idx == 0) { return Input_; }
6451 idx -= 1;
6452 llvm_unreachable("Invalid index");
6453}
6454
6455void ErfNode::setNthInput(unsigned idx, NodeValue val) {
6456 if (idx == 0) { Input_ = val; return; }
6457 idx -= 1;
6458 llvm_unreachable("Invalid index");
6459}
6460
6461llvm::StringRef ErfNode::getOutputName(unsigned idx) const {
6462 if (idx == 0) { return "Result"; }
6463 llvm_unreachable("Invalid index");
6464}
6465
6466std::string ErfNode::getDebugDesc() const {
6467 DescriptionBuilder db(getKindName());
6468 db.addParam("Name", separateString(getName(), 100, "\n"));
6469 if (hasPredicate()) db.addParam("Predicate", "Yes");
6470 db
6471 .addParam("Input", *(getInput().getType()))
6472 .addParam("Users", getNumUsers());
6473 db.addParam("Result", *(getResult().getType()));
6474 return db;
6475}
6476
6477void ErfNode::visit(Node *parent, NodeWalker *visitor) {
6478 if (!visitor->shouldVisit(parent, this)) { return; }
6479 visitor->pre(parent, this);
6480if (hasPredicate())
6481 getPredicate().getNode()->visit(this, visitor);
6482 getInput().getNode()->visit(this, visitor);
6483 visitor->post(parent, this);
6484}
6485
6486bool ErfNode::isEqual(const ErfNode &other) const {
6487 return true &&
6488 Input_ == other.Input_ &&
6489 predicate_ == other.predicate_ &&
6490 getType(0) == other.getType(0);
6491}
6492
6493Node* ErfNode::clone() const {
6494 return new ErfNode(getName(), getResult().getType(), getInput());
6495}
6496
6497llvm::hash_code ErfNode::getHash() const {
6498 return llvm::hash_combine(
6499 Input_);
6500}
6501
6502unsigned ExpNode::getNumInputs() const {
6503 return 1;
6504}
6505
6506std::string ExpNode::getInputName(unsigned idx) const {
6507 if (idx == 0) { return "Input"; }
6508 idx -= 1;
6509 llvm_unreachable("Invalid index");
6510}
6511
6512NodeValue ExpNode::getNthInput(unsigned idx) {
6513 if (idx == 0) { return Input_; }
6514 idx -= 1;
6515 llvm_unreachable("Invalid index");
6516}
6517
6518void ExpNode::setNthInput(unsigned idx, NodeValue val) {
6519 if (idx == 0) { Input_ = val; return; }
6520 idx -= 1;
6521 llvm_unreachable("Invalid index");
6522}
6523
6524llvm::StringRef ExpNode::getOutputName(unsigned idx) const {
6525 if (idx == 0) { return "Result"; }
6526 llvm_unreachable("Invalid index");
6527}
6528
6529std::string ExpNode::getDebugDesc() const {
6530 DescriptionBuilder db(getKindName());
6531 db.addParam("Name", separateString(getName(), 100, "\n"));
6532 if (hasPredicate()) db.addParam("Predicate", "Yes");
6533 db
6534 .addParam("Input", *(getInput().getType()))
6535 .addParam("Users", getNumUsers());
6536 db.addParam("Result", *(getResult().getType()));
6537 return db;
6538}
6539
6540void ExpNode::visit(Node *parent, NodeWalker *visitor) {
6541 if (!visitor->shouldVisit(parent, this)) { return; }
6542 visitor->pre(parent, this);
6543if (hasPredicate())
6544 getPredicate().getNode()->visit(this, visitor);
6545 getInput().getNode()->visit(this, visitor);
6546 visitor->post(parent, this);
6547}
6548
6549bool ExpNode::isEqual(const ExpNode &other) const {
6550 return true &&
6551 Input_ == other.Input_ &&
6552 predicate_ == other.predicate_ &&
6553 getType(0) == other.getType(0);
6554}
6555
6556Node* ExpNode::clone() const {
6557 return new ExpNode(getName(), getResult().getType(), getInput());
6558}
6559
6560llvm::hash_code ExpNode::getHash() const {
6561 return llvm::hash_combine(
6562 Input_);
6563}
6564
6565unsigned LogitNode::getNumInputs() const {
6566 return 1;
6567}
6568
6569std::string LogitNode::getInputName(unsigned idx) const {
6570 if (idx == 0) { return "Input"; }
6571 idx -= 1;
6572 llvm_unreachable("Invalid index");
6573}
6574
6575NodeValue LogitNode::getNthInput(unsigned idx) {
6576 if (idx == 0) { return Input_; }
6577 idx -= 1;
6578 llvm_unreachable("Invalid index");
6579}
6580
6581void LogitNode::setNthInput(unsigned idx, NodeValue val) {
6582 if (idx == 0) { Input_ = val; return; }
6583 idx -= 1;
6584 llvm_unreachable("Invalid index");
6585}
6586
6587llvm::StringRef LogitNode::getOutputName(unsigned idx) const {
6588 if (idx == 0) { return "Result"; }
6589 llvm_unreachable("Invalid index");
6590}
6591
6592std::string LogitNode::getDebugDesc() const {
6593 DescriptionBuilder db(getKindName());
6594 db.addParam("Name", separateString(getName(), 100, "\n"));
6595 if (hasPredicate()) db.addParam("Predicate", "Yes");
6596 db
6597 .addParam("Input", *(getInput().getType()))
6598 .addParam("Epsilon", getEpsilon())
6599 .addParam("Users", getNumUsers());
6600 db.addParam("Result", *(getResult().getType()));
6601 return db;
6602}
6603
6604void LogitNode::visit(Node *parent, NodeWalker *visitor) {
6605 if (!visitor->shouldVisit(parent, this)) { return; }
6606 visitor->pre(parent, this);
6607if (hasPredicate())
6608 getPredicate().getNode()->visit(this, visitor);
6609 getInput().getNode()->visit(this, visitor);
6610 visitor->post(parent, this);
6611}
6612
6613bool LogitNode::isEqual(const LogitNode &other) const {
6614 return true &&
6615 Input_ == other.Input_ &&
6616 predicate_ == other.predicate_ &&
6617 Epsilon_ == other.Epsilon_ &&
6618 getType(0) == other.getType(0);
6619}
6620
6621Node* LogitNode::clone() const {
6622 return new LogitNode(getName(), getResult().getType(), getInput(), getEpsilon());
6623}
6624
6625llvm::hash_code LogitNode::getHash() const {
6626 return llvm::hash_combine(
6627 toBinary(Epsilon_),
6628 Input_);
6629}
6630
6631unsigned NonZeroNode::getNumInputs() const {
6632 return 1;
6633}
6634
6635std::string NonZeroNode::getInputName(unsigned idx) const {
6636 if (idx == 0) { return "Cond"; }
6637 idx -= 1;
6638 llvm_unreachable("Invalid index");
6639}
6640
6641NodeValue NonZeroNode::getNthInput(unsigned idx) {
6642 if (idx == 0) { return Cond_; }
6643 idx -= 1;
6644 llvm_unreachable("Invalid index");
6645}
6646
6647void NonZeroNode::setNthInput(unsigned idx, NodeValue val) {
6648 if (idx == 0) { Cond_ = val; return; }
6649 idx -= 1;
6650 llvm_unreachable("Invalid index");
6651}
6652
6653llvm::StringRef NonZeroNode::getOutputName(unsigned idx) const {
6654 if (idx == 0) { return "Result"; }
6655 llvm_unreachable("Invalid index");
6656}
6657
6658std::string NonZeroNode::getDebugDesc() const {
6659 DescriptionBuilder db(getKindName());
6660 db.addParam("Name", separateString(getName(), 100, "\n"));
6661 if (hasPredicate()) db.addParam("Predicate", "Yes");
6662 db
6663 .addParam("Cond", *(getCond().getType()))
6664 .addParam("Users", getNumUsers());
6665 db.addParam("Result", *(getResult().getType()));
6666 return db;
6667}
6668
6669void NonZeroNode::visit(Node *parent, NodeWalker *visitor) {
6670 if (!visitor->shouldVisit(parent, this)) { return; }
6671 visitor->pre(parent, this);
6672if (hasPredicate())
6673 getPredicate().getNode()->visit(this, visitor);
6674 getCond().getNode()->visit(this, visitor);
6675 visitor->post(parent, this);
6676}
6677
6678bool NonZeroNode::isEqual(const NonZeroNode &other) const {
6679 return true &&
6680 Cond_ == other.Cond_ &&
6681 predicate_ == other.predicate_ &&
6682 getType(0) == other.getType(0);
6683}
6684
6685Node* NonZeroNode::clone() const {
6686 return new NonZeroNode(getName(), getResult().getType(), getCond());
6687}
6688
6689llvm::hash_code NonZeroNode::getHash() const {
6690 return llvm::hash_combine(
6691 Cond_);
6692}
6693
6694unsigned SelectNode::getNumInputs() const {
6695 return 3;
6696}
6697
6698std::string SelectNode::getInputName(unsigned idx) const {
6699 if (idx == 0) { return "Cond"; }
6700 if (idx == 1) { return "LHS"; }
6701 if (idx == 2) { return "RHS"; }
6702 idx -= 3;
6703 llvm_unreachable("Invalid index");
6704}
6705
6706NodeValue SelectNode::getNthInput(unsigned idx) {
6707 if (idx == 0) { return Cond_; }
6708 if (idx == 1) { return LHS_; }
6709 if (idx == 2) { return RHS_; }
6710 idx -= 3;
6711 llvm_unreachable("Invalid index");
6712}
6713
6714void SelectNode::setNthInput(unsigned idx, NodeValue val) {
6715 if (idx == 0) { Cond_ = val; return; }
6716 if (idx == 1) { LHS_ = val; return; }
6717 if (idx == 2) { RHS_ = val; return; }
6718 idx -= 3;
6719 llvm_unreachable("Invalid index");
6720}
6721
6722llvm::StringRef SelectNode::getOutputName(unsigned idx) const {
6723 if (idx == 0) { return "Result"; }
6724 llvm_unreachable("Invalid index");
6725}
6726
6727std::string SelectNode::getDebugDesc() const {
6728 DescriptionBuilder db(getKindName());
6729 db.addParam("Name", separateString(getName(), 100, "\n"));
6730 if (hasPredicate()) db.addParam("Predicate", "Yes");
6731 db
6732 .addParam("Cond", *(getCond().getType()))
6733 .addParam("LHS", *(getLHS().getType()))
6734 .addParam("RHS", *(getRHS().getType()))
6735 .addParam("Users", getNumUsers());
6736 db.addParam("Result", *(getResult().getType()));
6737 return db;
6738}
6739
6740void SelectNode::visit(Node *parent, NodeWalker *visitor) {
6741 if (!visitor->shouldVisit(parent, this)) { return; }
6742 visitor->pre(parent, this);
6743if (hasPredicate())
6744 getPredicate().getNode()->visit(this, visitor);
6745 getCond().getNode()->visit(this, visitor);
6746 getLHS().getNode()->visit(this, visitor);
6747 getRHS().getNode()->visit(this, visitor);
6748 visitor->post(parent, this);
6749}
6750
6751bool SelectNode::isEqual(const SelectNode &other) const {
6752 return true &&
6753 Cond_ == other.Cond_ &&
6754 LHS_ == other.LHS_ &&
6755 RHS_ == other.RHS_ &&
6756 predicate_ == other.predicate_ &&
6757 getType(0) == other.getType(0);
6758}
6759
6760Node* SelectNode::clone() const {
6761 return new SelectNode(getName(), getResult().getType(), getCond(), getLHS(), getRHS());
6762}
6763
6764llvm::hash_code SelectNode::getHash() const {
6765 return llvm::hash_combine(
6766 Cond_,
6767 LHS_,
6768 RHS_);
6769}
6770
6771unsigned BatchedAddNode::getNumInputs() const {
6772 return 2;
6773}
6774
6775std::string BatchedAddNode::getInputName(unsigned idx) const {
6776 if (idx == 0) { return "Batch"; }
6777 if (idx == 1) { return "Slice"; }
6778 idx -= 2;
6779 llvm_unreachable("Invalid index");
6780}
6781
6782NodeValue BatchedAddNode::getNthInput(unsigned idx) {
6783 if (idx == 0) { return Batch_; }
6784 if (idx == 1) { return Slice_; }
6785 idx -= 2;
6786 llvm_unreachable("Invalid index");
6787}
6788
6789void BatchedAddNode::setNthInput(unsigned idx, NodeValue val) {
6790 if (idx == 0) { Batch_ = val; return; }
6791 if (idx == 1) { Slice_ = val; return; }
6792 idx -= 2;
6793 llvm_unreachable("Invalid index");
6794}
6795
6796llvm::StringRef BatchedAddNode::getOutputName(unsigned idx) const {
6797 if (idx == 0) { return "Result"; }
6798 llvm_unreachable("Invalid index");
6799}
6800
6801std::string BatchedAddNode::getDebugDesc() const {
6802 DescriptionBuilder db(getKindName());
6803 db.addParam("Name", separateString(getName(), 100, "\n"));
6804 if (hasPredicate()) db.addParam("Predicate", "Yes");
6805 db
6806 .addParam("Batch", *(getBatch().getType()))
6807 .addParam("Slice", *(getSlice().getType()))
6808 .addParam("Users", getNumUsers());
6809 db.addParam("Result", *(getResult().getType()));
6810 return db;
6811}
6812
6813void BatchedAddNode::visit(Node *parent, NodeWalker *visitor) {
6814 if (!visitor->shouldVisit(parent, this)) { return; }
6815 visitor->pre(parent, this);
6816if (hasPredicate())
6817 getPredicate().getNode()->visit(this, visitor);
6818 getBatch().getNode()->visit(this, visitor);
6819 getSlice().getNode()->visit(this, visitor);
6820 visitor->post(parent, this);
6821}
6822
6823bool BatchedAddNode::isEqual(const BatchedAddNode &other) const {
6824 return true &&
6825 Batch_ == other.Batch_ &&
6826 Slice_ == other.Slice_ &&
6827 predicate_ == other.predicate_ &&
6828 getType(0) == other.getType(0);
6829}
6830
6831Node* BatchedAddNode::clone() const {
6832 return new BatchedAddNode(getName(), getResult().getType(), getBatch(), getSlice());
6833}
6834
6835llvm::hash_code BatchedAddNode::getHash() const {
6836 return llvm::hash_combine(
6837 Batch_,
6838 Slice_);
6839}
6840
6841unsigned BatchedMulNode::getNumInputs() const {
6842 return 2;
6843}
6844
6845std::string BatchedMulNode::getInputName(unsigned idx) const {
6846 if (idx == 0) { return "Batch"; }
6847 if (idx == 1) { return "Slice"; }
6848 idx -= 2;
6849 llvm_unreachable("Invalid index");
6850}
6851
6852NodeValue BatchedMulNode::getNthInput(unsigned idx) {
6853 if (idx == 0) { return Batch_; }
6854 if (idx == 1) { return Slice_; }
6855 idx -= 2;
6856 llvm_unreachable("Invalid index");
6857}
6858
6859void BatchedMulNode::setNthInput(unsigned idx, NodeValue val) {
6860 if (idx == 0) { Batch_ = val; return; }
6861 if (idx == 1) { Slice_ = val; return; }
6862 idx -= 2;
6863 llvm_unreachable("Invalid index");
6864}
6865
6866llvm::StringRef BatchedMulNode::getOutputName(unsigned idx) const {
6867 if (idx == 0) { return "Result"; }
6868 llvm_unreachable("Invalid index");
6869}
6870
6871std::string BatchedMulNode::getDebugDesc() const {
6872 DescriptionBuilder db(getKindName());
6873 db.addParam("Name", separateString(getName(), 100, "\n"));
6874 if (hasPredicate()) db.addParam("Predicate", "Yes");
6875 db
6876 .addParam("Batch", *(getBatch().getType()))
6877 .addParam("Slice", *(getSlice().getType()))
6878 .addParam("Users", getNumUsers());
6879 db.addParam("Result", *(getResult().getType()));
6880 return db;
6881}
6882
6883void BatchedMulNode::visit(Node *parent, NodeWalker *visitor) {
6884 if (!visitor->shouldVisit(parent, this)) { return; }
6885 visitor->pre(parent, this);
6886if (hasPredicate())
6887 getPredicate().getNode()->visit(this, visitor);
6888 getBatch().getNode()->visit(this, visitor);
6889 getSlice().getNode()->visit(this, visitor);
6890 visitor->post(parent, this);
6891}
6892
6893bool BatchedMulNode::isEqual(const BatchedMulNode &other) const {
6894 return true &&
6895 Batch_ == other.Batch_ &&
6896 Slice_ == other.Slice_ &&
6897 predicate_ == other.predicate_ &&
6898 getType(0) == other.getType(0);
6899}
6900
6901Node* BatchedMulNode::clone() const {
6902 return new BatchedMulNode(getName(), getResult().getType(), getBatch(), getSlice());
6903}
6904
6905llvm::hash_code BatchedMulNode::getHash() const {
6906 return llvm::hash_combine(
6907 Batch_,
6908 Slice_);
6909}
6910
6911unsigned MatMulNode::getNumInputs() const {
6912 return 2;
6913}
6914
6915std::string MatMulNode::getInputName(unsigned idx) const {
6916 if (idx == 0) { return "LHS"; }
6917 if (idx == 1) { return "RHS"; }
6918 idx -= 2;
6919 llvm_unreachable("Invalid index");
6920}
6921
6922NodeValue MatMulNode::getNthInput(unsigned idx) {
6923 if (idx == 0) { return LHS_; }
6924 if (idx == 1) { return RHS_; }
6925 idx -= 2;
6926 llvm_unreachable("Invalid index");
6927}
6928
6929void MatMulNode::setNthInput(unsigned idx, NodeValue val) {
6930 if (idx == 0) { LHS_ = val; return; }
6931 if (idx == 1) { RHS_ = val; return; }
6932 idx -= 2;
6933 llvm_unreachable("Invalid index");
6934}
6935
6936llvm::StringRef MatMulNode::getOutputName(unsigned idx) const {
6937 if (idx == 0) { return "Result"; }
6938 llvm_unreachable("Invalid index");
6939}
6940
6941std::string MatMulNode::getDebugDesc() const {
6942 DescriptionBuilder db(getKindName());
6943 db.addParam("Name", separateString(getName(), 100, "\n"));
6944 if (hasPredicate()) db.addParam("Predicate", "Yes");
6945 db
6946 .addParam("LHS", *(getLHS().getType()))
6947 .addParam("RHS", *(getRHS().getType()))
6948 .addParam("Users", getNumUsers());
6949 db.addParam("Result", *(getResult().getType()));
6950 return db;
6951}
6952
6953void MatMulNode::visit(Node *parent, NodeWalker *visitor) {
6954 if (!visitor->shouldVisit(parent, this)) { return; }
6955 visitor->pre(parent, this);
6956if (hasPredicate())
6957 getPredicate().getNode()->visit(this, visitor);
6958 getLHS().getNode()->visit(this, visitor);
6959 getRHS().getNode()->visit(this, visitor);
6960 visitor->post(parent, this);
6961}
6962
6963bool MatMulNode::isEqual(const MatMulNode &other) const {
6964 return true &&
6965 LHS_ == other.LHS_ &&
6966 RHS_ == other.RHS_ &&
6967 predicate_ == other.predicate_ &&
6968 getType(0) == other.getType(0);
6969}
6970
6971Node* MatMulNode::clone() const {
6972 return new MatMulNode(getName(), getResult().getType(), getLHS(), getRHS());
6973}
6974
6975llvm::hash_code MatMulNode::getHash() const {
6976 return llvm::hash_combine(
6977 LHS_,
6978 RHS_);
6979}
6980
6981unsigned BatchMatMulNode::getNumInputs() const {
6982 return 2;
6983}
6984
6985std::string BatchMatMulNode::getInputName(unsigned idx) const {
6986 if (idx == 0) { return "LHS"; }
6987 if (idx == 1) { return "RHS"; }
6988 idx -= 2;
6989 llvm_unreachable("Invalid index");
6990}
6991
6992NodeValue BatchMatMulNode::getNthInput(unsigned idx) {
6993 if (idx == 0) { return LHS_; }
6994 if (idx == 1) { return RHS_; }
6995 idx -= 2;
6996 llvm_unreachable("Invalid index");
6997}
6998
6999void BatchMatMulNode::setNthInput(unsigned idx, NodeValue val) {
7000 if (idx == 0) { LHS_ = val; return; }
7001 if (idx == 1) { RHS_ = val; return; }
7002 idx -= 2;
7003 llvm_unreachable("Invalid index");
7004}
7005
7006llvm::StringRef BatchMatMulNode::getOutputName(unsigned idx) const {
7007 if (idx == 0) { return "Result"; }
7008 llvm_unreachable("Invalid index");
7009}
7010
7011std::string BatchMatMulNode::getDebugDesc() const {
7012 DescriptionBuilder db(getKindName());
7013 db.addParam("Name", separateString(getName(), 100, "\n"));
7014 if (hasPredicate()) db.addParam("Predicate", "Yes");
7015 db
7016 .addParam("LHS", *(getLHS().getType()))
7017 .addParam("RHS", *(getRHS().getType()))
7018 .addParam("Users", getNumUsers());
7019 db.addParam("Result", *(getResult().getType()));
7020 return db;
7021}
7022
7023void BatchMatMulNode::visit(Node *parent, NodeWalker *visitor) {
7024 if (!visitor->shouldVisit(parent, this)) { return; }
7025 visitor->pre(parent, this);
7026if (hasPredicate())
7027 getPredicate().getNode()->visit(this, visitor);
7028 getLHS().getNode()->visit(this, visitor);
7029 getRHS().getNode()->visit(this, visitor);
7030 visitor->post(parent, this);
7031}
7032
7033bool BatchMatMulNode::isEqual(const BatchMatMulNode &other) const {
7034 return true &&
7035 LHS_ == other.LHS_ &&
7036 RHS_ == other.RHS_ &&
7037 predicate_ == other.predicate_ &&
7038 getType(0) == other.getType(0);
7039}
7040
7041Node* BatchMatMulNode::clone() const {
7042 return new BatchMatMulNode(getName(), getResult().getType(), getLHS(), getRHS());
7043}
7044
7045llvm::hash_code BatchMatMulNode::getHash() const {
7046 return llvm::hash_combine(
7047 LHS_,
7048 RHS_);
7049}
7050
7051unsigned BatchedReduceAddNode::getNumInputs() const {
7052 return 1;
7053}
7054
7055std::string BatchedReduceAddNode::getInputName(unsigned idx) const {
7056 if (idx == 0) { return "Batch"; }
7057 idx -= 1;
7058 llvm_unreachable("Invalid index");
7059}
7060
7061NodeValue BatchedReduceAddNode::getNthInput(unsigned idx) {
7062 if (idx == 0) { return Batch_; }
7063 idx -= 1;
7064 llvm_unreachable("Invalid index");
7065}
7066
7067void BatchedReduceAddNode::setNthInput(unsigned idx, NodeValue val) {
7068 if (idx == 0) { Batch_ = val; return; }
7069 idx -= 1;
7070 llvm_unreachable("Invalid index");
7071}
7072
7073llvm::StringRef BatchedReduceAddNode::getOutputName(unsigned idx) const {
7074 if (idx == 0) { return "Result"; }
7075 llvm_unreachable("Invalid index");
7076}
7077
7078std::string BatchedReduceAddNode::getDebugDesc() const {
7079 DescriptionBuilder db(getKindName());
7080 db.addParam("Name", separateString(getName(), 100, "\n"));
7081 if (hasPredicate()) db.addParam("Predicate", "Yes");
7082 db
7083 .addParam("Batch", *(getBatch().getType()))
7084 .addParam("Axis", getAxis())
7085 .addParam("Users", getNumUsers());
7086 db.addParam("Result", *(getResult().getType()));
7087 return db;
7088}
7089
7090void BatchedReduceAddNode::visit(Node *parent, NodeWalker *visitor) {
7091 if (!visitor->shouldVisit(parent, this)) { return; }
7092 visitor->pre(parent, this);
7093if (hasPredicate())
7094 getPredicate().getNode()->visit(this, visitor);
7095 getBatch().getNode()->visit(this, visitor);
7096 visitor->post(parent, this);
7097}
7098
7099bool BatchedReduceAddNode::isEqual(const BatchedReduceAddNode &other) const {
7100 return true &&
7101 Batch_ == other.Batch_ &&
7102 predicate_ == other.predicate_ &&
7103 Axis_ == other.Axis_ &&
7104 getType(0) == other.getType(0);
7105}
7106
7107Node* BatchedReduceAddNode::clone() const {
7108 return new BatchedReduceAddNode(getName(), getResult().getType(), getBatch(), getAxis());
7109}
7110
7111llvm::hash_code BatchedReduceAddNode::getHash() const {
7112 return llvm::hash_combine(
7113 Axis_,
7114 Batch_);
7115}
7116
7117unsigned BatchedReduceSumSquareNode::getNumInputs() const {
7118 return 1;
7119}
7120
7121std::string BatchedReduceSumSquareNode::getInputName(unsigned idx) const {
7122 if (idx == 0) { return "Batch"; }
7123 idx -= 1;
7124 llvm_unreachable("Invalid index");
7125}
7126
7127NodeValue BatchedReduceSumSquareNode::getNthInput(unsigned idx) {
7128 if (idx == 0) { return Batch_; }
7129 idx -= 1;
7130 llvm_unreachable("Invalid index");
7131}
7132
7133void BatchedReduceSumSquareNode::setNthInput(unsigned idx, NodeValue val) {
7134 if (idx == 0) { Batch_ = val; return; }
7135 idx -= 1;
7136 llvm_unreachable("Invalid index");
7137}
7138
7139llvm::StringRef BatchedReduceSumSquareNode::getOutputName(unsigned idx) const {
7140 if (idx == 0) { return "Result"; }
7141 llvm_unreachable("Invalid index");
7142}
7143
7144std::string BatchedReduceSumSquareNode::getDebugDesc() const {
7145 DescriptionBuilder db(getKindName());
7146 db.addParam("Name", separateString(getName(), 100, "\n"));
7147 if (hasPredicate()) db.addParam("Predicate", "Yes");
7148 db
7149 .addParam("Batch", *(getBatch().getType()))
7150 .addParam("Axis", getAxis())
7151 .addParam("Users", getNumUsers());
7152 db.addParam("Result", *(getResult().getType()));
7153 return db;
7154}
7155
7156void BatchedReduceSumSquareNode::visit(Node *parent, NodeWalker *visitor) {
7157 if (!visitor->shouldVisit(parent, this)) { return; }
7158 visitor->pre(parent, this);
7159if (hasPredicate())
7160 getPredicate().getNode()->visit(this, visitor);
7161 getBatch().getNode()->visit(this, visitor);
7162 visitor->post(parent, this);
7163}
7164
7165bool BatchedReduceSumSquareNode::isEqual(const BatchedReduceSumSquareNode &other) const {
7166 return true &&
7167 Batch_ == other.Batch_ &&
7168 predicate_ == other.predicate_ &&
7169 Axis_ == other.Axis_ &&
7170 getType(0) == other.getType(0);
7171}
7172
7173Node* BatchedReduceSumSquareNode::clone() const {
7174 return new BatchedReduceSumSquareNode(getName(), getResult().getType(), getBatch(), getAxis());
7175}
7176
7177llvm::hash_code BatchedReduceSumSquareNode::getHash() const {
7178 return llvm::hash_combine(
7179 Axis_,
7180 Batch_);
7181}
7182
7183unsigned BatchedReduceMeanNode::getNumInputs() const {
7184 return 1;
7185}
7186
7187std::string BatchedReduceMeanNode::getInputName(unsigned idx) const {
7188 if (idx == 0) { return "Batch"; }
7189 idx -= 1;
7190 llvm_unreachable("Invalid index");
7191}
7192
7193NodeValue BatchedReduceMeanNode::getNthInput(unsigned idx) {
7194 if (idx == 0) { return Batch_; }
7195 idx -= 1;
7196 llvm_unreachable("Invalid index");
7197}
7198
7199void BatchedReduceMeanNode::setNthInput(unsigned idx, NodeValue val) {
7200 if (idx == 0) { Batch_ = val; return; }
7201 idx -= 1;
7202 llvm_unreachable("Invalid index");
7203}
7204
7205llvm::StringRef BatchedReduceMeanNode::getOutputName(unsigned idx) const {
7206 if (idx == 0) { return "Result"; }
7207 llvm_unreachable("Invalid index");
7208}
7209
7210std::string BatchedReduceMeanNode::getDebugDesc() const {
7211 DescriptionBuilder db(getKindName());
7212 db.addParam("Name", separateString(getName(), 100, "\n"));
7213 if (hasPredicate()) db.addParam("Predicate", "Yes");
7214 db
7215 .addParam("Batch", *(getBatch().getType()))
7216 .addParam("Axes", getAxes())
7217 .addParam("Users", getNumUsers());
7218 db.addParam("Result", *(getResult().getType()));
7219 return db;
7220}
7221
7222void BatchedReduceMeanNode::visit(Node *parent, NodeWalker *visitor) {
7223 if (!visitor->shouldVisit(parent, this)) { return; }
7224 visitor->pre(parent, this);
7225if (hasPredicate())
7226 getPredicate().getNode()->visit(this, visitor);
7227 getBatch().getNode()->visit(this, visitor);
7228 visitor->post(parent, this);
7229}
7230
7231bool BatchedReduceMeanNode::isEqual(const BatchedReduceMeanNode &other) const {
7232 return true &&
7233 Batch_ == other.Batch_ &&
7234 predicate_ == other.predicate_ &&
7235 Axes_ == other.Axes_ &&
7236 getType(0) == other.getType(0);
7237}
7238
7239Node* BatchedReduceMeanNode::clone() const {
7240 return new BatchedReduceMeanNode(getName(), getResult().getType(), getBatch(), getAxes());
7241}
7242
7243llvm::hash_code BatchedReduceMeanNode::getHash() const {
7244 return llvm::hash_combine(
7245 llvm::hash_combine_range(Axes_.begin(), Axes_.end()),
7246 Batch_);
7247}
7248
7249unsigned BatchedReduceMinNode::getNumInputs() const {
7250 return 1;
7251}
7252
7253std::string BatchedReduceMinNode::getInputName(unsigned idx) const {
7254 if (idx == 0) { return "Batch"; }
7255 idx -= 1;
7256 llvm_unreachable("Invalid index");
7257}
7258
7259NodeValue BatchedReduceMinNode::getNthInput(unsigned idx) {
7260 if (idx == 0) { return Batch_; }
7261 idx -= 1;
7262 llvm_unreachable("Invalid index");
7263}
7264
7265void BatchedReduceMinNode::setNthInput(unsigned idx, NodeValue val) {
7266 if (idx == 0) { Batch_ = val; return; }
7267 idx -= 1;
7268 llvm_unreachable("Invalid index");
7269}
7270
7271llvm::StringRef BatchedReduceMinNode::getOutputName(unsigned idx) const {
7272 if (idx == 0) { return "Result"; }
7273 llvm_unreachable("Invalid index");
7274}
7275
7276std::string BatchedReduceMinNode::getDebugDesc() const {
7277 DescriptionBuilder db(getKindName());
7278 db.addParam("Name", separateString(getName(), 100, "\n"));
7279 if (hasPredicate()) db.addParam("Predicate", "Yes");
7280 db
7281 .addParam("Batch", *(getBatch().getType()))
7282 .addParam("Axes", getAxes())
7283 .addParam("Users", getNumUsers());
7284 db.addParam("Result", *(getResult().getType()));
7285 return db;
7286}
7287
7288void BatchedReduceMinNode::visit(Node *parent, NodeWalker *visitor) {
7289 if (!visitor->shouldVisit(parent, this)) { return; }
7290 visitor->pre(parent, this);
7291if (hasPredicate())
7292 getPredicate().getNode()->visit(this, visitor);
7293 getBatch().getNode()->visit(this, visitor);
7294 visitor->post(parent, this);
7295}
7296
7297bool BatchedReduceMinNode::isEqual(const BatchedReduceMinNode &other) const {
7298 return true &&
7299 Batch_ == other.Batch_ &&
7300 predicate_ == other.predicate_ &&
7301 Axes_ == other.Axes_ &&
7302 getType(0) == other.getType(0);
7303}
7304
7305Node* BatchedReduceMinNode::clone() const {
7306 return new BatchedReduceMinNode(getName(), getResult().getType(), getBatch(), getAxes());
7307}
7308
7309llvm::hash_code BatchedReduceMinNode::getHash() const {
7310 return llvm::hash_combine(
7311 llvm::hash_combine_range(Axes_.begin(), Axes_.end()),
7312 Batch_);
7313}
7314
7315unsigned BatchedReduceMaxNode::getNumInputs() const {
7316 return 1;
7317}
7318
7319std::string BatchedReduceMaxNode::getInputName(unsigned idx) const {
7320 if (idx == 0) { return "Batch"; }
7321 idx -= 1;
7322 llvm_unreachable("Invalid index");
7323}
7324
7325NodeValue BatchedReduceMaxNode::getNthInput(unsigned idx) {
7326 if (idx == 0) { return Batch_; }
7327 idx -= 1;
7328 llvm_unreachable("Invalid index");
7329}
7330
7331void BatchedReduceMaxNode::setNthInput(unsigned idx, NodeValue val) {
7332 if (idx == 0) { Batch_ = val; return; }
7333 idx -= 1;
7334 llvm_unreachable("Invalid index");
7335}
7336
7337llvm::StringRef BatchedReduceMaxNode::getOutputName(unsigned idx) const {
7338 if (idx == 0) { return "Result"; }
7339 llvm_unreachable("Invalid index");
7340}
7341
7342std::string BatchedReduceMaxNode::getDebugDesc() const {
7343 DescriptionBuilder db(getKindName());
7344 db.addParam("Name", separateString(getName(), 100, "\n"));
7345 if (hasPredicate()) db.addParam("Predicate", "Yes");
7346 db
7347 .addParam("Batch", *(getBatch().getType()))
7348 .addParam("Axes", getAxes())
7349 .addParam("Users", getNumUsers());
7350 db.addParam("Result", *(getResult().getType()));
7351 return db;
7352}
7353
7354void BatchedReduceMaxNode::visit(Node *parent, NodeWalker *visitor) {
7355 if (!visitor->shouldVisit(parent, this)) { return; }
7356 visitor->pre(parent, this);
7357if (hasPredicate())
7358 getPredicate().getNode()->visit(this, visitor);
7359 getBatch().getNode()->visit(this, visitor);
7360 visitor->post(parent, this);
7361}
7362
7363bool BatchedReduceMaxNode::isEqual(const BatchedReduceMaxNode &other) const {
7364 return true &&
7365 Batch_ == other.Batch_ &&
7366 predicate_ == other.predicate_ &&
7367 Axes_ == other.Axes_ &&
7368 getType(0) == other.getType(0);
7369}
7370
7371Node* BatchedReduceMaxNode::clone() const {
7372 return new BatchedReduceMaxNode(getName(), getResult().getType(), getBatch(), getAxes());
7373}
7374
7375llvm::hash_code BatchedReduceMaxNode::getHash() const {
7376 return llvm::hash_combine(
7377 llvm::hash_combine_range(Axes_.begin(), Axes_.end()),
7378 Batch_);
7379}
7380
7381unsigned BatchedReduceProdNode::getNumInputs() const {
7382 return 1;
7383}
7384
7385std::string BatchedReduceProdNode::getInputName(unsigned idx) const {
7386 if (idx == 0) { return "Batch"; }
7387 idx -= 1;
7388 llvm_unreachable("Invalid index");
7389}
7390
7391NodeValue BatchedReduceProdNode::getNthInput(unsigned idx) {
7392 if (idx == 0) { return Batch_; }
7393 idx -= 1;
7394 llvm_unreachable("Invalid index");
7395}
7396
7397void BatchedReduceProdNode::setNthInput(unsigned idx, NodeValue val) {
7398 if (idx == 0) { Batch_ = val; return; }
7399 idx -= 1;
7400 llvm_unreachable("Invalid index");
7401}
7402
7403llvm::StringRef BatchedReduceProdNode::getOutputName(unsigned idx) const {
7404 if (idx == 0) { return "Result"; }
7405 llvm_unreachable("Invalid index");
7406}
7407
7408std::string BatchedReduceProdNode::getDebugDesc() const {
7409 DescriptionBuilder db(getKindName());
7410 db.addParam("Name", separateString(getName(), 100, "\n"));
7411 if (hasPredicate()) db.addParam("Predicate", "Yes");
7412 db
7413 .addParam("Batch", *(getBatch().getType()))
7414 .addParam("Axis", getAxis())
7415 .addParam("Users", getNumUsers());
7416 db.addParam("Result", *(getResult().getType()));
7417 return db;
7418}
7419
7420void BatchedReduceProdNode::visit(Node *parent, NodeWalker *visitor) {
7421 if (!visitor->shouldVisit(parent, this)) { return; }
7422 visitor->pre(parent, this);
7423if (hasPredicate())
7424 getPredicate().getNode()->visit(this, visitor);
7425 getBatch().getNode()->visit(this, visitor);
7426 visitor->post(parent, this);
7427}
7428
7429bool BatchedReduceProdNode::isEqual(const BatchedReduceProdNode &other) const {
7430 return true &&
7431 Batch_ == other.Batch_ &&
7432 predicate_ == other.predicate_ &&
7433 Axis_ == other.Axis_ &&
7434 getType(0) == other.getType(0);
7435}
7436
7437Node* BatchedReduceProdNode::clone() const {
7438 return new BatchedReduceProdNode(getName(), getResult().getType(), getBatch(), getAxis());
7439}
7440
7441llvm::hash_code BatchedReduceProdNode::getHash() const {
7442 return llvm::hash_combine(
7443 Axis_,
7444 Batch_);
7445}
7446
7447unsigned ChannelShuffleNode::getNumInputs() const {
7448 return 1;
7449}
7450
7451std::string ChannelShuffleNode::getInputName(unsigned idx) const {
7452 if (idx == 0) { return "Input"; }
7453 idx -= 1;
7454 llvm_unreachable("Invalid index");
7455}
7456
7457NodeValue ChannelShuffleNode::getNthInput(unsigned idx) {
7458 if (idx == 0) { return Input_; }
7459 idx -= 1;
7460 llvm_unreachable("Invalid index");
7461}
7462
7463void ChannelShuffleNode::setNthInput(unsigned idx, NodeValue val) {
7464 if (idx == 0) { Input_ = val; return; }
7465 idx -= 1;
7466 llvm_unreachable("Invalid index");
7467}
7468
7469llvm::StringRef ChannelShuffleNode::getOutputName(unsigned idx) const {
7470 if (idx == 0) { return "Result"; }
7471 llvm_unreachable("Invalid index");
7472}
7473
7474std::string ChannelShuffleNode::getDebugDesc() const {
7475 DescriptionBuilder db(getKindName());
7476 db.addParam("Name", separateString(getName(), 100, "\n"));
7477 if (hasPredicate()) db.addParam("Predicate", "Yes");
7478 db
7479 .addParam("Input", *(getInput().getType()))
7480 .addParam("Group", getGroup())
7481 .addParam("Kernel", getKernel())
7482 .addParam("Users", getNumUsers());
7483 db.addParam("Result", *(getResult().getType()));
7484 return db;
7485}
7486
7487void ChannelShuffleNode::visit(Node *parent, NodeWalker *visitor) {
7488 if (!visitor->shouldVisit(parent, this)) { return; }
7489 visitor->pre(parent, this);
7490if (hasPredicate())
7491 getPredicate().getNode()->visit(this, visitor);
7492 getInput().getNode()->visit(this, visitor);
7493 visitor->post(parent, this);
7494}
7495
7496bool ChannelShuffleNode::isEqual(const ChannelShuffleNode &other) const {
7497 return true &&
7498 Input_ == other.Input_ &&
7499 predicate_ == other.predicate_ &&
7500 Group_ == other.Group_ &&
7501 Kernel_ == other.Kernel_ &&
7502 getType(0) == other.getType(0);
7503}
7504
7505Node* ChannelShuffleNode::clone() const {
7506 return new ChannelShuffleNode(getName(), getResult().getType(), getInput(), getGroup(), getKernel());
7507}
7508
7509llvm::hash_code ChannelShuffleNode::getHash() const {
7510 return llvm::hash_combine(
7511 Group_,
7512 Kernel_,
7513 Input_);
7514}
7515
7516unsigned CumSumNode::getNumInputs() const {
7517 return 1;
7518}
7519
7520std::string CumSumNode::getInputName(unsigned idx) const {
7521 if (idx == 0) { return "Input"; }
7522 idx -= 1;
7523 llvm_unreachable("Invalid index");
7524}
7525
7526NodeValue CumSumNode::getNthInput(unsigned idx) {
7527 if (idx == 0) { return Input_; }
7528 idx -= 1;
7529 llvm_unreachable("Invalid index");
7530}
7531
7532void CumSumNode::setNthInput(unsigned idx, NodeValue val) {
7533 if (idx == 0) { Input_ = val; return; }
7534 idx -= 1;
7535 llvm_unreachable("Invalid index");
7536}
7537
7538llvm::StringRef CumSumNode::getOutputName(unsigned idx) const {
7539 if (idx == 0) { return "Result"; }
7540 llvm_unreachable("Invalid index");
7541}
7542
7543std::string CumSumNode::getDebugDesc() const {
7544 DescriptionBuilder db(getKindName());
7545 db.addParam("Name", separateString(getName(), 100, "\n"));
7546 if (hasPredicate()) db.addParam("Predicate", "Yes");
7547 db
7548 .addParam("Input", *(getInput().getType()))
7549 .addParam("Dim", getDim())
7550 .addParam("Exclusive", getExclusive())
7551 .addParam("Reverse", getReverse())
7552 .addParam("Users", getNumUsers());
7553 db.addParam("Result", *(getResult().getType()));
7554 return db;
7555}
7556
7557void CumSumNode::visit(Node *parent, NodeWalker *visitor) {
7558 if (!visitor->shouldVisit(parent, this)) { return; }
7559 visitor->pre(parent, this);
7560if (hasPredicate())
7561 getPredicate().getNode()->visit(this, visitor);
7562 getInput().getNode()->visit(this, visitor);
7563 visitor->post(parent, this);
7564}
7565
7566bool CumSumNode::isEqual(const CumSumNode &other) const {
7567 return true &&
7568 Input_ == other.Input_ &&
7569 predicate_ == other.predicate_ &&
7570 Dim_ == other.Dim_ &&
7571 Exclusive_ == other.Exclusive_ &&
7572 Reverse_ == other.Reverse_ &&
7573 getType(0) == other.getType(0);
7574}
7575
7576Node* CumSumNode::clone() const {
7577 return new CumSumNode(getName(), getResult().getType(), getInput(), getDim(), getExclusive(), getReverse());
7578}
7579
7580llvm::hash_code CumSumNode::getHash() const {
7581 return llvm::hash_combine(
7582 Dim_,
7583 Exclusive_,
7584 Reverse_,
7585 Input_);
7586}
7587
7588unsigned LengthsSumNode::getNumInputs() const {
7589 return 2;
7590}
7591
7592std::string LengthsSumNode::getInputName(unsigned idx) const {
7593 if (idx == 0) { return "Data"; }
7594 if (idx == 1) { return "Lengths"; }
7595 idx -= 2;
7596 llvm_unreachable("Invalid index");
7597}
7598
7599NodeValue LengthsSumNode::getNthInput(unsigned idx) {
7600 if (idx == 0) { return Data_; }
7601 if (idx == 1) { return Lengths_; }
7602 idx -= 2;
7603 llvm_unreachable("Invalid index");
7604}
7605
7606void LengthsSumNode::setNthInput(unsigned idx, NodeValue val) {
7607 if (idx == 0) { Data_ = val; return; }
7608 if (idx == 1) { Lengths_ = val; return; }
7609 idx -= 2;
7610 llvm_unreachable("Invalid index");
7611}
7612
7613llvm::StringRef LengthsSumNode::getOutputName(unsigned idx) const {
7614 if (idx == 0) { return "Result"; }
7615 llvm_unreachable("Invalid index");
7616}
7617
7618std::string LengthsSumNode::getDebugDesc() const {
7619 DescriptionBuilder db(getKindName());
7620 db.addParam("Name", separateString(getName(), 100, "\n"));
7621 if (hasPredicate()) db.addParam("Predicate", "Yes");
7622 db
7623 .addParam("Data", *(getData().getType()))
7624 .addParam("Lengths", *(getLengths().getType()))
7625 .addParam("Users", getNumUsers());
7626 db.addParam("Result", *(getResult().getType()));
7627 return db;
7628}
7629
7630void LengthsSumNode::visit(Node *parent, NodeWalker *visitor) {
7631 if (!visitor->shouldVisit(parent, this)) { return; }
7632 visitor->pre(parent, this);
7633if (hasPredicate())
7634 getPredicate().getNode()->visit(this, visitor);
7635 getData().getNode()->visit(this, visitor);
7636 getLengths().getNode()->visit(this, visitor);
7637 visitor->post(parent, this);
7638}
7639
7640bool LengthsSumNode::isEqual(const LengthsSumNode &other) const {
7641 return true &&
7642 Data_ == other.Data_ &&
7643 Lengths_ == other.Lengths_ &&
7644 predicate_ == other.predicate_ &&
7645 getType(0) == other.getType(0);
7646}
7647
7648Node* LengthsSumNode::clone() const {
7649 return new LengthsSumNode(getName(), getResult().getType(), getData(), getLengths());
7650}
7651
7652llvm::hash_code LengthsSumNode::getHash() const {
7653 return llvm::hash_combine(
7654 Data_,
7655 Lengths_);
7656}
7657
7658unsigned SparseLengthsSumGradNode::getNumInputs() const {
7659 return 5;
7660}
7661
7662std::string SparseLengthsSumGradNode::getInputName(unsigned idx) const {
7663 if (idx == 0) { return "Data"; }
7664 if (idx == 1) { return "Indices"; }
7665 if (idx == 2) { return "Lengths"; }
7666 if (idx == 3) { return "OriginalOutputForResult"; }
7667 if (idx == 4) { return "GradOfOriginalOutputNamedResult"; }
7668 idx -= 5;
7669 llvm_unreachable("Invalid index");
7670}
7671
7672NodeValue SparseLengthsSumGradNode::getNthInput(unsigned idx) {
7673 if (idx == 0) { return Data_; }
7674 if (idx == 1) { return Indices_; }
7675 if (idx == 2) { return Lengths_; }
7676 if (idx == 3) { return OriginalOutputForResult_; }
7677 if (idx == 4) { return GradOfOriginalOutputNamedResult_; }
7678 idx -= 5;
7679 llvm_unreachable("Invalid index");
7680}
7681
7682void SparseLengthsSumGradNode::setNthInput(unsigned idx, NodeValue val) {
7683 if (idx == 0) { Data_ = val; return; }
7684 if (idx == 1) { Indices_ = val; return; }
7685 if (idx == 2) { Lengths_ = val; return; }
7686 if (idx == 3) { OriginalOutputForResult_ = val; return; }
7687 if (idx == 4) { GradOfOriginalOutputNamedResult_ = val; return; }
7688 idx -= 5;
7689 llvm_unreachable("Invalid index");
7690}
7691
7692llvm::StringRef SparseLengthsSumGradNode::getOutputName(unsigned idx) const {
7693 if (idx == 0) { return "GradOfInputNamedData"; }
7694 if (idx == 1) { return "GradOfInputNamedIndices"; }
7695 if (idx == 2) { return "GradOfInputNamedLengths"; }
7696 llvm_unreachable("Invalid index");
7697}
7698
7699std::string SparseLengthsSumGradNode::getDebugDesc() const {
7700 DescriptionBuilder db(getKindName());
7701 db.addParam("Name", separateString(getName(), 100, "\n"));
7702 if (hasPredicate()) db.addParam("Predicate", "Yes");
7703 db
7704 .addParam("Data", *(getData().getType()))
7705 .addParam("Indices", *(getIndices().getType()))
7706 .addParam("Lengths", *(getLengths().getType()))
7707 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
7708 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
7709 .addParam("LengthsMode", getLengthsMode())
7710 .addParam("AvgLength", getAvgLength())
7711 .addParam("Users", getNumUsers());
7712 db.addParam("GradOfInputNamedData", *(getGradOfInputNamedData().getType()));
7713 db.addParam("GradOfInputNamedIndices", *(getGradOfInputNamedIndices().getType()));
7714 db.addParam("GradOfInputNamedLengths", *(getGradOfInputNamedLengths().getType()));
7715 return db;
7716}
7717
7718void SparseLengthsSumGradNode::visit(Node *parent, NodeWalker *visitor) {
7719 if (!visitor->shouldVisit(parent, this)) { return; }
7720 visitor->pre(parent, this);
7721if (hasPredicate())
7722 getPredicate().getNode()->visit(this, visitor);
7723 getData().getNode()->visit(this, visitor);
7724 getIndices().getNode()->visit(this, visitor);
7725 getLengths().getNode()->visit(this, visitor);
7726 getOriginalOutputForResult().getNode()->visit(this, visitor);
7727 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
7728 visitor->post(parent, this);
7729}
7730
7731bool SparseLengthsSumGradNode::isEqual(const SparseLengthsSumGradNode &other) const {
7732 return true &&
7733 Data_ == other.Data_ &&
7734 Indices_ == other.Indices_ &&
7735 Lengths_ == other.Lengths_ &&
7736 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
7737 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
7738 predicate_ == other.predicate_ &&
7739 LengthsMode_ == other.LengthsMode_ &&
7740 AvgLength_ == other.AvgLength_ &&
7741 getType(0) == other.getType(0) &&
7742 getType(1) == other.getType(1) &&
7743 getType(2) == other.getType(2);
7744}
7745
7746Node* SparseLengthsSumGradNode::clone() const {
7747 return new SparseLengthsSumGradNode(getName(), getData(), getIndices(), getLengths(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getLengthsMode(), getAvgLength());
7748}
7749
7750llvm::hash_code SparseLengthsSumGradNode::getHash() const {
7751 return llvm::hash_combine(
7752 LengthsMode_,
7753 toBinary(AvgLength_),
7754 Data_,
7755 Indices_,
7756 Lengths_,
7757 OriginalOutputForResult_,
7758 GradOfOriginalOutputNamedResult_);
7759}
7760
7761unsigned SparseLengthsSumNode::getNumInputs() const {
7762 return 3;
7763}
7764
7765std::string SparseLengthsSumNode::getInputName(unsigned idx) const {
7766 if (idx == 0) { return "Data"; }
7767 if (idx == 1) { return "Indices"; }
7768 if (idx == 2) { return "Lengths"; }
7769 idx -= 3;
7770 llvm_unreachable("Invalid index");
7771}
7772
7773NodeValue SparseLengthsSumNode::getNthInput(unsigned idx) {
7774 if (idx == 0) { return Data_; }
7775 if (idx == 1) { return Indices_; }
7776 if (idx == 2) { return Lengths_; }
7777 idx -= 3;
7778 llvm_unreachable("Invalid index");
7779}
7780
7781void SparseLengthsSumNode::setNthInput(unsigned idx, NodeValue val) {
7782 if (idx == 0) { Data_ = val; return; }
7783 if (idx == 1) { Indices_ = val; return; }
7784 if (idx == 2) { Lengths_ = val; return; }
7785 idx -= 3;
7786 llvm_unreachable("Invalid index");
7787}
7788
7789llvm::StringRef SparseLengthsSumNode::getOutputName(unsigned idx) const {
7790 if (idx == 0) { return "Result"; }
7791 llvm_unreachable("Invalid index");
7792}
7793
7794std::string SparseLengthsSumNode::getDebugDesc() const {
7795 DescriptionBuilder db(getKindName());
7796 db.addParam("Name", separateString(getName(), 100, "\n"));
7797 if (hasPredicate()) db.addParam("Predicate", "Yes");
7798 db
7799 .addParam("Data", *(getData().getType()))
7800 .addParam("Indices", *(getIndices().getType()))
7801 .addParam("Lengths", *(getLengths().getType()))
7802 .addParam("LengthsMode", getLengthsMode())
7803 .addParam("AvgLength", getAvgLength())
7804 .addParam("Users", getNumUsers());
7805 db.addParam("Result", *(getResult().getType()));
7806 return db;
7807}
7808
7809void SparseLengthsSumNode::visit(Node *parent, NodeWalker *visitor) {
7810 if (!visitor->shouldVisit(parent, this)) { return; }
7811 visitor->pre(parent, this);
7812if (hasPredicate())
7813 getPredicate().getNode()->visit(this, visitor);
7814 getData().getNode()->visit(this, visitor);
7815 getIndices().getNode()->visit(this, visitor);
7816 getLengths().getNode()->visit(this, visitor);
7817 visitor->post(parent, this);
7818}
7819
7820bool SparseLengthsSumNode::isEqual(const SparseLengthsSumNode &other) const {
7821 return true &&
7822 Data_ == other.Data_ &&
7823 Indices_ == other.Indices_ &&
7824 Lengths_ == other.Lengths_ &&
7825 predicate_ == other.predicate_ &&
7826 LengthsMode_ == other.LengthsMode_ &&
7827 AvgLength_ == other.AvgLength_ &&
7828 getType(0) == other.getType(0);
7829}
7830
7831Node* SparseLengthsSumNode::clone() const {
7832 return new SparseLengthsSumNode(getName(), getResult().getType(), getData(), getIndices(), getLengths(), getLengthsMode(), getAvgLength());
7833}
7834
7835llvm::hash_code SparseLengthsSumNode::getHash() const {
7836 return llvm::hash_combine(
7837 LengthsMode_,
7838 toBinary(AvgLength_),
7839 Data_,
7840 Indices_,
7841 Lengths_);
7842}
7843
7844SparseLengthsSumGradNode *SparseLengthsSumNode::getGrad(GraphGradMapper &builder) {
7845 auto *x = new SparseLengthsSumGradNode(getName().str() + "_grad", getData(), getIndices(), getLengths(), getResult(), builder.getGradient(getResult()), getLengthsMode(), getAvgLength());
7846 builder.addGradient(getData(), x->getGradOfInputNamedData());
7847 builder.addGradient(getIndices(), x->getGradOfInputNamedIndices());
7848 builder.addGradient(getLengths(), x->getGradOfInputNamedLengths());
7849 return x;
7850}
7851
7852unsigned SparseLengthsWeightedSumGradNode::getNumInputs() const {
7853 return 6;
7854}
7855
7856std::string SparseLengthsWeightedSumGradNode::getInputName(unsigned idx) const {
7857 if (idx == 0) { return "Data"; }
7858 if (idx == 1) { return "Weights"; }
7859 if (idx == 2) { return "Indices"; }
7860 if (idx == 3) { return "Lengths"; }
7861 if (idx == 4) { return "OriginalOutputForResult"; }
7862 if (idx == 5) { return "GradOfOriginalOutputNamedResult"; }
7863 idx -= 6;
7864 llvm_unreachable("Invalid index");
7865}
7866
7867NodeValue SparseLengthsWeightedSumGradNode::getNthInput(unsigned idx) {
7868 if (idx == 0) { return Data_; }
7869 if (idx == 1) { return Weights_; }
7870 if (idx == 2) { return Indices_; }
7871 if (idx == 3) { return Lengths_; }
7872 if (idx == 4) { return OriginalOutputForResult_; }
7873 if (idx == 5) { return GradOfOriginalOutputNamedResult_; }
7874 idx -= 6;
7875 llvm_unreachable("Invalid index");
7876}
7877
7878void SparseLengthsWeightedSumGradNode::setNthInput(unsigned idx, NodeValue val) {
7879 if (idx == 0) { Data_ = val; return; }
7880 if (idx == 1) { Weights_ = val; return; }
7881 if (idx == 2) { Indices_ = val; return; }
7882 if (idx == 3) { Lengths_ = val; return; }
7883 if (idx == 4) { OriginalOutputForResult_ = val; return; }
7884 if (idx == 5) { GradOfOriginalOutputNamedResult_ = val; return; }
7885 idx -= 6;
7886 llvm_unreachable("Invalid index");
7887}
7888
7889llvm::StringRef SparseLengthsWeightedSumGradNode::getOutputName(unsigned idx) const {
7890 if (idx == 0) { return "GradOfInputNamedData"; }
7891 if (idx == 1) { return "GradOfInputNamedWeights"; }
7892 if (idx == 2) { return "GradOfInputNamedIndices"; }
7893 if (idx == 3) { return "GradOfInputNamedLengths"; }
7894 llvm_unreachable("Invalid index");
7895}
7896
7897std::string SparseLengthsWeightedSumGradNode::getDebugDesc() const {
7898 DescriptionBuilder db(getKindName());
7899 db.addParam("Name", separateString(getName(), 100, "\n"));
7900 if (hasPredicate()) db.addParam("Predicate", "Yes");
7901 db
7902 .addParam("Data", *(getData().getType()))
7903 .addParam("Weights", *(getWeights().getType()))
7904 .addParam("Indices", *(getIndices().getType()))
7905 .addParam("Lengths", *(getLengths().getType()))
7906 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
7907 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
7908 .addParam("LengthsMode", getLengthsMode())
7909 .addParam("AvgLength", getAvgLength())
7910 .addParam("Users", getNumUsers());
7911 db.addParam("GradOfInputNamedData", *(getGradOfInputNamedData().getType()));
7912 db.addParam("GradOfInputNamedWeights", *(getGradOfInputNamedWeights().getType()));
7913 db.addParam("GradOfInputNamedIndices", *(getGradOfInputNamedIndices().getType()));
7914 db.addParam("GradOfInputNamedLengths", *(getGradOfInputNamedLengths().getType()));
7915 return db;
7916}
7917
7918void SparseLengthsWeightedSumGradNode::visit(Node *parent, NodeWalker *visitor) {
7919 if (!visitor->shouldVisit(parent, this)) { return; }
7920 visitor->pre(parent, this);
7921if (hasPredicate())
7922 getPredicate().getNode()->visit(this, visitor);
7923 getData().getNode()->visit(this, visitor);
7924 getWeights().getNode()->visit(this, visitor);
7925 getIndices().getNode()->visit(this, visitor);
7926 getLengths().getNode()->visit(this, visitor);
7927 getOriginalOutputForResult().getNode()->visit(this, visitor);
7928 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
7929 visitor->post(parent, this);
7930}
7931
7932bool SparseLengthsWeightedSumGradNode::isEqual(const SparseLengthsWeightedSumGradNode &other) const {
7933 return true &&
7934 Data_ == other.Data_ &&
7935 Weights_ == other.Weights_ &&
7936 Indices_ == other.Indices_ &&
7937 Lengths_ == other.Lengths_ &&
7938 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
7939 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
7940 predicate_ == other.predicate_ &&
7941 LengthsMode_ == other.LengthsMode_ &&
7942 AvgLength_ == other.AvgLength_ &&
7943 getType(0) == other.getType(0) &&
7944 getType(1) == other.getType(1) &&
7945 getType(2) == other.getType(2) &&
7946 getType(3) == other.getType(3);
7947}
7948
7949Node* SparseLengthsWeightedSumGradNode::clone() const {
7950 return new SparseLengthsWeightedSumGradNode(getName(), getData(), getWeights(), getIndices(), getLengths(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getLengthsMode(), getAvgLength());
7951}
7952
7953llvm::hash_code SparseLengthsWeightedSumGradNode::getHash() const {
7954 return llvm::hash_combine(
7955 LengthsMode_,
7956 toBinary(AvgLength_),
7957 Data_,
7958 Weights_,
7959 Indices_,
7960 Lengths_,
7961 OriginalOutputForResult_,
7962 GradOfOriginalOutputNamedResult_);
7963}
7964
7965unsigned SparseLengthsWeightedSumNode::getNumInputs() const {
7966 return 4;
7967}
7968
7969std::string SparseLengthsWeightedSumNode::getInputName(unsigned idx) const {
7970 if (idx == 0) { return "Data"; }
7971 if (idx == 1) { return "Weights"; }
7972 if (idx == 2) { return "Indices"; }
7973 if (idx == 3) { return "Lengths"; }
7974 idx -= 4;
7975 llvm_unreachable("Invalid index");
7976}
7977
7978NodeValue SparseLengthsWeightedSumNode::getNthInput(unsigned idx) {
7979 if (idx == 0) { return Data_; }
7980 if (idx == 1) { return Weights_; }
7981 if (idx == 2) { return Indices_; }
7982 if (idx == 3) { return Lengths_; }
7983 idx -= 4;
7984 llvm_unreachable("Invalid index");
7985}
7986
7987void SparseLengthsWeightedSumNode::setNthInput(unsigned idx, NodeValue val) {
7988 if (idx == 0) { Data_ = val; return; }
7989 if (idx == 1) { Weights_ = val; return; }
7990 if (idx == 2) { Indices_ = val; return; }
7991 if (idx == 3) { Lengths_ = val; return; }
7992 idx -= 4;
7993 llvm_unreachable("Invalid index");
7994}
7995
7996llvm::StringRef SparseLengthsWeightedSumNode::getOutputName(unsigned idx) const {
7997 if (idx == 0) { return "Result"; }
7998 llvm_unreachable("Invalid index");
7999}
8000
8001std::string SparseLengthsWeightedSumNode::getDebugDesc() const {
8002 DescriptionBuilder db(getKindName());
8003 db.addParam("Name", separateString(getName(), 100, "\n"));
8004 if (hasPredicate()) db.addParam("Predicate", "Yes");
8005 db
8006 .addParam("Data", *(getData().getType()))
8007 .addParam("Weights", *(getWeights().getType()))
8008 .addParam("Indices", *(getIndices().getType()))
8009 .addParam("Lengths", *(getLengths().getType()))
8010 .addParam("LengthsMode", getLengthsMode())
8011 .addParam("AvgLength", getAvgLength())
8012 .addParam("Users", getNumUsers());
8013 db.addParam("Result", *(getResult().getType()));
8014 return db;
8015}
8016
8017void SparseLengthsWeightedSumNode::visit(Node *parent, NodeWalker *visitor) {
8018 if (!visitor->shouldVisit(parent, this)) { return; }
8019 visitor->pre(parent, this);
8020if (hasPredicate())
8021 getPredicate().getNode()->visit(this, visitor);
8022 getData().getNode()->visit(this, visitor);
8023 getWeights().getNode()->visit(this, visitor);
8024 getIndices().getNode()->visit(this, visitor);
8025 getLengths().getNode()->visit(this, visitor);
8026 visitor->post(parent, this);
8027}
8028
8029bool SparseLengthsWeightedSumNode::isEqual(const SparseLengthsWeightedSumNode &other) const {
8030 return true &&
8031 Data_ == other.Data_ &&
8032 Weights_ == other.Weights_ &&
8033 Indices_ == other.Indices_ &&
8034 Lengths_ == other.Lengths_ &&
8035 predicate_ == other.predicate_ &&
8036 LengthsMode_ == other.LengthsMode_ &&
8037 AvgLength_ == other.AvgLength_ &&
8038 getType(0) == other.getType(0);
8039}
8040
8041Node* SparseLengthsWeightedSumNode::clone() const {
8042 return new SparseLengthsWeightedSumNode(getName(), getResult().getType(), getData(), getWeights(), getIndices(), getLengths(), getLengthsMode(), getAvgLength());
8043}
8044
8045llvm::hash_code SparseLengthsWeightedSumNode::getHash() const {
8046 return llvm::hash_combine(
8047 LengthsMode_,
8048 toBinary(AvgLength_),
8049 Data_,
8050 Weights_,
8051 Indices_,
8052 Lengths_);
8053}
8054
8055SparseLengthsWeightedSumGradNode *SparseLengthsWeightedSumNode::getGrad(GraphGradMapper &builder) {
8056 auto *x = new SparseLengthsWeightedSumGradNode(getName().str() + "_grad", getData(), getWeights(), getIndices(), getLengths(), getResult(), builder.getGradient(getResult()), getLengthsMode(), getAvgLength());
8057 builder.addGradient(getData(), x->getGradOfInputNamedData());
8058 builder.addGradient(getWeights(), x->getGradOfInputNamedWeights());
8059 builder.addGradient(getIndices(), x->getGradOfInputNamedIndices());
8060 builder.addGradient(getLengths(), x->getGradOfInputNamedLengths());
8061 return x;
8062}
8063
8064unsigned EmbeddingNode::getNumInputs() const {
8065 return 2;
8066}
8067
8068std::string EmbeddingNode::getInputName(unsigned idx) const {
8069 if (idx == 0) { return "Weights"; }
8070 if (idx == 1) { return "Indices"; }
8071 idx -= 2;
8072 llvm_unreachable("Invalid index");
8073}
8074
8075NodeValue EmbeddingNode::getNthInput(unsigned idx) {
8076 if (idx == 0) { return Weights_; }
8077 if (idx == 1) { return Indices_; }
8078 idx -= 2;
8079 llvm_unreachable("Invalid index");
8080}
8081
8082void EmbeddingNode::setNthInput(unsigned idx, NodeValue val) {
8083 if (idx == 0) { Weights_ = val; return; }
8084 if (idx == 1) { Indices_ = val; return; }
8085 idx -= 2;
8086 llvm_unreachable("Invalid index");
8087}
8088
8089llvm::StringRef EmbeddingNode::getOutputName(unsigned idx) const {
8090 if (idx == 0) { return "Result"; }
8091 llvm_unreachable("Invalid index");
8092}
8093
8094std::string EmbeddingNode::getDebugDesc() const {
8095 DescriptionBuilder db(getKindName());
8096 db.addParam("Name", separateString(getName(), 100, "\n"));
8097 if (hasPredicate()) db.addParam("Predicate", "Yes");
8098 db
8099 .addParam("Weights", *(getWeights().getType()))
8100 .addParam("Indices", *(getIndices().getType()))
8101 .addParam("PadIdx", getPadIdx())
8102 .addParam("Scale", getScale())
8103 .addParam("Sparse", getSparse())
8104 .addParam("Users", getNumUsers());
8105 db.addParam("Result", *(getResult().getType()));
8106 return db;
8107}
8108
8109void EmbeddingNode::visit(Node *parent, NodeWalker *visitor) {
8110 if (!visitor->shouldVisit(parent, this)) { return; }
8111 visitor->pre(parent, this);
8112if (hasPredicate())
8113 getPredicate().getNode()->visit(this, visitor);
8114 getWeights().getNode()->visit(this, visitor);
8115 getIndices().getNode()->visit(this, visitor);
8116 visitor->post(parent, this);
8117}
8118
8119bool EmbeddingNode::isEqual(const EmbeddingNode &other) const {
8120 return true &&
8121 Weights_ == other.Weights_ &&
8122 Indices_ == other.Indices_ &&
8123 predicate_ == other.predicate_ &&
8124 PadIdx_ == other.PadIdx_ &&
8125 Scale_ == other.Scale_ &&
8126 Sparse_ == other.Sparse_ &&
8127 getType(0) == other.getType(0);
8128}
8129
8130Node* EmbeddingNode::clone() const {
8131 return new EmbeddingNode(getName(), getResult().getType(), getWeights(), getIndices(), getPadIdx(), getScale(), getSparse());
8132}
8133
8134llvm::hash_code EmbeddingNode::getHash() const {
8135 return llvm::hash_combine(
8136 PadIdx_,
8137 Scale_,
8138 Sparse_,
8139 Weights_,
8140 Indices_);
8141}
8142
8143unsigned EmbeddingBagNode::getNumInputs() const {
8144 return 4;
8145}
8146
8147std::string EmbeddingBagNode::getInputName(unsigned idx) const {
8148 if (idx == 0) { return "Data"; }
8149 if (idx == 1) { return "Weights"; }
8150 if (idx == 2) { return "Indices"; }
8151 if (idx == 3) { return "Offsets"; }
8152 idx -= 4;
8153 llvm_unreachable("Invalid index");
8154}
8155
8156NodeValue EmbeddingBagNode::getNthInput(unsigned idx) {
8157 if (idx == 0) { return Data_; }
8158 if (idx == 1) { return Weights_; }
8159 if (idx == 2) { return Indices_; }
8160 if (idx == 3) { return Offsets_; }
8161 idx -= 4;
8162 llvm_unreachable("Invalid index");
8163}
8164
8165void EmbeddingBagNode::setNthInput(unsigned idx, NodeValue val) {
8166 if (idx == 0) { Data_ = val; return; }
8167 if (idx == 1) { Weights_ = val; return; }
8168 if (idx == 2) { Indices_ = val; return; }
8169 if (idx == 3) { Offsets_ = val; return; }
8170 idx -= 4;
8171 llvm_unreachable("Invalid index");
8172}
8173
8174llvm::StringRef EmbeddingBagNode::getOutputName(unsigned idx) const {
8175 if (idx == 0) { return "Result"; }
8176 llvm_unreachable("Invalid index");
8177}
8178
8179std::string EmbeddingBagNode::getDebugDesc() const {
8180 DescriptionBuilder db(getKindName());
8181 db.addParam("Name", separateString(getName(), 100, "\n"));
8182 if (hasPredicate()) db.addParam("Predicate", "Yes");
8183 db
8184 .addParam("Data", *(getData().getType()))
8185 .addParam("Weights", *(getWeights().getType()))
8186 .addParam("Indices", *(getIndices().getType()))
8187 .addParam("Offsets", *(getOffsets().getType()))
8188 .addParam("HasEndOffset", getHasEndOffset())
8189 .addParam("LengthsMode", getLengthsMode())
8190 .addParam("AvgLength", getAvgLength())
8191 .addParam("Users", getNumUsers());
8192 db.addParam("Result", *(getResult().getType()));
8193 return db;
8194}
8195
8196void EmbeddingBagNode::visit(Node *parent, NodeWalker *visitor) {
8197 if (!visitor->shouldVisit(parent, this)) { return; }
8198 visitor->pre(parent, this);
8199if (hasPredicate())
8200 getPredicate().getNode()->visit(this, visitor);
8201 getData().getNode()->visit(this, visitor);
8202 getWeights().getNode()->visit(this, visitor);
8203 getIndices().getNode()->visit(this, visitor);
8204 getOffsets().getNode()->visit(this, visitor);
8205 visitor->post(parent, this);
8206}
8207
8208bool EmbeddingBagNode::isEqual(const EmbeddingBagNode &other) const {
8209 return true &&
8210 Data_ == other.Data_ &&
8211 Weights_ == other.Weights_ &&
8212 Indices_ == other.Indices_ &&
8213 Offsets_ == other.Offsets_ &&
8214 predicate_ == other.predicate_ &&
8215 HasEndOffset_ == other.HasEndOffset_ &&
8216 LengthsMode_ == other.LengthsMode_ &&
8217 AvgLength_ == other.AvgLength_ &&
8218 getType(0) == other.getType(0);
8219}
8220
8221Node* EmbeddingBagNode::clone() const {
8222 return new EmbeddingBagNode(getName(), getResult().getType(), getData(), getWeights(), getIndices(), getOffsets(), getHasEndOffset(), getLengthsMode(), getAvgLength());
8223}
8224
8225llvm::hash_code EmbeddingBagNode::getHash() const {
8226 return llvm::hash_combine(
8227 HasEndOffset_,
8228 LengthsMode_,
8229 toBinary(AvgLength_),
8230 Data_,
8231 Weights_,
8232 Indices_,
8233 Offsets_);
8234}
8235
8236unsigned EmbeddingBagByteRowwiseOffsetsNode::getNumInputs() const {
8237 return 4;
8238}
8239
8240std::string EmbeddingBagByteRowwiseOffsetsNode::getInputName(unsigned idx) const {
8241 if (idx == 0) { return "Data"; }
8242 if (idx == 1) { return "Weights"; }
8243 if (idx == 2) { return "Indices"; }
8244 if (idx == 3) { return "Offsets"; }
8245 idx -= 4;
8246 llvm_unreachable("Invalid index");
8247}
8248
8249NodeValue EmbeddingBagByteRowwiseOffsetsNode::getNthInput(unsigned idx) {
8250 if (idx == 0) { return Data_; }
8251 if (idx == 1) { return Weights_; }
8252 if (idx == 2) { return Indices_; }
8253 if (idx == 3) { return Offsets_; }
8254 idx -= 4;
8255 llvm_unreachable("Invalid index");
8256}
8257
8258void EmbeddingBagByteRowwiseOffsetsNode::setNthInput(unsigned idx, NodeValue val) {
8259 if (idx == 0) { Data_ = val; return; }
8260 if (idx == 1) { Weights_ = val; return; }
8261 if (idx == 2) { Indices_ = val; return; }
8262 if (idx == 3) { Offsets_ = val; return; }
8263 idx -= 4;
8264 llvm_unreachable("Invalid index");
8265}
8266
8267llvm::StringRef EmbeddingBagByteRowwiseOffsetsNode::getOutputName(unsigned idx) const {
8268 if (idx == 0) { return "Result"; }
8269 llvm_unreachable("Invalid index");
8270}
8271
8272std::string EmbeddingBagByteRowwiseOffsetsNode::getDebugDesc() const {
8273 DescriptionBuilder db(getKindName());
8274 db.addParam("Name", separateString(getName(), 100, "\n"));
8275 if (hasPredicate()) db.addParam("Predicate", "Yes");
8276 db
8277 .addParam("Data", *(getData().getType()))
8278 .addParam("Weights", *(getWeights().getType()))
8279 .addParam("Indices", *(getIndices().getType()))
8280 .addParam("Offsets", *(getOffsets().getType()))
8281 .addParam("UseFP16Accumulation", getUseFP16Accumulation())
8282 .addParam("HasEndOffset", getHasEndOffset())
8283 .addParam("LengthsMode", getLengthsMode())
8284 .addParam("AvgLength", getAvgLength())
8285 .addParam("Users", getNumUsers());
8286 db.addParam("Result", *(getResult().getType()));
8287 return db;
8288}
8289
8290void EmbeddingBagByteRowwiseOffsetsNode::visit(Node *parent, NodeWalker *visitor) {
8291 if (!visitor->shouldVisit(parent, this)) { return; }
8292 visitor->pre(parent, this);
8293if (hasPredicate())
8294 getPredicate().getNode()->visit(this, visitor);
8295 getData().getNode()->visit(this, visitor);
8296 getWeights().getNode()->visit(this, visitor);
8297 getIndices().getNode()->visit(this, visitor);
8298 getOffsets().getNode()->visit(this, visitor);
8299 visitor->post(parent, this);
8300}
8301
8302bool EmbeddingBagByteRowwiseOffsetsNode::isEqual(const EmbeddingBagByteRowwiseOffsetsNode &other) const {
8303 return true &&
8304 Data_ == other.Data_ &&
8305 Weights_ == other.Weights_ &&
8306 Indices_ == other.Indices_ &&
8307 Offsets_ == other.Offsets_ &&
8308 predicate_ == other.predicate_ &&
8309 UseFP16Accumulation_ == other.UseFP16Accumulation_ &&
8310 HasEndOffset_ == other.HasEndOffset_ &&
8311 LengthsMode_ == other.LengthsMode_ &&
8312 AvgLength_ == other.AvgLength_ &&
8313 getType(0) == other.getType(0);
8314}
8315
8316Node* EmbeddingBagByteRowwiseOffsetsNode::clone() const {
8317 return new EmbeddingBagByteRowwiseOffsetsNode(getName(), getResult().getType(), getData(), getWeights(), getIndices(), getOffsets(), getUseFP16Accumulation(), getHasEndOffset(), getLengthsMode(), getAvgLength());
8318}
8319
8320llvm::hash_code EmbeddingBagByteRowwiseOffsetsNode::getHash() const {
8321 return llvm::hash_combine(
8322 UseFP16Accumulation_,
8323 HasEndOffset_,
8324 LengthsMode_,
8325 toBinary(AvgLength_),
8326 Data_,
8327 Weights_,
8328 Indices_,
8329 Offsets_);
8330}
8331
8332unsigned RowwiseQuantizedSparseLengthsWeightedSumNode::getNumInputs() const {
8333 return 6;
8334}
8335
8336std::string RowwiseQuantizedSparseLengthsWeightedSumNode::getInputName(unsigned idx) const {
8337 if (idx == 0) { return "Data"; }
8338 if (idx == 1) { return "Scales"; }
8339 if (idx == 2) { return "Offsets"; }
8340 if (idx == 3) { return "Weights"; }
8341 if (idx == 4) { return "Indices"; }
8342 if (idx == 5) { return "Lengths"; }
8343 idx -= 6;
8344 llvm_unreachable("Invalid index");
8345}
8346
8347NodeValue RowwiseQuantizedSparseLengthsWeightedSumNode::getNthInput(unsigned idx) {
8348 if (idx == 0) { return Data_; }
8349 if (idx == 1) { return Scales_; }
8350 if (idx == 2) { return Offsets_; }
8351 if (idx == 3) { return Weights_; }
8352 if (idx == 4) { return Indices_; }
8353 if (idx == 5) { return Lengths_; }
8354 idx -= 6;
8355 llvm_unreachable("Invalid index");
8356}
8357
8358void RowwiseQuantizedSparseLengthsWeightedSumNode::setNthInput(unsigned idx, NodeValue val) {
8359 if (idx == 0) { Data_ = val; return; }
8360 if (idx == 1) { Scales_ = val; return; }
8361 if (idx == 2) { Offsets_ = val; return; }
8362 if (idx == 3) { Weights_ = val; return; }
8363 if (idx == 4) { Indices_ = val; return; }
8364 if (idx == 5) { Lengths_ = val; return; }
8365 idx -= 6;
8366 llvm_unreachable("Invalid index");
8367}
8368
8369llvm::StringRef RowwiseQuantizedSparseLengthsWeightedSumNode::getOutputName(unsigned idx) const {
8370 if (idx == 0) { return "Result"; }
8371 llvm_unreachable("Invalid index");
8372}
8373
8374std::string RowwiseQuantizedSparseLengthsWeightedSumNode::getDebugDesc() const {
8375 DescriptionBuilder db(getKindName());
8376 db.addParam("Name", separateString(getName(), 100, "\n"));
8377 if (hasPredicate()) db.addParam("Predicate", "Yes");
8378 db
8379 .addParam("Data", *(getData().getType()))
8380 .addParam("Scales", *(getScales().getType()))
8381 .addParam("Offsets", *(getOffsets().getType()))
8382 .addParam("Weights", *(getWeights().getType()))
8383 .addParam("Indices", *(getIndices().getType()))
8384 .addParam("Lengths", *(getLengths().getType()))
8385 .addParam("UseFP16Accumulation", getUseFP16Accumulation())
8386 .addParam("LengthsMode", getLengthsMode())
8387 .addParam("AvgLength", getAvgLength())
8388 .addParam("Users", getNumUsers());
8389 db.addParam("Result", *(getResult().getType()));
8390 return db;
8391}
8392
8393void RowwiseQuantizedSparseLengthsWeightedSumNode::visit(Node *parent, NodeWalker *visitor) {
8394 if (!visitor->shouldVisit(parent, this)) { return; }
8395 visitor->pre(parent, this);
8396if (hasPredicate())
8397 getPredicate().getNode()->visit(this, visitor);
8398 getData().getNode()->visit(this, visitor);
8399 getScales().getNode()->visit(this, visitor);
8400 getOffsets().getNode()->visit(this, visitor);
8401 getWeights().getNode()->visit(this, visitor);
8402 getIndices().getNode()->visit(this, visitor);
8403 getLengths().getNode()->visit(this, visitor);
8404 visitor->post(parent, this);
8405}
8406
8407bool RowwiseQuantizedSparseLengthsWeightedSumNode::isEqual(const RowwiseQuantizedSparseLengthsWeightedSumNode &other) const {
8408 return true &&
8409 Data_ == other.Data_ &&
8410 Scales_ == other.Scales_ &&
8411 Offsets_ == other.Offsets_ &&
8412 Weights_ == other.Weights_ &&
8413 Indices_ == other.Indices_ &&
8414 Lengths_ == other.Lengths_ &&
8415 predicate_ == other.predicate_ &&
8416 UseFP16Accumulation_ == other.UseFP16Accumulation_ &&
8417 LengthsMode_ == other.LengthsMode_ &&
8418 AvgLength_ == other.AvgLength_ &&
8419 getType(0) == other.getType(0);
8420}
8421
8422Node* RowwiseQuantizedSparseLengthsWeightedSumNode::clone() const {
8423 return new RowwiseQuantizedSparseLengthsWeightedSumNode(getName(), getResult().getType(), getData(), getScales(), getOffsets(), getWeights(), getIndices(), getLengths(), getUseFP16Accumulation(), getLengthsMode(), getAvgLength());
8424}
8425
8426llvm::hash_code RowwiseQuantizedSparseLengthsWeightedSumNode::getHash() const {
8427 return llvm::hash_combine(
8428 UseFP16Accumulation_,
8429 LengthsMode_,
8430 toBinary(AvgLength_),
8431 Data_,
8432 Scales_,
8433 Offsets_,
8434 Weights_,
8435 Indices_,
8436 Lengths_);
8437}
8438
8439unsigned FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getNumInputs() const {
8440 return 4;
8441}
8442
8443std::string FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getInputName(unsigned idx) const {
8444 if (idx == 0) { return "Data"; }
8445 if (idx == 1) { return "Weights"; }
8446 if (idx == 2) { return "Indices"; }
8447 if (idx == 3) { return "Lengths"; }
8448 idx -= 4;
8449 llvm_unreachable("Invalid index");
8450}
8451
8452NodeValue FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getNthInput(unsigned idx) {
8453 if (idx == 0) { return Data_; }
8454 if (idx == 1) { return Weights_; }
8455 if (idx == 2) { return Indices_; }
8456 if (idx == 3) { return Lengths_; }
8457 idx -= 4;
8458 llvm_unreachable("Invalid index");
8459}
8460
8461void FusedRowwiseQuantizedSparseLengthsWeightedSumNode::setNthInput(unsigned idx, NodeValue val) {
8462 if (idx == 0) { Data_ = val; return; }
8463 if (idx == 1) { Weights_ = val; return; }
8464 if (idx == 2) { Indices_ = val; return; }
8465 if (idx == 3) { Lengths_ = val; return; }
8466 idx -= 4;
8467 llvm_unreachable("Invalid index");
8468}
8469
8470llvm::StringRef FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getOutputName(unsigned idx) const {
8471 if (idx == 0) { return "Result"; }
8472 llvm_unreachable("Invalid index");
8473}
8474
8475std::string FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getDebugDesc() const {
8476 DescriptionBuilder db(getKindName());
8477 db.addParam("Name", separateString(getName(), 100, "\n"));
8478 if (hasPredicate()) db.addParam("Predicate", "Yes");
8479 db
8480 .addParam("Data", *(getData().getType()))
8481 .addParam("Weights", *(getWeights().getType()))
8482 .addParam("Indices", *(getIndices().getType()))
8483 .addParam("Lengths", *(getLengths().getType()))
8484 .addParam("UseFP16Accumulation", getUseFP16Accumulation())
8485 .addParam("LengthsMode", getLengthsMode())
8486 .addParam("AvgLength", getAvgLength())
8487 .addParam("Users", getNumUsers());
8488 db.addParam("Result", *(getResult().getType()));
8489 return db;
8490}
8491
8492void FusedRowwiseQuantizedSparseLengthsWeightedSumNode::visit(Node *parent, NodeWalker *visitor) {
8493 if (!visitor->shouldVisit(parent, this)) { return; }
8494 visitor->pre(parent, this);
8495if (hasPredicate())
8496 getPredicate().getNode()->visit(this, visitor);
8497 getData().getNode()->visit(this, visitor);
8498 getWeights().getNode()->visit(this, visitor);
8499 getIndices().getNode()->visit(this, visitor);
8500 getLengths().getNode()->visit(this, visitor);
8501 visitor->post(parent, this);
8502}
8503
8504bool FusedRowwiseQuantizedSparseLengthsWeightedSumNode::isEqual(const FusedRowwiseQuantizedSparseLengthsWeightedSumNode &other) const {
8505 return true &&
8506 Data_ == other.Data_ &&
8507 Weights_ == other.Weights_ &&
8508 Indices_ == other.Indices_ &&
8509 Lengths_ == other.Lengths_ &&
8510 predicate_ == other.predicate_ &&
8511 UseFP16Accumulation_ == other.UseFP16Accumulation_ &&
8512 LengthsMode_ == other.LengthsMode_ &&
8513 AvgLength_ == other.AvgLength_ &&
8514 getType(0) == other.getType(0);
8515}
8516
8517Node* FusedRowwiseQuantizedSparseLengthsWeightedSumNode::clone() const {
8518 return new FusedRowwiseQuantizedSparseLengthsWeightedSumNode(getName(), getResult().getType(), getData(), getWeights(), getIndices(), getLengths(), getUseFP16Accumulation(), getLengthsMode(), getAvgLength());
8519}
8520
8521llvm::hash_code FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getHash() const {
8522 return llvm::hash_combine(
8523 UseFP16Accumulation_,
8524 LengthsMode_,
8525 toBinary(AvgLength_),
8526 Data_,
8527 Weights_,
8528 Indices_,
8529 Lengths_);
8530}
8531
8532unsigned FusedRowwiseQuantizedSparseLengthsSumNode::getNumInputs() const {
8533 return 3;
8534}
8535
8536std::string FusedRowwiseQuantizedSparseLengthsSumNode::getInputName(unsigned idx) const {
8537 if (idx == 0) { return "Data"; }
8538 if (idx == 1) { return "Indices"; }
8539 if (idx == 2) { return "Lengths"; }
8540 idx -= 3;
8541 llvm_unreachable("Invalid index");
8542}
8543
8544NodeValue FusedRowwiseQuantizedSparseLengthsSumNode::getNthInput(unsigned idx) {
8545 if (idx == 0) { return Data_; }
8546 if (idx == 1) { return Indices_; }
8547 if (idx == 2) { return Lengths_; }
8548 idx -= 3;
8549 llvm_unreachable("Invalid index");
8550}
8551
8552void FusedRowwiseQuantizedSparseLengthsSumNode::setNthInput(unsigned idx, NodeValue val) {
8553 if (idx == 0) { Data_ = val; return; }
8554 if (idx == 1) { Indices_ = val; return; }
8555 if (idx == 2) { Lengths_ = val; return; }
8556 idx -= 3;
8557 llvm_unreachable("Invalid index");
8558}
8559
8560llvm::StringRef FusedRowwiseQuantizedSparseLengthsSumNode::getOutputName(unsigned idx) const {
8561 if (idx == 0) { return "Result"; }
8562 llvm_unreachable("Invalid index");
8563}
8564
8565std::string FusedRowwiseQuantizedSparseLengthsSumNode::getDebugDesc() const {
8566 DescriptionBuilder db(getKindName());
8567 db.addParam("Name", separateString(getName(), 100, "\n"));
8568 if (hasPredicate()) db.addParam("Predicate", "Yes");
8569 db
8570 .addParam("Data", *(getData().getType()))
8571 .addParam("Indices", *(getIndices().getType()))
8572 .addParam("Lengths", *(getLengths().getType()))
8573 .addParam("UseFP16Accumulation", getUseFP16Accumulation())
8574 .addParam("LengthsMode", getLengthsMode())
8575 .addParam("AvgLength", getAvgLength())
8576 .addParam("Users", getNumUsers());
8577 db.addParam("Result", *(getResult().getType()));
8578 return db;
8579}
8580
8581void FusedRowwiseQuantizedSparseLengthsSumNode::visit(Node *parent, NodeWalker *visitor) {
8582 if (!visitor->shouldVisit(parent, this)) { return; }
8583 visitor->pre(parent, this);
8584if (hasPredicate())
8585 getPredicate().getNode()->visit(this, visitor);
8586 getData().getNode()->visit(this, visitor);
8587 getIndices().getNode()->visit(this, visitor);
8588 getLengths().getNode()->visit(this, visitor);
8589 visitor->post(parent, this);
8590}
8591
8592bool FusedRowwiseQuantizedSparseLengthsSumNode::isEqual(const FusedRowwiseQuantizedSparseLengthsSumNode &other) const {
8593 return true &&
8594 Data_ == other.Data_ &&
8595 Indices_ == other.Indices_ &&
8596 Lengths_ == other.Lengths_ &&
8597 predicate_ == other.predicate_ &&
8598 UseFP16Accumulation_ == other.UseFP16Accumulation_ &&
8599 LengthsMode_ == other.LengthsMode_ &&
8600 AvgLength_ == other.AvgLength_ &&
8601 getType(0) == other.getType(0);
8602}
8603
8604Node* FusedRowwiseQuantizedSparseLengthsSumNode::clone() const {
8605 return new FusedRowwiseQuantizedSparseLengthsSumNode(getName(), getResult().getType(), getData(), getIndices(), getLengths(), getUseFP16Accumulation(), getLengthsMode(), getAvgLength());
8606}
8607
8608llvm::hash_code FusedRowwiseQuantizedSparseLengthsSumNode::getHash() const {
8609 return llvm::hash_combine(
8610 UseFP16Accumulation_,
8611 LengthsMode_,
8612 toBinary(AvgLength_),
8613 Data_,
8614 Indices_,
8615 Lengths_);
8616}
8617
8618unsigned LengthsToRangesNode::getNumInputs() const {
8619 return 1;
8620}
8621
8622std::string LengthsToRangesNode::getInputName(unsigned idx) const {
8623 if (idx == 0) { return "Lengths"; }
8624 idx -= 1;
8625 llvm_unreachable("Invalid index");
8626}
8627
8628NodeValue LengthsToRangesNode::getNthInput(unsigned idx) {
8629 if (idx == 0) { return Lengths_; }
8630 idx -= 1;
8631 llvm_unreachable("Invalid index");
8632}
8633
8634void LengthsToRangesNode::setNthInput(unsigned idx, NodeValue val) {
8635 if (idx == 0) { Lengths_ = val; return; }
8636 idx -= 1;
8637 llvm_unreachable("Invalid index");
8638}
8639
8640llvm::StringRef LengthsToRangesNode::getOutputName(unsigned idx) const {
8641 if (idx == 0) { return "Result"; }
8642 llvm_unreachable("Invalid index");
8643}
8644
8645std::string LengthsToRangesNode::getDebugDesc() const {
8646 DescriptionBuilder db(getKindName());
8647 db.addParam("Name", separateString(getName(), 100, "\n"));
8648 if (hasPredicate()) db.addParam("Predicate", "Yes");
8649 db
8650 .addParam("Lengths", *(getLengths().getType()))
8651 .addParam("Users", getNumUsers());
8652 db.addParam("Result", *(getResult().getType()));
8653 return db;
8654}
8655
8656void LengthsToRangesNode::visit(Node *parent, NodeWalker *visitor) {
8657 if (!visitor->shouldVisit(parent, this)) { return; }
8658 visitor->pre(parent, this);
8659if (hasPredicate())
8660 getPredicate().getNode()->visit(this, visitor);
8661 getLengths().getNode()->visit(this, visitor);
8662 visitor->post(parent, this);
8663}
8664
8665bool LengthsToRangesNode::isEqual(const LengthsToRangesNode &other) const {
8666 return true &&
8667 Lengths_ == other.Lengths_ &&
8668 predicate_ == other.predicate_ &&
8669 getType(0) == other.getType(0);
8670}
8671
8672Node* LengthsToRangesNode::clone() const {
8673 return new LengthsToRangesNode(getName(), getResult().getType(), getLengths());
8674}
8675
8676llvm::hash_code LengthsToRangesNode::getHash() const {
8677 return llvm::hash_combine(
8678 Lengths_);
8679}
8680
8681unsigned LengthsRangeFillNode::getNumInputs() const {
8682 return 1;
8683}
8684
8685std::string LengthsRangeFillNode::getInputName(unsigned idx) const {
8686 if (idx == 0) { return "Lengths"; }
8687 idx -= 1;
8688 llvm_unreachable("Invalid index");
8689}
8690
8691NodeValue LengthsRangeFillNode::getNthInput(unsigned idx) {
8692 if (idx == 0) { return Lengths_; }
8693 idx -= 1;
8694 llvm_unreachable("Invalid index");
8695}
8696
8697void LengthsRangeFillNode::setNthInput(unsigned idx, NodeValue val) {
8698 if (idx == 0) { Lengths_ = val; return; }
8699 idx -= 1;
8700 llvm_unreachable("Invalid index");
8701}
8702
8703llvm::StringRef LengthsRangeFillNode::getOutputName(unsigned idx) const {
8704 if (idx == 0) { return "Result"; }
8705 llvm_unreachable("Invalid index");
8706}
8707
8708std::string LengthsRangeFillNode::getDebugDesc() const {
8709 DescriptionBuilder db(getKindName());
8710 db.addParam("Name", separateString(getName(), 100, "\n"));
8711 if (hasPredicate()) db.addParam("Predicate", "Yes");
8712 db
8713 .addParam("Lengths", *(getLengths().getType()))
8714 .addParam("Users", getNumUsers());
8715 db.addParam("Result", *(getResult().getType()));
8716 return db;
8717}
8718
8719void LengthsRangeFillNode::visit(Node *parent, NodeWalker *visitor) {
8720 if (!visitor->shouldVisit(parent, this)) { return; }
8721 visitor->pre(parent, this);
8722if (hasPredicate())
8723 getPredicate().getNode()->visit(this, visitor);
8724 getLengths().getNode()->visit(this, visitor);
8725 visitor->post(parent, this);
8726}
8727
8728bool LengthsRangeFillNode::isEqual(const LengthsRangeFillNode &other) const {
8729 return true &&
8730 Lengths_ == other.Lengths_ &&
8731 predicate_ == other.predicate_ &&
8732 getType(0) == other.getType(0);
8733}
8734
8735Node* LengthsRangeFillNode::clone() const {
8736 return new LengthsRangeFillNode(getName(), getResult().getType(), getLengths());
8737}
8738
8739llvm::hash_code LengthsRangeFillNode::getHash() const {
8740 return llvm::hash_combine(
8741 Lengths_);
8742}
8743
8744unsigned BatchSparseToDenseNode::getNumInputs() const {
8745 return 3;
8746}
8747
8748std::string BatchSparseToDenseNode::getInputName(unsigned idx) const {
8749 if (idx == 0) { return "Lengths"; }
8750 if (idx == 1) { return "Indices"; }
8751 if (idx == 2) { return "Values"; }
8752 idx -= 3;
8753 llvm_unreachable("Invalid index");
8754}
8755
8756NodeValue BatchSparseToDenseNode::getNthInput(unsigned idx) {
8757 if (idx == 0) { return Lengths_; }
8758 if (idx == 1) { return Indices_; }
8759 if (idx == 2) { return Values_; }
8760 idx -= 3;
8761 llvm_unreachable("Invalid index");
8762}
8763
8764void BatchSparseToDenseNode::setNthInput(unsigned idx, NodeValue val) {
8765 if (idx == 0) { Lengths_ = val; return; }
8766 if (idx == 1) { Indices_ = val; return; }
8767 if (idx == 2) { Values_ = val; return; }
8768 idx -= 3;
8769 llvm_unreachable("Invalid index");
8770}
8771
8772llvm::StringRef BatchSparseToDenseNode::getOutputName(unsigned idx) const {
8773 if (idx == 0) { return "Result"; }
8774 llvm_unreachable("Invalid index");
8775}
8776
8777std::string BatchSparseToDenseNode::getDebugDesc() const {
8778 DescriptionBuilder db(getKindName());
8779 db.addParam("Name", separateString(getName(), 100, "\n"));
8780 if (hasPredicate()) db.addParam("Predicate", "Yes");
8781 db
8782 .addParam("Lengths", *(getLengths().getType()))
8783 .addParam("Indices", *(getIndices().getType()))
8784 .addParam("Values", *(getValues().getType()))
8785 .addParam("DefaultValue", getDefaultValue())
8786 .addParam("DenseLastDim", getDenseLastDim())
8787 .addParam("Users", getNumUsers());
8788 db.addParam("Result", *(getResult().getType()));
8789 return db;
8790}
8791
8792void BatchSparseToDenseNode::visit(Node *parent, NodeWalker *visitor) {
8793 if (!visitor->shouldVisit(parent, this)) { return; }
8794 visitor->pre(parent, this);
8795if (hasPredicate())
8796 getPredicate().getNode()->visit(this, visitor);
8797 getLengths().getNode()->visit(this, visitor);
8798 getIndices().getNode()->visit(this, visitor);
8799 getValues().getNode()->visit(this, visitor);
8800 visitor->post(parent, this);
8801}
8802
8803bool BatchSparseToDenseNode::isEqual(const BatchSparseToDenseNode &other) const {
8804 return true &&
8805 Lengths_ == other.Lengths_ &&
8806 Indices_ == other.Indices_ &&
8807 Values_ == other.Values_ &&
8808 predicate_ == other.predicate_ &&
8809 DefaultValue_ == other.DefaultValue_ &&
8810 DenseLastDim_ == other.DenseLastDim_ &&
8811 getType(0) == other.getType(0);
8812}
8813
8814Node* BatchSparseToDenseNode::clone() const {
8815 return new BatchSparseToDenseNode(getName(), getResult().getType(), getLengths(), getIndices(), getValues(), getDefaultValue(), getDenseLastDim());
8816}
8817
8818llvm::hash_code BatchSparseToDenseNode::getHash() const {
8819 return llvm::hash_combine(
8820 toBinary(DefaultValue_),
8821 DenseLastDim_,
8822 Lengths_,
8823 Indices_,
8824 Values_);
8825}
8826
8827unsigned FillExamplesWithIndicatorNode::getNumInputs() const {
8828 return 2;
8829}
8830
8831std::string FillExamplesWithIndicatorNode::getInputName(unsigned idx) const {
8832 if (idx == 0) { return "Data"; }
8833 if (idx == 1) { return "Indicator"; }
8834 idx -= 2;
8835 llvm_unreachable("Invalid index");
8836}
8837
8838NodeValue FillExamplesWithIndicatorNode::getNthInput(unsigned idx) {
8839 if (idx == 0) { return Data_; }
8840 if (idx == 1) { return Indicator_; }
8841 idx -= 2;
8842 llvm_unreachable("Invalid index");
8843}
8844
8845void FillExamplesWithIndicatorNode::setNthInput(unsigned idx, NodeValue val) {
8846 if (idx == 0) { Data_ = val; return; }
8847 if (idx == 1) { Indicator_ = val; return; }
8848 idx -= 2;
8849 llvm_unreachable("Invalid index");
8850}
8851
8852llvm::StringRef FillExamplesWithIndicatorNode::getOutputName(unsigned idx) const {
8853 if (idx == 0) { return "Result"; }
8854 llvm_unreachable("Invalid index");
8855}
8856
8857std::string FillExamplesWithIndicatorNode::getDebugDesc() const {
8858 DescriptionBuilder db(getKindName());
8859 db.addParam("Name", separateString(getName(), 100, "\n"));
8860 if (hasPredicate()) db.addParam("Predicate", "Yes");
8861 db
8862 .addParam("Data", *(getData().getType()))
8863 .addParam("Indicator", *(getIndicator().getType()))
8864 .addParam("Users", getNumUsers());
8865 db.addParam("Result", *(getResult().getType()));
8866 return db;
8867}
8868
8869void FillExamplesWithIndicatorNode::visit(Node *parent, NodeWalker *visitor) {
8870 if (!visitor->shouldVisit(parent, this)) { return; }
8871 visitor->pre(parent, this);
8872if (hasPredicate())
8873 getPredicate().getNode()->visit(this, visitor);
8874 getData().getNode()->visit(this, visitor);
8875 getIndicator().getNode()->visit(this, visitor);
8876 visitor->post(parent, this);
8877}
8878
8879bool FillExamplesWithIndicatorNode::isEqual(const FillExamplesWithIndicatorNode &other) const {
8880 return true &&
8881 Data_ == other.Data_ &&
8882 Indicator_ == other.Indicator_ &&
8883 predicate_ == other.predicate_ &&
8884 getType(0) == other.getType(0);
8885}
8886
8887Node* FillExamplesWithIndicatorNode::clone() const {
8888 return new FillExamplesWithIndicatorNode(getName(), getResult().getType(), getData(), getIndicator());
8889}
8890
8891llvm::hash_code FillExamplesWithIndicatorNode::getHash() const {
8892 return llvm::hash_combine(
8893 Data_,
8894 Indicator_);
8895}
8896
8897unsigned SparseToDenseMaskNode::getNumInputs() const {
8898 return 4;
8899}
8900
8901std::string SparseToDenseMaskNode::getInputName(unsigned idx) const {
8902 if (idx == 0) { return "Indices"; }
8903 if (idx == 1) { return "Values"; }
8904 if (idx == 2) { return "DefaultValue"; }
8905 if (idx == 3) { return "Lengths"; }
8906 idx -= 4;
8907 llvm_unreachable("Invalid index");
8908}
8909
8910NodeValue SparseToDenseMaskNode::getNthInput(unsigned idx) {
8911 if (idx == 0) { return Indices_; }
8912 if (idx == 1) { return Values_; }
8913 if (idx == 2) { return DefaultValue_; }
8914 if (idx == 3) { return Lengths_; }
8915 idx -= 4;
8916 llvm_unreachable("Invalid index");
8917}
8918
8919void SparseToDenseMaskNode::setNthInput(unsigned idx, NodeValue val) {
8920 if (idx == 0) { Indices_ = val; return; }
8921 if (idx == 1) { Values_ = val; return; }
8922 if (idx == 2) { DefaultValue_ = val; return; }
8923 if (idx == 3) { Lengths_ = val; return; }
8924 idx -= 4;
8925 llvm_unreachable("Invalid index");
8926}
8927
8928llvm::StringRef SparseToDenseMaskNode::getOutputName(unsigned idx) const {
8929 if (idx == 0) { return "Result"; }
8930 llvm_unreachable("Invalid index");
8931}
8932
8933std::string SparseToDenseMaskNode::getDebugDesc() const {
8934 DescriptionBuilder db(getKindName());
8935 db.addParam("Name", separateString(getName(), 100, "\n"));
8936 if (hasPredicate()) db.addParam("Predicate", "Yes");
8937 db
8938 .addParam("Indices", *(getIndices().getType()))
8939 .addParam("Values", *(getValues().getType()))
8940 .addParam("DefaultValue", *(getDefaultValue().getType()))
8941 .addParam("Lengths", *(getLengths().getType()))
8942 .addParam("Mask", getMask())
8943 .addParam("Users", getNumUsers());
8944 db.addParam("Result", *(getResult().getType()));
8945 return db;
8946}
8947
8948void SparseToDenseMaskNode::visit(Node *parent, NodeWalker *visitor) {
8949 if (!visitor->shouldVisit(parent, this)) { return; }
8950 visitor->pre(parent, this);
8951if (hasPredicate())
8952 getPredicate().getNode()->visit(this, visitor);
8953 getIndices().getNode()->visit(this, visitor);
8954 getValues().getNode()->visit(this, visitor);
8955 getDefaultValue().getNode()->visit(this, visitor);
8956 getLengths().getNode()->visit(this, visitor);
8957 visitor->post(parent, this);
8958}
8959
8960bool SparseToDenseMaskNode::isEqual(const SparseToDenseMaskNode &other) const {
8961 return true &&
8962 Indices_ == other.Indices_ &&
8963 Values_ == other.Values_ &&
8964 DefaultValue_ == other.DefaultValue_ &&
8965 Lengths_ == other.Lengths_ &&
8966 predicate_ == other.predicate_ &&
8967 Mask_ == other.Mask_ &&
8968 getType(0) == other.getType(0);
8969}
8970
8971Node* SparseToDenseMaskNode::clone() const {
8972 return new SparseToDenseMaskNode(getName(), getResult().getType(), getIndices(), getValues(), getDefaultValue(), getLengths(), getMask());
8973}
8974
8975llvm::hash_code SparseToDenseMaskNode::getHash() const {
8976 return llvm::hash_combine(
8977 llvm::hash_combine_range(Mask_.begin(), Mask_.end()),
8978 Indices_,
8979 Values_,
8980 DefaultValue_,
8981 Lengths_);
8982}
8983
8984unsigned IsNaNNode::getNumInputs() const {
8985 return 1;
8986}
8987
8988std::string IsNaNNode::getInputName(unsigned idx) const {
8989 if (idx == 0) { return "Input"; }
8990 idx -= 1;
8991 llvm_unreachable("Invalid index");
8992}
8993
8994NodeValue IsNaNNode::getNthInput(unsigned idx) {
8995 if (idx == 0) { return Input_; }
8996 idx -= 1;
8997 llvm_unreachable("Invalid index");
8998}
8999
9000void IsNaNNode::setNthInput(unsigned idx, NodeValue val) {
9001 if (idx == 0) { Input_ = val; return; }
9002 idx -= 1;
9003 llvm_unreachable("Invalid index");
9004}
9005
9006llvm::StringRef IsNaNNode::getOutputName(unsigned idx) const {
9007 if (idx == 0) { return "Result"; }
9008 llvm_unreachable("Invalid index");
9009}
9010
9011std::string IsNaNNode::getDebugDesc() const {
9012 DescriptionBuilder db(getKindName());
9013 db.addParam("Name", separateString(getName(), 100, "\n"));
9014 if (hasPredicate()) db.addParam("Predicate", "Yes");
9015 db
9016 .addParam("Input", *(getInput().getType()))
9017 .addParam("Users", getNumUsers());
9018 db.addParam("Result", *(getResult().getType()));
9019 return db;
9020}
9021
9022void IsNaNNode::visit(Node *parent, NodeWalker *visitor) {
9023 if (!visitor->shouldVisit(parent, this)) { return; }
9024 visitor->pre(parent, this);
9025if (hasPredicate())
9026 getPredicate().getNode()->visit(this, visitor);
9027 getInput().getNode()->visit(this, visitor);
9028 visitor->post(parent, this);
9029}
9030
9031bool IsNaNNode::isEqual(const IsNaNNode &other) const {
9032 return true &&
9033 Input_ == other.Input_ &&
9034 predicate_ == other.predicate_ &&
9035 getType(0) == other.getType(0);
9036}
9037
9038Node* IsNaNNode::clone() const {
9039 return new IsNaNNode(getName(), getResult().getType(), getInput());
9040}
9041
9042llvm::hash_code IsNaNNode::getHash() const {
9043 return llvm::hash_combine(
9044 Input_);
9045}
9046
9047unsigned ReplaceNaNNode::getNumInputs() const {
9048 return 1;
9049}
9050
9051std::string ReplaceNaNNode::getInputName(unsigned idx) const {
9052 if (idx == 0) { return "Input"; }
9053 idx -= 1;
9054 llvm_unreachable("Invalid index");
9055}
9056
9057NodeValue ReplaceNaNNode::getNthInput(unsigned idx) {
9058 if (idx == 0) { return Input_; }
9059 idx -= 1;
9060 llvm_unreachable("Invalid index");
9061}
9062
9063void ReplaceNaNNode::setNthInput(unsigned idx, NodeValue val) {
9064 if (idx == 0) { Input_ = val; return; }
9065 idx -= 1;
9066 llvm_unreachable("Invalid index");
9067}
9068
9069llvm::StringRef ReplaceNaNNode::getOutputName(unsigned idx) const {
9070 if (idx == 0) { return "Result"; }
9071 llvm_unreachable("Invalid index");
9072}
9073
9074std::string ReplaceNaNNode::getDebugDesc() const {
9075 DescriptionBuilder db(getKindName());
9076 db.addParam("Name", separateString(getName(), 100, "\n"));
9077 if (hasPredicate()) db.addParam("Predicate", "Yes");
9078 db
9079 .addParam("Input", *(getInput().getType()))
9080 .addParam("Value", getValue())
9081 .addParam("Users", getNumUsers());
9082 db.addParam("Result", *(getResult().getType()));
9083 return db;
9084}
9085
9086void ReplaceNaNNode::visit(Node *parent, NodeWalker *visitor) {
9087 if (!visitor->shouldVisit(parent, this)) { return; }
9088 visitor->pre(parent, this);
9089if (hasPredicate())
9090 getPredicate().getNode()->visit(this, visitor);
9091 getInput().getNode()->visit(this, visitor);
9092 visitor->post(parent, this);
9093}
9094
9095bool ReplaceNaNNode::isEqual(const ReplaceNaNNode &other) const {
9096 return true &&
9097 Input_ == other.Input_ &&
9098 predicate_ == other.predicate_ &&
9099 Value_ == other.Value_ &&
9100 getType(0) == other.getType(0);
9101}
9102
9103Node* ReplaceNaNNode::clone() const {
9104 return new ReplaceNaNNode(getName(), getResult().getType(), getInput(), getValue());
9105}
9106
9107llvm::hash_code ReplaceNaNNode::getHash() const {
9108 return llvm::hash_combine(
9109 toBinary(Value_),
9110 Input_);
9111}
9112
9113unsigned ModuloNode::getNumInputs() const {
9114 return 1;
9115}
9116
9117std::string ModuloNode::getInputName(unsigned idx) const {
9118 if (idx == 0) { return "Input"; }
9119 idx -= 1;
9120 llvm_unreachable("Invalid index");
9121}
9122
9123NodeValue ModuloNode::getNthInput(unsigned idx) {
9124 if (idx == 0) { return Input_; }
9125 idx -= 1;
9126 llvm_unreachable("Invalid index");
9127}
9128
9129void ModuloNode::setNthInput(unsigned idx, NodeValue val) {
9130 if (idx == 0) { Input_ = val; return; }
9131 idx -= 1;
9132 llvm_unreachable("Invalid index");
9133}
9134
9135llvm::StringRef ModuloNode::getOutputName(unsigned idx) const {
9136 if (idx == 0) { return "Result"; }
9137 llvm_unreachable("Invalid index");
9138}
9139
9140std::string ModuloNode::getDebugDesc() const {
9141 DescriptionBuilder db(getKindName());
9142 db.addParam("Name", separateString(getName(), 100, "\n"));
9143 if (hasPredicate()) db.addParam("Predicate", "Yes");
9144 db
9145 .addParam("Input", *(getInput().getType()))
9146 .addParam("Divisor", getDivisor())
9147 .addParam("SignFollowDivisor", getSignFollowDivisor())
9148 .addParam("Users", getNumUsers());
9149 db.addParam("Result", *(getResult().getType()));
9150 return db;
9151}
9152
9153void ModuloNode::visit(Node *parent, NodeWalker *visitor) {
9154 if (!visitor->shouldVisit(parent, this)) { return; }
9155 visitor->pre(parent, this);
9156if (hasPredicate())
9157 getPredicate().getNode()->visit(this, visitor);
9158 getInput().getNode()->visit(this, visitor);
9159 visitor->post(parent, this);
9160}
9161
9162bool ModuloNode::isEqual(const ModuloNode &other) const {
9163 return true &&
9164 Input_ == other.Input_ &&
9165 predicate_ == other.predicate_ &&
9166 Divisor_ == other.Divisor_ &&
9167 SignFollowDivisor_ == other.SignFollowDivisor_ &&
9168 getType(0) == other.getType(0);
9169}
9170
9171Node* ModuloNode::clone() const {
9172 return new ModuloNode(getName(), getResult().getType(), getInput(), getDivisor(), getSignFollowDivisor());
9173}
9174
9175llvm::hash_code ModuloNode::getHash() const {
9176 return llvm::hash_combine(
9177 Divisor_,
9178 SignFollowDivisor_,
9179 Input_);
9180}
9181
9182unsigned BatchedPairwiseDotProductNode::getNumInputs() const {
9183 return 0 + Inputs_.size();
9184}
9185
9186std::string BatchedPairwiseDotProductNode::getInputName(unsigned idx) const {
9187 idx -= 0;
9188 if (idx < Inputs_.size()) { return "Inputs" + std::to_string(idx); }
9189 idx -= Inputs_.size();
9190 llvm_unreachable("Invalid index");
9191}
9192
9193NodeValue BatchedPairwiseDotProductNode::getNthInput(unsigned idx) {
9194 idx -= 0;
9195 if (idx < Inputs_.size()) { return Inputs_[idx]; }
9196 idx -= Inputs_.size();
9197 llvm_unreachable("Invalid index");
9198}
9199
9200void BatchedPairwiseDotProductNode::setNthInput(unsigned idx, NodeValue val) {
9201 idx -= 0;
9202 if (idx < Inputs_.size()) { Inputs_[idx] = val; return; }
9203 idx -= Inputs_.size();
9204 llvm_unreachable("Invalid index");
9205}
9206
9207llvm::StringRef BatchedPairwiseDotProductNode::getOutputName(unsigned idx) const {
9208 if (idx == 0) { return "Result"; }
9209 llvm_unreachable("Invalid index");
9210}
9211
9212std::string BatchedPairwiseDotProductNode::getDebugDesc() const {
9213 DescriptionBuilder db(getKindName());
9214 db.addParam("Name", separateString(getName(), 100, "\n"));
9215 if (hasPredicate()) db.addParam("Predicate", "Yes");
9216 db
9217 .addParam("Users", getNumUsers());
9218 {
9219 unsigned mIndex = 0;
9220 for (const auto &II : getInputs()) {
9221 db.addParam("Inputs"+std::to_string(mIndex++), *II.getType());
9222 }
9223 }
9224 db.addParam("Result", *(getResult().getType()));
9225 return db;
9226}
9227
9228void BatchedPairwiseDotProductNode::visit(Node *parent, NodeWalker *visitor) {
9229 if (!visitor->shouldVisit(parent, this)) { return; }
9230 visitor->pre(parent, this);
9231if (hasPredicate())
9232 getPredicate().getNode()->visit(this, visitor);
9233 for (auto &I : Inputs_) { I.getNode()->visit(this, visitor); }
9234 visitor->post(parent, this);
9235}
9236
9237bool BatchedPairwiseDotProductNode::isEqual(const BatchedPairwiseDotProductNode &other) const {
9238 return true &&
9239 predicate_ == other.predicate_ &&
9240 Inputs_ == other.Inputs_ &&
9241 getType(0) == other.getType(0);
9242}
9243
9244Node* BatchedPairwiseDotProductNode::clone() const {
9245 return new BatchedPairwiseDotProductNode(getName(), getResult().getType(), getInputs());
9246}
9247
9248llvm::hash_code BatchedPairwiseDotProductNode::getHash() const {
9249 return llvm::hash_combine(
9250 llvm::hash_combine_range(Inputs_.begin(), Inputs_.end()));
9251}
9252
9253unsigned BatchedPairwiseDotProductGradNode::getNumInputs() const {
9254 return 1 + OriginalInputs_.size();
9255}
9256
9257std::string BatchedPairwiseDotProductGradNode::getInputName(unsigned idx) const {
9258 if (idx == 0) { return "OutputGrad"; }
9259 idx -= 1;
9260 if (idx < OriginalInputs_.size()) { return "OriginalInputs" + std::to_string(idx); }
9261 idx -= OriginalInputs_.size();
9262 llvm_unreachable("Invalid index");
9263}
9264
9265NodeValue BatchedPairwiseDotProductGradNode::getNthInput(unsigned idx) {
9266 if (idx == 0) { return OutputGrad_; }
9267 idx -= 1;
9268 if (idx < OriginalInputs_.size()) { return OriginalInputs_[idx]; }
9269 idx -= OriginalInputs_.size();
9270 llvm_unreachable("Invalid index");
9271}
9272
9273void BatchedPairwiseDotProductGradNode::setNthInput(unsigned idx, NodeValue val) {
9274 if (idx == 0) { OutputGrad_ = val; return; }
9275 idx -= 1;
9276 if (idx < OriginalInputs_.size()) { OriginalInputs_[idx] = val; return; }
9277 idx -= OriginalInputs_.size();
9278 llvm_unreachable("Invalid index");
9279}
9280
9281llvm::StringRef BatchedPairwiseDotProductGradNode::getOutputName(unsigned idx) const {
9282 llvm_unreachable("Invalid index");
9283}
9284
9285std::string BatchedPairwiseDotProductGradNode::getDebugDesc() const {
9286 DescriptionBuilder db(getKindName());
9287 db.addParam("Name", separateString(getName(), 100, "\n"));
9288 if (hasPredicate()) db.addParam("Predicate", "Yes");
9289 db
9290 .addParam("OutputGrad", *(getOutputGrad().getType()))
9291 .addParam("Users", getNumUsers());
9292 {
9293 unsigned mIndex = 0;
9294 for (const auto &II : getOriginalInputs()) {
9295 db.addParam("OriginalInputs"+std::to_string(mIndex++), *II.getType());
9296 }
9297 }
9298 return db;
9299}
9300
9301void BatchedPairwiseDotProductGradNode::visit(Node *parent, NodeWalker *visitor) {
9302 if (!visitor->shouldVisit(parent, this)) { return; }
9303 visitor->pre(parent, this);
9304if (hasPredicate())
9305 getPredicate().getNode()->visit(this, visitor);
9306 getOutputGrad().getNode()->visit(this, visitor);
9307 for (auto &I : OriginalInputs_) { I.getNode()->visit(this, visitor); }
9308 visitor->post(parent, this);
9309}
9310
9311bool BatchedPairwiseDotProductGradNode::isEqual(const BatchedPairwiseDotProductGradNode &other) const {
9312 return true &&
9313 OutputGrad_ == other.OutputGrad_ &&
9314 predicate_ == other.predicate_ &&
9315 OriginalInputs_ == other.OriginalInputs_;
9316}
9317
9318Node* BatchedPairwiseDotProductGradNode::clone() const {
9319 return new BatchedPairwiseDotProductGradNode(getName(), getOutputGrad(), getOriginalInputs());
9320}
9321
9322llvm::hash_code BatchedPairwiseDotProductGradNode::getHash() const {
9323 return llvm::hash_combine(
9324 llvm::hash_combine_range(OriginalInputs_.begin(), OriginalInputs_.end()),
9325 OutputGrad_);
9326}
9327
9328unsigned BatchedUnaryEmbeddingsBagsNode::getNumInputs() const {
9329 return 4;
9330}
9331
9332std::string BatchedUnaryEmbeddingsBagsNode::getInputName(unsigned idx) const {
9333 if (idx == 0) { return "Weights"; }
9334 if (idx == 1) { return "TableOffsets"; }
9335 if (idx == 2) { return "Offsets"; }
9336 if (idx == 3) { return "Indices"; }
9337 idx -= 4;
9338 llvm_unreachable("Invalid index");
9339}
9340
9341NodeValue BatchedUnaryEmbeddingsBagsNode::getNthInput(unsigned idx) {
9342 if (idx == 0) { return Weights_; }
9343 if (idx == 1) { return TableOffsets_; }
9344 if (idx == 2) { return Offsets_; }
9345 if (idx == 3) { return Indices_; }
9346 idx -= 4;
9347 llvm_unreachable("Invalid index");
9348}
9349
9350void BatchedUnaryEmbeddingsBagsNode::setNthInput(unsigned idx, NodeValue val) {
9351 if (idx == 0) { Weights_ = val; return; }
9352 if (idx == 1) { TableOffsets_ = val; return; }
9353 if (idx == 2) { Offsets_ = val; return; }
9354 if (idx == 3) { Indices_ = val; return; }
9355 idx -= 4;
9356 llvm_unreachable("Invalid index");
9357}
9358
9359llvm::StringRef BatchedUnaryEmbeddingsBagsNode::getOutputName(unsigned idx) const {
9360 if (idx == 0) { return "Result"; }
9361 llvm_unreachable("Invalid index");
9362}
9363
9364std::string BatchedUnaryEmbeddingsBagsNode::getDebugDesc() const {
9365 DescriptionBuilder db(getKindName());
9366 db.addParam("Name", separateString(getName(), 100, "\n"));
9367 if (hasPredicate()) db.addParam("Predicate", "Yes");
9368 db
9369 .addParam("Weights", *(getWeights().getType()))
9370 .addParam("TableOffsets", *(getTableOffsets().getType()))
9371 .addParam("Offsets", *(getOffsets().getType()))
9372 .addParam("Indices", *(getIndices().getType()))
9373 .addParam("Users", getNumUsers());
9374 db.addParam("Result", *(getResult().getType()));
9375 return db;
9376}
9377
9378void BatchedUnaryEmbeddingsBagsNode::visit(Node *parent, NodeWalker *visitor) {
9379 if (!visitor->shouldVisit(parent, this)) { return; }
9380 visitor->pre(parent, this);
9381if (hasPredicate())
9382 getPredicate().getNode()->visit(this, visitor);
9383 getWeights().getNode()->visit(this, visitor);
9384 getTableOffsets().getNode()->visit(this, visitor);
9385 getOffsets().getNode()->visit(this, visitor);
9386 getIndices().getNode()->visit(this, visitor);
9387 visitor->post(parent, this);
9388}
9389
9390bool BatchedUnaryEmbeddingsBagsNode::isEqual(const BatchedUnaryEmbeddingsBagsNode &other) const {
9391 return true &&
9392 Weights_ == other.Weights_ &&
9393 TableOffsets_ == other.TableOffsets_ &&
9394 Offsets_ == other.Offsets_ &&
9395 Indices_ == other.Indices_ &&
9396 predicate_ == other.predicate_ &&
9397 getType(0) == other.getType(0);
9398}
9399
9400Node* BatchedUnaryEmbeddingsBagsNode::clone() const {
9401 return new BatchedUnaryEmbeddingsBagsNode(getName(), getResult().getType(), getWeights(), getTableOffsets(), getOffsets(), getIndices());
9402}
9403
9404llvm::hash_code BatchedUnaryEmbeddingsBagsNode::getHash() const {
9405 return llvm::hash_combine(
9406 Weights_,
9407 TableOffsets_,
9408 Offsets_,
9409 Indices_);
9410}
9411
9412unsigned IntNBitSplitEmbeddingBagsNode::getNumInputs() const {
9413 return 8;
9414}
9415
9416std::string IntNBitSplitEmbeddingBagsNode::getInputName(unsigned idx) const {
9417 if (idx == 0) { return "DevWeights"; }
9418 if (idx == 1) { return "UvmWeights"; }
9419 if (idx == 2) { return "WeightsPlacements"; }
9420 if (idx == 3) { return "WeightsOffsets"; }
9421 if (idx == 4) { return "WeightsTys"; }
9422 if (idx == 5) { return "DimOffsets"; }
9423 if (idx == 6) { return "Indices"; }
9424 if (idx == 7) { return "Offsets"; }
9425 idx -= 8;
9426 llvm_unreachable("Invalid index");
9427}
9428
9429NodeValue IntNBitSplitEmbeddingBagsNode::getNthInput(unsigned idx) {
9430 if (idx == 0) { return DevWeights_; }
9431 if (idx == 1) { return UvmWeights_; }
9432 if (idx == 2) { return WeightsPlacements_; }
9433 if (idx == 3) { return WeightsOffsets_; }
9434 if (idx == 4) { return WeightsTys_; }
9435 if (idx == 5) { return DimOffsets_; }
9436 if (idx == 6) { return Indices_; }
9437 if (idx == 7) { return Offsets_; }
9438 idx -= 8;
9439 llvm_unreachable("Invalid index");
9440}
9441
9442void IntNBitSplitEmbeddingBagsNode::setNthInput(unsigned idx, NodeValue val) {
9443 if (idx == 0) { DevWeights_ = val; return; }
9444 if (idx == 1) { UvmWeights_ = val; return; }
9445 if (idx == 2) { WeightsPlacements_ = val; return; }
9446 if (idx == 3) { WeightsOffsets_ = val; return; }
9447 if (idx == 4) { WeightsTys_ = val; return; }
9448 if (idx == 5) { DimOffsets_ = val; return; }
9449 if (idx == 6) { Indices_ = val; return; }
9450 if (idx == 7) { Offsets_ = val; return; }
9451 idx -= 8;
9452 llvm_unreachable("Invalid index");
9453}
9454
9455llvm::StringRef IntNBitSplitEmbeddingBagsNode::getOutputName(unsigned idx) const {
9456 if (idx == 0) { return "Result"; }
9457 llvm_unreachable("Invalid index");
9458}
9459
9460std::string IntNBitSplitEmbeddingBagsNode::getDebugDesc() const {
9461 DescriptionBuilder db(getKindName());
9462 db.addParam("Name", separateString(getName(), 100, "\n"));
9463 if (hasPredicate()) db.addParam("Predicate", "Yes");
9464 db
9465 .addParam("DevWeights", *(getDevWeights().getType()))
9466 .addParam("UvmWeights", *(getUvmWeights().getType()))
9467 .addParam("WeightsPlacements", *(getWeightsPlacements().getType()))
9468 .addParam("WeightsOffsets", *(getWeightsOffsets().getType()))
9469 .addParam("WeightsTys", *(getWeightsTys().getType()))
9470 .addParam("DimOffsets", *(getDimOffsets().getType()))
9471 .addParam("Indices", *(getIndices().getType()))
9472 .addParam("Offsets", *(getOffsets().getType()))
9473 .addParam("TotalDims", getTotalDims())
9474 .addParam("PoolingMode", getPoolingMode())
9475 .addParam("OutputDType", getOutputDType())
9476 .addParam("Users", getNumUsers());
9477 db.addParam("Result", *(getResult().getType()));
9478 return db;
9479}
9480
9481void IntNBitSplitEmbeddingBagsNode::visit(Node *parent, NodeWalker *visitor) {
9482 if (!visitor->shouldVisit(parent, this)) { return; }
9483 visitor->pre(parent, this);
9484if (hasPredicate())
9485 getPredicate().getNode()->visit(this, visitor);
9486 getDevWeights().getNode()->visit(this, visitor);
9487 getUvmWeights().getNode()->visit(this, visitor);
9488 getWeightsPlacements().getNode()->visit(this, visitor);
9489 getWeightsOffsets().getNode()->visit(this, visitor);
9490 getWeightsTys().getNode()->visit(this, visitor);
9491 getDimOffsets().getNode()->visit(this, visitor);
9492 getIndices().getNode()->visit(this, visitor);
9493 getOffsets().getNode()->visit(this, visitor);
9494 visitor->post(parent, this);
9495}
9496
9497bool IntNBitSplitEmbeddingBagsNode::isEqual(const IntNBitSplitEmbeddingBagsNode &other) const {
9498 return true &&
9499 DevWeights_ == other.DevWeights_ &&
9500 UvmWeights_ == other.UvmWeights_ &&
9501 WeightsPlacements_ == other.WeightsPlacements_ &&
9502 WeightsOffsets_ == other.WeightsOffsets_ &&
9503 WeightsTys_ == other.WeightsTys_ &&
9504 DimOffsets_ == other.DimOffsets_ &&
9505 Indices_ == other.Indices_ &&
9506 Offsets_ == other.Offsets_ &&
9507 predicate_ == other.predicate_ &&
9508 TotalDims_ == other.TotalDims_ &&
9509 PoolingMode_ == other.PoolingMode_ &&
9510 OutputDType_ == other.OutputDType_ &&
9511 getType(0) == other.getType(0);
9512}
9513
9514Node* IntNBitSplitEmbeddingBagsNode::clone() const {
9515 return new IntNBitSplitEmbeddingBagsNode(getName(), getResult().getType(), getDevWeights(), getUvmWeights(), getWeightsPlacements(), getWeightsOffsets(), getWeightsTys(), getDimOffsets(), getIndices(), getOffsets(), getTotalDims(), getPoolingMode(), getOutputDType());
9516}
9517
9518llvm::hash_code IntNBitSplitEmbeddingBagsNode::getHash() const {
9519 return llvm::hash_combine(
9520 TotalDims_,
9521 PoolingMode_,
9522 OutputDType_,
9523 DevWeights_,
9524 UvmWeights_,
9525 WeightsPlacements_,
9526 WeightsOffsets_,
9527 WeightsTys_,
9528 DimOffsets_,
9529 Indices_,
9530 Offsets_);
9531}
9532
9533unsigned IntNBitSplitEmbeddingWeightedBagsNode::getNumInputs() const {
9534 return 9;
9535}
9536
9537std::string IntNBitSplitEmbeddingWeightedBagsNode::getInputName(unsigned idx) const {
9538 if (idx == 0) { return "DevWeights"; }
9539 if (idx == 1) { return "UvmWeights"; }
9540 if (idx == 2) { return "WeightsPlacements"; }
9541 if (idx == 3) { return "WeightsOffsets"; }
9542 if (idx == 4) { return "WeightsTys"; }
9543 if (idx == 5) { return "DimOffsets"; }
9544 if (idx == 6) { return "Indices"; }
9545 if (idx == 7) { return "Offsets"; }
9546 if (idx == 8) { return "IndiceWeight"; }
9547 idx -= 9;
9548 llvm_unreachable("Invalid index");
9549}
9550
9551NodeValue IntNBitSplitEmbeddingWeightedBagsNode::getNthInput(unsigned idx) {
9552 if (idx == 0) { return DevWeights_; }
9553 if (idx == 1) { return UvmWeights_; }
9554 if (idx == 2) { return WeightsPlacements_; }
9555 if (idx == 3) { return WeightsOffsets_; }
9556 if (idx == 4) { return WeightsTys_; }
9557 if (idx == 5) { return DimOffsets_; }
9558 if (idx == 6) { return Indices_; }
9559 if (idx == 7) { return Offsets_; }
9560 if (idx == 8) { return IndiceWeight_; }
9561 idx -= 9;
9562 llvm_unreachable("Invalid index");
9563}
9564
9565void IntNBitSplitEmbeddingWeightedBagsNode::setNthInput(unsigned idx, NodeValue val) {
9566 if (idx == 0) { DevWeights_ = val; return; }
9567 if (idx == 1) { UvmWeights_ = val; return; }
9568 if (idx == 2) { WeightsPlacements_ = val; return; }
9569 if (idx == 3) { WeightsOffsets_ = val; return; }
9570 if (idx == 4) { WeightsTys_ = val; return; }
9571 if (idx == 5) { DimOffsets_ = val; return; }
9572 if (idx == 6) { Indices_ = val; return; }
9573 if (idx == 7) { Offsets_ = val; return; }
9574 if (idx == 8) { IndiceWeight_ = val; return; }
9575 idx -= 9;
9576 llvm_unreachable("Invalid index");
9577}
9578
9579llvm::StringRef IntNBitSplitEmbeddingWeightedBagsNode::getOutputName(unsigned idx) const {
9580 if (idx == 0) { return "Result"; }
9581 llvm_unreachable("Invalid index");
9582}
9583
9584std::string IntNBitSplitEmbeddingWeightedBagsNode::getDebugDesc() const {
9585 DescriptionBuilder db(getKindName());
9586 db.addParam("Name", separateString(getName(), 100, "\n"));
9587 if (hasPredicate()) db.addParam("Predicate", "Yes");
9588 db
9589 .addParam("DevWeights", *(getDevWeights().getType()))
9590 .addParam("UvmWeights", *(getUvmWeights().getType()))
9591 .addParam("WeightsPlacements", *(getWeightsPlacements().getType()))
9592 .addParam("WeightsOffsets", *(getWeightsOffsets().getType()))
9593 .addParam("WeightsTys", *(getWeightsTys().getType()))
9594 .addParam("DimOffsets", *(getDimOffsets().getType()))
9595 .addParam("Indices", *(getIndices().getType()))
9596 .addParam("Offsets", *(getOffsets().getType()))
9597 .addParam("IndiceWeight", *(getIndiceWeight().getType()))
9598 .addParam("TotalDims", getTotalDims())
9599 .addParam("PoolingMode", getPoolingMode())
9600 .addParam("OutputDType", getOutputDType())
9601 .addParam("Users", getNumUsers());
9602 db.addParam("Result", *(getResult().getType()));
9603 return db;
9604}
9605
9606void IntNBitSplitEmbeddingWeightedBagsNode::visit(Node *parent, NodeWalker *visitor) {
9607 if (!visitor->shouldVisit(parent, this)) { return; }
9608 visitor->pre(parent, this);
9609if (hasPredicate())
9610 getPredicate().getNode()->visit(this, visitor);
9611 getDevWeights().getNode()->visit(this, visitor);
9612 getUvmWeights().getNode()->visit(this, visitor);
9613 getWeightsPlacements().getNode()->visit(this, visitor);
9614 getWeightsOffsets().getNode()->visit(this, visitor);
9615 getWeightsTys().getNode()->visit(this, visitor);
9616 getDimOffsets().getNode()->visit(this, visitor);
9617 getIndices().getNode()->visit(this, visitor);
9618 getOffsets().getNode()->visit(this, visitor);
9619 getIndiceWeight().getNode()->visit(this, visitor);
9620 visitor->post(parent, this);
9621}
9622
9623bool IntNBitSplitEmbeddingWeightedBagsNode::isEqual(const IntNBitSplitEmbeddingWeightedBagsNode &other) const {
9624 return true &&
9625 DevWeights_ == other.DevWeights_ &&
9626 UvmWeights_ == other.UvmWeights_ &&
9627 WeightsPlacements_ == other.WeightsPlacements_ &&
9628 WeightsOffsets_ == other.WeightsOffsets_ &&
9629 WeightsTys_ == other.WeightsTys_ &&
9630 DimOffsets_ == other.DimOffsets_ &&
9631 Indices_ == other.Indices_ &&
9632 Offsets_ == other.Offsets_ &&
9633 IndiceWeight_ == other.IndiceWeight_ &&
9634 predicate_ == other.predicate_ &&
9635 TotalDims_ == other.TotalDims_ &&
9636 PoolingMode_ == other.PoolingMode_ &&
9637 OutputDType_ == other.OutputDType_ &&
9638 getType(0) == other.getType(0);
9639}
9640
9641Node* IntNBitSplitEmbeddingWeightedBagsNode::clone() const {
9642 return new IntNBitSplitEmbeddingWeightedBagsNode(getName(), getResult().getType(), getDevWeights(), getUvmWeights(), getWeightsPlacements(), getWeightsOffsets(), getWeightsTys(), getDimOffsets(), getIndices(), getOffsets(), getIndiceWeight(), getTotalDims(), getPoolingMode(), getOutputDType());
9643}
9644
9645llvm::hash_code IntNBitSplitEmbeddingWeightedBagsNode::getHash() const {
9646 return llvm::hash_combine(
9647 TotalDims_,
9648 PoolingMode_,
9649 OutputDType_,
9650 DevWeights_,
9651 UvmWeights_,
9652 WeightsPlacements_,
9653 WeightsOffsets_,
9654 WeightsTys_,
9655 DimOffsets_,
9656 Indices_,
9657 Offsets_,
9658 IndiceWeight_);
9659}
9660
9661unsigned GaussianFillNode::getNumInputs() const {
9662 return 1;
9663}
9664
9665std::string GaussianFillNode::getInputName(unsigned idx) const {
9666 if (idx == 0) { return "Input"; }
9667 idx -= 1;
9668 llvm_unreachable("Invalid index");
9669}
9670
9671NodeValue GaussianFillNode::getNthInput(unsigned idx) {
9672 if (idx == 0) { return Input_; }
9673 idx -= 1;
9674 llvm_unreachable("Invalid index");
9675}
9676
9677void GaussianFillNode::setNthInput(unsigned idx, NodeValue val) {
9678 if (idx == 0) { Input_ = val; return; }
9679 idx -= 1;
9680 llvm_unreachable("Invalid index");
9681}
9682
9683llvm::StringRef GaussianFillNode::getOutputName(unsigned idx) const {
9684 if (idx == 0) { return "Result"; }
9685 llvm_unreachable("Invalid index");
9686}
9687
9688std::string GaussianFillNode::getDebugDesc() const {
9689 DescriptionBuilder db(getKindName());
9690 db.addParam("Name", separateString(getName(), 100, "\n"));
9691 if (hasPredicate()) db.addParam("Predicate", "Yes");
9692 db
9693 .addParam("Input", *(getInput().getType()))
9694 .addParam("Mean", getMean())
9695 .addParam("Scale", getScale())
9696 .addParam("Seed", getSeed())
9697 .addParam("Users", getNumUsers());
9698 db.addParam("Result", *(getResult().getType()));
9699 return db;
9700}
9701
9702void GaussianFillNode::visit(Node *parent, NodeWalker *visitor) {
9703 if (!visitor->shouldVisit(parent, this)) { return; }
9704 visitor->pre(parent, this);
9705if (hasPredicate())
9706 getPredicate().getNode()->visit(this, visitor);
9707 getInput().getNode()->visit(this, visitor);
9708 visitor->post(parent, this);
9709}
9710
9711bool GaussianFillNode::isEqual(const GaussianFillNode &other) const {
9712 return true &&
9713 Input_ == other.Input_ &&
9714 predicate_ == other.predicate_ &&
9715 Mean_ == other.Mean_ &&
9716 Scale_ == other.Scale_ &&
9717 Seed_ == other.Seed_ &&
9718 getType(0) == other.getType(0);
9719}
9720
9721Node* GaussianFillNode::clone() const {
9722 return new GaussianFillNode(getName(), getResult().getType(), getInput(), getMean(), getScale(), getSeed());
9723}
9724
9725llvm::hash_code GaussianFillNode::getHash() const {
9726 return llvm::hash_combine(
9727 toBinary(Mean_),
9728 toBinary(Scale_),
9729 toBinary(Seed_),
9730 Input_);
9731}
9732
9733unsigned ReluGradNode::getNumInputs() const {
9734 return 3;
9735}
9736
9737std::string ReluGradNode::getInputName(unsigned idx) const {
9738 if (idx == 0) { return "Input"; }
9739 if (idx == 1) { return "OriginalOutputForResult"; }
9740 if (idx == 2) { return "GradOfOriginalOutputNamedResult"; }
9741 idx -= 3;
9742 llvm_unreachable("Invalid index");
9743}
9744
9745NodeValue ReluGradNode::getNthInput(unsigned idx) {
9746 if (idx == 0) { return Input_; }
9747 if (idx == 1) { return OriginalOutputForResult_; }
9748 if (idx == 2) { return GradOfOriginalOutputNamedResult_; }
9749 idx -= 3;
9750 llvm_unreachable("Invalid index");
9751}
9752
9753void ReluGradNode::setNthInput(unsigned idx, NodeValue val) {
9754 if (idx == 0) { Input_ = val; return; }
9755 if (idx == 1) { OriginalOutputForResult_ = val; return; }
9756 if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; }
9757 idx -= 3;
9758 llvm_unreachable("Invalid index");
9759}
9760
9761llvm::StringRef ReluGradNode::getOutputName(unsigned idx) const {
9762 if (idx == 0) { return "GradOfInputNamedInput"; }
9763 llvm_unreachable("Invalid index");
9764}
9765
9766std::string ReluGradNode::getDebugDesc() const {
9767 DescriptionBuilder db(getKindName());
9768 db.addParam("Name", separateString(getName(), 100, "\n"));
9769 if (hasPredicate()) db.addParam("Predicate", "Yes");
9770 db
9771 .addParam("Input", *(getInput().getType()))
9772 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
9773 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
9774 .addParam("Users", getNumUsers());
9775 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
9776 return db;
9777}
9778
9779void ReluGradNode::visit(Node *parent, NodeWalker *visitor) {
9780 if (!visitor->shouldVisit(parent, this)) { return; }
9781 visitor->pre(parent, this);
9782if (hasPredicate())
9783 getPredicate().getNode()->visit(this, visitor);
9784 getInput().getNode()->visit(this, visitor);
9785 getOriginalOutputForResult().getNode()->visit(this, visitor);
9786 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
9787 visitor->post(parent, this);
9788}
9789
9790bool ReluGradNode::isEqual(const ReluGradNode &other) const {
9791 return true &&
9792 Input_ == other.Input_ &&
9793 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
9794 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
9795 predicate_ == other.predicate_ &&
9796 getType(0) == other.getType(0);
9797}
9798
9799Node* ReluGradNode::clone() const {
9800 return new ReluGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
9801}
9802
9803llvm::hash_code ReluGradNode::getHash() const {
9804 return llvm::hash_combine(
9805 Input_,
9806 OriginalOutputForResult_,
9807 GradOfOriginalOutputNamedResult_);
9808}
9809
9810unsigned ReluNode::getNumInputs() const {
9811 return 1;
9812}
9813
9814std::string ReluNode::getInputName(unsigned idx) const {
9815 if (idx == 0) { return "Input"; }
9816 idx -= 1;
9817 llvm_unreachable("Invalid index");
9818}
9819
9820NodeValue ReluNode::getNthInput(unsigned idx) {
9821 if (idx == 0) { return Input_; }
9822 idx -= 1;
9823 llvm_unreachable("Invalid index");
9824}
9825
9826void ReluNode::setNthInput(unsigned idx, NodeValue val) {
9827 if (idx == 0) { Input_ = val; return; }
9828 idx -= 1;
9829 llvm_unreachable("Invalid index");
9830}
9831
9832llvm::StringRef ReluNode::getOutputName(unsigned idx) const {
9833 if (idx == 0) { return "Result"; }
9834 llvm_unreachable("Invalid index");
9835}
9836
9837std::string ReluNode::getDebugDesc() const {
9838 DescriptionBuilder db(getKindName());
9839 db.addParam("Name", separateString(getName(), 100, "\n"));
9840 if (hasPredicate()) db.addParam("Predicate", "Yes");
9841 db
9842 .addParam("Input", *(getInput().getType()))
9843 .addParam("Users", getNumUsers());
9844 db.addParam("Result", *(getResult().getType()));
9845 return db;
9846}
9847
9848void ReluNode::visit(Node *parent, NodeWalker *visitor) {
9849 if (!visitor->shouldVisit(parent, this)) { return; }
9850 visitor->pre(parent, this);
9851if (hasPredicate())
9852 getPredicate().getNode()->visit(this, visitor);
9853 getInput().getNode()->visit(this, visitor);
9854 visitor->post(parent, this);
9855}
9856
9857bool ReluNode::isEqual(const ReluNode &other) const {
9858 return true &&
9859 Input_ == other.Input_ &&
9860 predicate_ == other.predicate_ &&
9861 getType(0) == other.getType(0);
9862}
9863
9864Node* ReluNode::clone() const {
9865 return new ReluNode(getName(), getResult().getType(), getInput());
9866}
9867
9868llvm::hash_code ReluNode::getHash() const {
9869 return llvm::hash_combine(
9870 Input_);
9871}
9872
9873ReluGradNode *ReluNode::getGrad(GraphGradMapper &builder) {
9874 auto *x = new ReluGradNode(getName().str() + "_grad", getInput(), getResult(), builder.getGradient(getResult()));
9875 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
9876 return x;
9877}
9878
9879unsigned HardSwishNode::getNumInputs() const {
9880 return 1;
9881}
9882
9883std::string HardSwishNode::getInputName(unsigned idx) const {
9884 if (idx == 0) { return "Input"; }
9885 idx -= 1;
9886 llvm_unreachable("Invalid index");
9887}
9888
9889NodeValue HardSwishNode::getNthInput(unsigned idx) {
9890 if (idx == 0) { return Input_; }
9891 idx -= 1;
9892 llvm_unreachable("Invalid index");
9893}
9894
9895void HardSwishNode::setNthInput(unsigned idx, NodeValue val) {
9896 if (idx == 0) { Input_ = val; return; }
9897 idx -= 1;
9898 llvm_unreachable("Invalid index");
9899}
9900
9901llvm::StringRef HardSwishNode::getOutputName(unsigned idx) const {
9902 if (idx == 0) { return "Result"; }
9903 llvm_unreachable("Invalid index");
9904}
9905
9906std::string HardSwishNode::getDebugDesc() const {
9907 DescriptionBuilder db(getKindName());
9908 db.addParam("Name", separateString(getName(), 100, "\n"));
9909 if (hasPredicate()) db.addParam("Predicate", "Yes");
9910 db
9911 .addParam("Input", *(getInput().getType()))
9912 .addParam("Users", getNumUsers());
9913 db.addParam("Result", *(getResult().getType()));
9914 return db;
9915}
9916
9917void HardSwishNode::visit(Node *parent, NodeWalker *visitor) {
9918 if (!visitor->shouldVisit(parent, this)) { return; }
9919 visitor->pre(parent, this);
9920if (hasPredicate())
9921 getPredicate().getNode()->visit(this, visitor);
9922 getInput().getNode()->visit(this, visitor);
9923 visitor->post(parent, this);
9924}
9925
9926bool HardSwishNode::isEqual(const HardSwishNode &other) const {
9927 return true &&
9928 Input_ == other.Input_ &&
9929 predicate_ == other.predicate_ &&
9930 getType(0) == other.getType(0);
9931}
9932
9933Node* HardSwishNode::clone() const {
9934 return new HardSwishNode(getName(), getResult().getType(), getInput());
9935}
9936
9937llvm::hash_code HardSwishNode::getHash() const {
9938 return llvm::hash_combine(
9939 Input_);
9940}
9941
9942unsigned GeluNode::getNumInputs() const {
9943 return 1;
9944}
9945
9946std::string GeluNode::getInputName(unsigned idx) const {
9947 if (idx == 0) { return "Input"; }
9948 idx -= 1;
9949 llvm_unreachable("Invalid index");
9950}
9951
9952NodeValue GeluNode::getNthInput(unsigned idx) {
9953 if (idx == 0) { return Input_; }
9954 idx -= 1;
9955 llvm_unreachable("Invalid index");
9956}
9957
9958void GeluNode::setNthInput(unsigned idx, NodeValue val) {
9959 if (idx == 0) { Input_ = val; return; }
9960 idx -= 1;
9961 llvm_unreachable("Invalid index");
9962}
9963
9964llvm::StringRef GeluNode::getOutputName(unsigned idx) const {
9965 if (idx == 0) { return "Result"; }
9966 llvm_unreachable("Invalid index");
9967}
9968
9969std::string GeluNode::getDebugDesc() const {
9970 DescriptionBuilder db(getKindName());
9971 db.addParam("Name", separateString(getName(), 100, "\n"));
9972 if (hasPredicate()) db.addParam("Predicate", "Yes");
9973 db
9974 .addParam("Input", *(getInput().getType()))
9975 .addParam("Users", getNumUsers());
9976 db.addParam("Result", *(getResult().getType()));
9977 return db;
9978}
9979
9980void GeluNode::visit(Node *parent, NodeWalker *visitor) {
9981 if (!visitor->shouldVisit(parent, this)) { return; }
9982 visitor->pre(parent, this);
9983if (hasPredicate())
9984 getPredicate().getNode()->visit(this, visitor);
9985 getInput().getNode()->visit(this, visitor);
9986 visitor->post(parent, this);
9987}
9988
9989bool GeluNode::isEqual(const GeluNode &other) const {
9990 return true &&
9991 Input_ == other.Input_ &&
9992 predicate_ == other.predicate_ &&
9993 getType(0) == other.getType(0);
9994}
9995
9996Node* GeluNode::clone() const {
9997 return new GeluNode(getName(), getResult().getType(), getInput());
9998}
9999
10000llvm::hash_code GeluNode::getHash() const {
10001 return llvm::hash_combine(
10002 Input_);
10003}
10004
10005unsigned ClipNode::getNumInputs() const {
10006 return 1;
10007}
10008
10009std::string ClipNode::getInputName(unsigned idx) const {
10010 if (idx == 0) { return "Input"; }
10011 idx -= 1;
10012 llvm_unreachable("Invalid index");
10013}
10014
10015NodeValue ClipNode::getNthInput(unsigned idx) {
10016 if (idx == 0) { return Input_; }
10017 idx -= 1;
10018 llvm_unreachable("Invalid index");
10019}
10020
10021void ClipNode::setNthInput(unsigned idx, NodeValue val) {
10022 if (idx == 0) { Input_ = val; return; }
10023 idx -= 1;
10024 llvm_unreachable("Invalid index");
10025}
10026
10027llvm::StringRef ClipNode::getOutputName(unsigned idx) const {
10028 if (idx == 0) { return "Result"; }
10029 llvm_unreachable("Invalid index");
10030}
10031
10032std::string ClipNode::getDebugDesc() const {
10033 DescriptionBuilder db(getKindName());
10034 db.addParam("Name", separateString(getName(), 100, "\n"));
10035 if (hasPredicate()) db.addParam("Predicate", "Yes");
10036 db
10037 .addParam("Input", *(getInput().getType()))
10038 .addParam("Min", getMin())
10039 .addParam("Max", getMax())
10040 .addParam("Users", getNumUsers());
10041 db.addParam("Result", *(getResult().getType()));
10042 return db;
10043}
10044
10045void ClipNode::visit(Node *parent, NodeWalker *visitor) {
10046 if (!visitor->shouldVisit(parent, this)) { return; }
10047 visitor->pre(parent, this);
10048if (hasPredicate())
10049 getPredicate().getNode()->visit(this, visitor);
10050 getInput().getNode()->visit(this, visitor);
10051 visitor->post(parent, this);
10052}
10053
10054bool ClipNode::isEqual(const ClipNode &other) const {
10055 return true &&
10056 Input_ == other.Input_ &&
10057 predicate_ == other.predicate_ &&
10058 Min_ == other.Min_ &&
10059 Max_ == other.Max_ &&
10060 getType(0) == other.getType(0);
10061}
10062
10063Node* ClipNode::clone() const {
10064 return new ClipNode(getName(), getResult().getType(), getInput(), getMin(), getMax());
10065}
10066
10067llvm::hash_code ClipNode::getHash() const {
10068 return llvm::hash_combine(
10069 toBinary(Min_),
10070 toBinary(Max_),
10071 Input_);
10072}
10073
10074unsigned PReluNode::getNumInputs() const {
10075 return 2;
10076}
10077
10078std::string PReluNode::getInputName(unsigned idx) const {
10079 if (idx == 0) { return "Input"; }
10080 if (idx == 1) { return "Slope"; }
10081 idx -= 2;
10082 llvm_unreachable("Invalid index");
10083}
10084
10085NodeValue PReluNode::getNthInput(unsigned idx) {
10086 if (idx == 0) { return Input_; }
10087 if (idx == 1) { return Slope_; }
10088 idx -= 2;
10089 llvm_unreachable("Invalid index");
10090}
10091
10092void PReluNode::setNthInput(unsigned idx, NodeValue val) {
10093 if (idx == 0) { Input_ = val; return; }
10094 if (idx == 1) { Slope_ = val; return; }
10095 idx -= 2;
10096 llvm_unreachable("Invalid index");
10097}
10098
10099llvm::StringRef PReluNode::getOutputName(unsigned idx) const {
10100 if (idx == 0) { return "Result"; }
10101 llvm_unreachable("Invalid index");
10102}
10103
10104std::string PReluNode::getDebugDesc() const {
10105 DescriptionBuilder db(getKindName());
10106 db.addParam("Name", separateString(getName(), 100, "\n"));
10107 if (hasPredicate()) db.addParam("Predicate", "Yes");
10108 db
10109 .addParam("Input", *(getInput().getType()))
10110 .addParam("Slope", *(getSlope().getType()))
10111 .addParam("Users", getNumUsers());
10112 db.addParam("Result", *(getResult().getType()));
10113 return db;
10114}
10115
10116void PReluNode::visit(Node *parent, NodeWalker *visitor) {
10117 if (!visitor->shouldVisit(parent, this)) { return; }
10118 visitor->pre(parent, this);
10119if (hasPredicate())
10120 getPredicate().getNode()->visit(this, visitor);
10121 getInput().getNode()->visit(this, visitor);
10122 getSlope().getNode()->visit(this, visitor);
10123 visitor->post(parent, this);
10124}
10125
10126bool PReluNode::isEqual(const PReluNode &other) const {
10127 return true &&
10128 Input_ == other.Input_ &&
10129 Slope_ == other.Slope_ &&
10130 predicate_ == other.predicate_ &&
10131 getType(0) == other.getType(0);
10132}
10133
10134Node* PReluNode::clone() const {
10135 return new PReluNode(getName(), getResult().getType(), getInput(), getSlope());
10136}
10137
10138llvm::hash_code PReluNode::getHash() const {
10139 return llvm::hash_combine(
10140 Input_,
10141 Slope_);
10142}
10143
10144unsigned SigmoidGradNode::getNumInputs() const {
10145 return 3;
10146}
10147
10148std::string SigmoidGradNode::getInputName(unsigned idx) const {
10149 if (idx == 0) { return "Input"; }
10150 if (idx == 1) { return "OriginalOutputForResult"; }
10151 if (idx == 2) { return "GradOfOriginalOutputNamedResult"; }
10152 idx -= 3;
10153 llvm_unreachable("Invalid index");
10154}
10155
10156NodeValue SigmoidGradNode::getNthInput(unsigned idx) {
10157 if (idx == 0) { return Input_; }
10158 if (idx == 1) { return OriginalOutputForResult_; }
10159 if (idx == 2) { return GradOfOriginalOutputNamedResult_; }
10160 idx -= 3;
10161 llvm_unreachable("Invalid index");
10162}
10163
10164void SigmoidGradNode::setNthInput(unsigned idx, NodeValue val) {
10165 if (idx == 0) { Input_ = val; return; }
10166 if (idx == 1) { OriginalOutputForResult_ = val; return; }
10167 if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; }
10168 idx -= 3;
10169 llvm_unreachable("Invalid index");
10170}
10171
10172llvm::StringRef SigmoidGradNode::getOutputName(unsigned idx) const {
10173 if (idx == 0) { return "GradOfInputNamedInput"; }
10174 llvm_unreachable("Invalid index");
10175}
10176
10177std::string SigmoidGradNode::getDebugDesc() const {
10178 DescriptionBuilder db(getKindName());
10179 db.addParam("Name", separateString(getName(), 100, "\n"));
10180 if (hasPredicate()) db.addParam("Predicate", "Yes");
10181 db
10182 .addParam("Input", *(getInput().getType()))
10183 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
10184 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
10185 .addParam("Users", getNumUsers());
10186 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
10187 return db;
10188}
10189
10190void SigmoidGradNode::visit(Node *parent, NodeWalker *visitor) {
10191 if (!visitor->shouldVisit(parent, this)) { return; }
10192 visitor->pre(parent, this);
10193if (hasPredicate())
10194 getPredicate().getNode()->visit(this, visitor);
10195 getInput().getNode()->visit(this, visitor);
10196 getOriginalOutputForResult().getNode()->visit(this, visitor);
10197 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
10198 visitor->post(parent, this);
10199}
10200
10201bool SigmoidGradNode::isEqual(const SigmoidGradNode &other) const {
10202 return true &&
10203 Input_ == other.Input_ &&
10204 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
10205 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
10206 predicate_ == other.predicate_ &&
10207 getType(0) == other.getType(0);
10208}
10209
10210Node* SigmoidGradNode::clone() const {
10211 return new SigmoidGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
10212}
10213
10214llvm::hash_code SigmoidGradNode::getHash() const {
10215 return llvm::hash_combine(
10216 Input_,
10217 OriginalOutputForResult_,
10218 GradOfOriginalOutputNamedResult_);
10219}
10220
10221unsigned SigmoidNode::getNumInputs() const {
10222 return 1;
10223}
10224
10225std::string SigmoidNode::getInputName(unsigned idx) const {
10226 if (idx == 0) { return "Input"; }
10227 idx -= 1;
10228 llvm_unreachable("Invalid index");
10229}
10230
10231NodeValue SigmoidNode::getNthInput(unsigned idx) {
10232 if (idx == 0) { return Input_; }
10233 idx -= 1;
10234 llvm_unreachable("Invalid index");
10235}
10236
10237void SigmoidNode::setNthInput(unsigned idx, NodeValue val) {
10238 if (idx == 0) { Input_ = val; return; }
10239 idx -= 1;
10240 llvm_unreachable("Invalid index");
10241}
10242
10243llvm::StringRef SigmoidNode::getOutputName(unsigned idx) const {
10244 if (idx == 0) { return "Result"; }
10245 llvm_unreachable("Invalid index");
10246}
10247
10248std::string SigmoidNode::getDebugDesc() const {
10249 DescriptionBuilder db(getKindName());
10250 db.addParam("Name", separateString(getName(), 100, "\n"));
10251 if (hasPredicate()) db.addParam("Predicate", "Yes");
10252 db
10253 .addParam("Input", *(getInput().getType()))
10254 .addParam("Users", getNumUsers());
10255 db.addParam("Result", *(getResult().getType()));
10256 return db;
10257}
10258
10259void SigmoidNode::visit(Node *parent, NodeWalker *visitor) {
10260 if (!visitor->shouldVisit(parent, this)) { return; }
10261 visitor->pre(parent, this);
10262if (hasPredicate())
10263 getPredicate().getNode()->visit(this, visitor);
10264 getInput().getNode()->visit(this, visitor);
10265 visitor->post(parent, this);
10266}
10267
10268bool SigmoidNode::isEqual(const SigmoidNode &other) const {
10269 return true &&
10270 Input_ == other.Input_ &&
10271 predicate_ == other.predicate_ &&
10272 getType(0) == other.getType(0);
10273}
10274
10275Node* SigmoidNode::clone() const {
10276 return new SigmoidNode(getName(), getResult().getType(), getInput());
10277}
10278
10279llvm::hash_code SigmoidNode::getHash() const {
10280 return llvm::hash_combine(
10281 Input_);
10282}
10283
10284SigmoidGradNode *SigmoidNode::getGrad(GraphGradMapper &builder) {
10285 auto *x = new SigmoidGradNode(getName().str() + "_grad", getInput(), getResult(), builder.getGradient(getResult()));
10286 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
10287 return x;
10288}
10289
10290unsigned SwishNode::getNumInputs() const {
10291 return 1;
10292}
10293
10294std::string SwishNode::getInputName(unsigned idx) const {
10295 if (idx == 0) { return "Input"; }
10296 idx -= 1;
10297 llvm_unreachable("Invalid index");
10298}
10299
10300NodeValue SwishNode::getNthInput(unsigned idx) {
10301 if (idx == 0) { return Input_; }
10302 idx -= 1;
10303 llvm_unreachable("Invalid index");
10304}
10305
10306void SwishNode::setNthInput(unsigned idx, NodeValue val) {
10307 if (idx == 0) { Input_ = val; return; }
10308 idx -= 1;
10309 llvm_unreachable("Invalid index");
10310}
10311
10312llvm::StringRef SwishNode::getOutputName(unsigned idx) const {
10313 if (idx == 0) { return "Result"; }
10314 llvm_unreachable("Invalid index");
10315}
10316
10317std::string SwishNode::getDebugDesc() const {
10318 DescriptionBuilder db(getKindName());
10319 db.addParam("Name", separateString(getName(), 100, "\n"));
10320 if (hasPredicate()) db.addParam("Predicate", "Yes");
10321 db
10322 .addParam("Input", *(getInput().getType()))
10323 .addParam("Users", getNumUsers());
10324 db.addParam("Result", *(getResult().getType()));
10325 return db;
10326}
10327
10328void SwishNode::visit(Node *parent, NodeWalker *visitor) {
10329 if (!visitor->shouldVisit(parent, this)) { return; }
10330 visitor->pre(parent, this);
10331if (hasPredicate())
10332 getPredicate().getNode()->visit(this, visitor);
10333 getInput().getNode()->visit(this, visitor);
10334 visitor->post(parent, this);
10335}
10336
10337bool SwishNode::isEqual(const SwishNode &other) const {
10338 return true &&
10339 Input_ == other.Input_ &&
10340 predicate_ == other.predicate_ &&
10341 getType(0) == other.getType(0);
10342}
10343
10344Node* SwishNode::clone() const {
10345 return new SwishNode(getName(), getResult().getType(), getInput());
10346}
10347
10348llvm::hash_code SwishNode::getHash() const {
10349 return llvm::hash_combine(
10350 Input_);
10351}
10352
10353unsigned TanhGradNode::getNumInputs() const {
10354 return 3;
10355}
10356
10357std::string TanhGradNode::getInputName(unsigned idx) const {
10358 if (idx == 0) { return "Input"; }
10359 if (idx == 1) { return "OriginalOutputForResult"; }
10360 if (idx == 2) { return "GradOfOriginalOutputNamedResult"; }
10361 idx -= 3;
10362 llvm_unreachable("Invalid index");
10363}
10364
10365NodeValue TanhGradNode::getNthInput(unsigned idx) {
10366 if (idx == 0) { return Input_; }
10367 if (idx == 1) { return OriginalOutputForResult_; }
10368 if (idx == 2) { return GradOfOriginalOutputNamedResult_; }
10369 idx -= 3;
10370 llvm_unreachable("Invalid index");
10371}
10372
10373void TanhGradNode::setNthInput(unsigned idx, NodeValue val) {
10374 if (idx == 0) { Input_ = val; return; }
10375 if (idx == 1) { OriginalOutputForResult_ = val; return; }
10376 if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; }
10377 idx -= 3;
10378 llvm_unreachable("Invalid index");
10379}
10380
10381llvm::StringRef TanhGradNode::getOutputName(unsigned idx) const {
10382 if (idx == 0) { return "GradOfInputNamedInput"; }
10383 llvm_unreachable("Invalid index");
10384}
10385
10386std::string TanhGradNode::getDebugDesc() const {
10387 DescriptionBuilder db(getKindName());
10388 db.addParam("Name", separateString(getName(), 100, "\n"));
10389 if (hasPredicate()) db.addParam("Predicate", "Yes");
10390 db
10391 .addParam("Input", *(getInput().getType()))
10392 .addParam("OriginalOutputForResult", *(getOriginalOutputForResult().getType()))
10393 .addParam("GradOfOriginalOutputNamedResult", *(getGradOfOriginalOutputNamedResult().getType()))
10394 .addParam("Users", getNumUsers());
10395 db.addParam("GradOfInputNamedInput", *(getGradOfInputNamedInput().getType()));
10396 return db;
10397}
10398
10399void TanhGradNode::visit(Node *parent, NodeWalker *visitor) {
10400 if (!visitor->shouldVisit(parent, this)) { return; }
10401 visitor->pre(parent, this);
10402if (hasPredicate())
10403 getPredicate().getNode()->visit(this, visitor);
10404 getInput().getNode()->visit(this, visitor);
10405 getOriginalOutputForResult().getNode()->visit(this, visitor);
10406 getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor);
10407 visitor->post(parent, this);
10408}
10409
10410bool TanhGradNode::isEqual(const TanhGradNode &other) const {
10411 return true &&
10412 Input_ == other.Input_ &&
10413 OriginalOutputForResult_ == other.OriginalOutputForResult_ &&
10414 GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ &&
10415 predicate_ == other.predicate_ &&
10416 getType(0) == other.getType(0);
10417}
10418
10419Node* TanhGradNode::clone() const {
10420 return new TanhGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult());
10421}
10422
10423llvm::hash_code TanhGradNode::getHash() const {
10424 return llvm::hash_combine(
10425 Input_,
10426 OriginalOutputForResult_,
10427 GradOfOriginalOutputNamedResult_);
10428}
10429
10430unsigned TanhNode::getNumInputs() const {
10431 return 1;
10432}
10433
10434std::string TanhNode::getInputName(unsigned idx) const {
10435 if (idx == 0) { return "Input"; }
10436 idx -= 1;
10437 llvm_unreachable("Invalid index");
10438}
10439
10440NodeValue TanhNode::getNthInput(unsigned idx) {
10441 if (idx == 0) { return Input_; }
10442 idx -= 1;
10443 llvm_unreachable("Invalid index");
10444}
10445
10446void TanhNode::setNthInput(unsigned idx, NodeValue val) {
10447 if (idx == 0) { Input_ = val; return; }
10448 idx -= 1;
10449 llvm_unreachable("Invalid index");
10450}
10451
10452llvm::StringRef TanhNode::getOutputName(unsigned idx) const {
10453 if (idx == 0) { return "Result"; }
10454 llvm_unreachable("Invalid index");
10455}
10456
10457std::string TanhNode::getDebugDesc() const {
10458 DescriptionBuilder db(getKindName());
10459 db.addParam("Name", separateString(getName(), 100, "\n"));
10460 if (hasPredicate()) db.addParam("Predicate", "Yes");
10461 db
10462 .addParam("Input", *(getInput().getType()))
10463 .addParam("Users", getNumUsers());
10464 db.addParam("Result", *(getResult().getType()));
10465 return db;
10466}
10467
10468void TanhNode::visit(Node *parent, NodeWalker *visitor) {
10469 if (!visitor->shouldVisit(parent, this)) { return; }
10470 visitor->pre(parent, this);
10471if (hasPredicate())
10472 getPredicate().getNode()->visit(this, visitor);
10473 getInput().getNode()->visit(this, visitor);
10474 visitor->post(parent, this);
10475}
10476
10477bool TanhNode::isEqual(const TanhNode &other) const {
10478 return true &&
10479 Input_ == other.Input_ &&
10480 predicate_ == other.predicate_ &&
10481 getType(0) == other.getType(0);
10482}
10483
10484Node* TanhNode::clone() const {
10485 return new TanhNode(getName(), getResult().getType(), getInput());
10486}
10487
10488llvm::hash_code TanhNode::getHash() const {
10489 return llvm::hash_combine(
10490 Input_);
10491}
10492
10493TanhGradNode *TanhNode::getGrad(GraphGradMapper &builder) {
10494 auto *x = new TanhGradNode(getName().str() + "_grad", getInput(), getResult(), builder.getGradient(getResult()));
10495 builder.addGradient(getInput(), x->getGradOfInputNamedInput());
10496 return x;
10497}
10498
10499unsigned LeakyReluNode::getNumInputs() const {
10500 return 1;
10501}
10502
10503std::string LeakyReluNode::getInputName(unsigned idx) const {
10504 if (idx == 0) { return "Input"; }
10505 idx -= 1;
10506 llvm_unreachable("Invalid index");
10507}
10508
10509NodeValue LeakyReluNode::getNthInput(unsigned idx) {
10510 if (idx == 0) { return Input_; }
10511 idx -= 1;
10512 llvm_unreachable("Invalid index");
10513}
10514
10515void LeakyReluNode::setNthInput(unsigned idx, NodeValue val) {
10516 if (idx == 0) { Input_ = val; return; }
10517 idx -= 1;
10518 llvm_unreachable("Invalid index");
10519}
10520
10521llvm::StringRef LeakyReluNode::getOutputName(unsigned idx) const {
10522 if (idx == 0) { return "Result"; }
10523 llvm_unreachable("Invalid index");
10524}
10525
10526std::string LeakyReluNode::getDebugDesc() const {
10527 DescriptionBuilder db(getKindName());
10528 db.addParam("Name", separateString(getName(), 100, "\n"));
10529 if (hasPredicate()) db.addParam("Predicate", "Yes");
10530 db
10531 .addParam("Input", *(getInput().getType()))
10532 .addParam("Alpha", getAlpha())
10533 .addParam("Users", getNumUsers());
10534 db.addParam("Result", *(getResult().getType()));
10535 return db;
10536}
10537
10538void LeakyReluNode::visit(Node *parent, NodeWalker *visitor) {
10539 if (!visitor->shouldVisit(parent, this)) { return; }
10540 visitor->pre(parent, this);
10541if (hasPredicate())
10542 getPredicate().getNode()->visit(this, visitor);
10543 getInput().getNode()->visit(this, visitor);
10544 visitor->post(parent, this);
10545}
10546
10547bool LeakyReluNode::isEqual(const LeakyReluNode &other) const {
10548 return true &&
10549 Input_ == other.Input_ &&
10550 predicate_ == other.predicate_ &&
10551 Alpha_ == other.Alpha_ &&
10552 getType(0) == other.getType(0);
10553}
10554
10555Node* LeakyReluNode::clone() const {
10556 return new LeakyReluNode(getName(), getResult().getType(), getInput(), getAlpha());
10557}
10558
10559llvm::hash_code LeakyReluNode::getHash() const {
10560 return llvm::hash_combine(
10561 toBinary(Alpha_),
10562 Input_);
10563}
10564
10565unsigned SoftPlusNode::getNumInputs() const {
10566 return 1;
10567}
10568
10569std::string SoftPlusNode::getInputName(unsigned idx) const {
10570 if (idx == 0) { return "Input"; }
10571 idx -= 1;
10572 llvm_unreachable("Invalid index");
10573}
10574
10575NodeValue SoftPlusNode::getNthInput(unsigned idx) {
10576 if (idx == 0) { return Input_; }
10577 idx -= 1;
10578 llvm_unreachable("Invalid index");
10579}
10580
10581void SoftPlusNode::setNthInput(unsigned idx, NodeValue val) {
10582 if (idx == 0) { Input_ = val; return; }
10583 idx -= 1;
10584 llvm_unreachable("Invalid index");
10585}
10586
10587llvm::StringRef SoftPlusNode::getOutputName(unsigned idx) const {
10588 if (idx == 0) { return "Result"; }
10589 llvm_unreachable("Invalid index");
10590}
10591
10592std::string SoftPlusNode::getDebugDesc() const {
10593 DescriptionBuilder db(getKindName());
10594 db.addParam("Name", separateString(getName(), 100, "\n"));
10595 if (hasPredicate()) db.addParam("Predicate", "Yes");
10596 db
10597 .addParam("Input", *(getInput().getType()))
10598 .addParam("Users", getNumUsers());
10599 db.addParam("Result", *(getResult().getType()));
10600 return db;
10601}
10602
10603void SoftPlusNode::visit(Node *parent, NodeWalker *visitor) {
10604 if (!visitor->shouldVisit(parent, this)) { return; }
10605 visitor->pre(parent, this);
10606if (hasPredicate())
10607 getPredicate().getNode()->visit(this, visitor);
10608 getInput().getNode()->visit(this, visitor);
10609 visitor->post(parent, this);
10610}
10611
10612bool SoftPlusNode::isEqual(const SoftPlusNode &other) const {
10613 return true &&
10614 Input_ == other.Input_ &&
10615 predicate_ == other.predicate_ &&
10616 getType(0) == other.getType(0);
10617}
10618
10619Node* SoftPlusNode::clone() const {
10620 return new SoftPlusNode(getName(), getResult().getType(), getInput());
10621}
10622
10623llvm::hash_code SoftPlusNode::getHash() const {
10624 return llvm::hash_combine(
10625 Input_);
10626}
10627
10628unsigned ReshapeNode::getNumInputs() const {
10629 return 1;
10630}
10631
10632std::string ReshapeNode::getInputName(unsigned idx) const {
10633 if (idx == 0) { return "Input"; }
10634 idx -= 1;
10635 llvm_unreachable("Invalid index");
10636}
10637
10638NodeValue ReshapeNode::getNthInput(unsigned idx) {
10639 if (idx == 0) { return Input_; }
10640 idx -= 1;
10641 llvm_unreachable("Invalid index");
10642}
10643
10644void ReshapeNode::setNthInput(unsigned idx, NodeValue val) {
10645 if (idx == 0) { Input_ = val; return; }
10646 idx -= 1;
10647 llvm_unreachable("Invalid index");
10648}
10649
10650llvm::StringRef ReshapeNode::getOutputName(unsigned idx) const {
10651 if (idx == 0) { return "Result"; }
10652 llvm_unreachable("Invalid index");
10653}
10654
10655std::string ReshapeNode::getDebugDesc() const {
10656 DescriptionBuilder db(getKindName());
10657 db.addParam("Name", separateString(getName(), 100, "\n"));
10658 if (hasPredicate()) db.addParam("Predicate", "Yes");
10659 db
10660 .addParam("Input", *(getInput().getType()))
10661 .addParam("Dims", getDims())
10662 .addParam("Layout", getLayout())
10663 .addParam("Users", getNumUsers());
10664 db.addParam("Result", *(getResult().getType()));
10665 return db;
10666}
10667
10668void ReshapeNode::visit(Node *parent, NodeWalker *visitor) {
10669 if (!visitor->shouldVisit(parent, this)) { return; }
10670 visitor->pre(parent, this);
10671if (hasPredicate())
10672 getPredicate().getNode()->visit(this, visitor);
10673 getInput().getNode()->visit(this, visitor);
10674 visitor->post(parent, this);
10675}
10676
10677bool ReshapeNode::isEqual(const ReshapeNode &other) const {
10678 return true &&
10679 Input_ == other.Input_ &&
10680 predicate_ == other.predicate_ &&
10681 Dims_ == other.Dims_ &&
10682 Layout_ == other.Layout_ &&
10683 getType(0) == other.getType(0);
10684}
10685
10686Node* ReshapeNode::clone() const {
10687 return new ReshapeNode(getName(), getResult().getType(), getInput(), getDims(), getLayout());
10688}
10689
10690llvm::hash_code ReshapeNode::getHash() const {
10691 return llvm::hash_combine(
10692 llvm::hash_combine_range(Dims_.begin(), Dims_.end()),
10693 Layout_,
10694 Input_);
10695}
10696
10697unsigned TransposeNode::getNumInputs() const {
10698 return 1;
10699}
10700
10701std::string TransposeNode::getInputName(unsigned idx) const {
10702 if (idx == 0) { return "Input"; }
10703 idx -= 1;
10704 llvm_unreachable("Invalid index");
10705}
10706
10707NodeValue TransposeNode::getNthInput(unsigned idx) {
10708 if (idx == 0) { return Input_; }
10709 idx -= 1;
10710 llvm_unreachable("Invalid index");
10711}
10712
10713void TransposeNode::setNthInput(unsigned idx, NodeValue val) {
10714 if (idx == 0) { Input_ = val; return; }
10715 idx -= 1;
10716 llvm_unreachable("Invalid index");
10717}
10718
10719llvm::StringRef TransposeNode::getOutputName(unsigned idx) const {
10720 if (idx == 0) { return "Result"; }
10721 llvm_unreachable("Invalid index");
10722}
10723
10724std::string TransposeNode::getDebugDesc() const {
10725 DescriptionBuilder db(getKindName());
10726 db.addParam("Name", separateString(getName(), 100, "\n"));
10727 if (hasPredicate()) db.addParam("Predicate", "Yes");
10728 db
10729 .addParam("Input", *(getInput().getType()))
10730 .addParam("Shuffle", getShuffle())
10731 .addParam("Layout", getLayout())
10732 .addParam("Users", getNumUsers());
10733 db.addParam("Result", *(getResult().getType()));
10734 return db;
10735}
10736
10737void TransposeNode::visit(Node *parent, NodeWalker *visitor) {
10738 if (!visitor->shouldVisit(parent, this)) { return; }
10739 visitor->pre(parent, this);
10740if (hasPredicate())
10741 getPredicate().getNode()->visit(this, visitor);
10742 getInput().getNode()->visit(this, visitor);
10743 visitor->post(parent, this);
10744}
10745
10746bool TransposeNode::isEqual(const TransposeNode &other) const {
10747 return true &&
10748 Input_ == other.Input_ &&
10749 predicate_ == other.predicate_ &&
10750 Shuffle_ == other.Shuffle_ &&
10751 Layout_ == other.Layout_ &&
10752 getType(0) == other.getType(0);
10753}
10754
10755Node* TransposeNode::clone() const {
10756 return new TransposeNode(getName(), getResult().getType(), getInput(), getShuffle(), getLayout());
10757}
10758
10759llvm::hash_code TransposeNode::getHash() const {
10760 return llvm::hash_combine(
10761 llvm::hash_combine_range(Shuffle_.begin(), Shuffle_.end()),
10762 Layout_,
10763 Input_);
10764}
10765
10766unsigned ConcatNode::getNumInputs() const {
10767 return 0 + Inputs_.size();
10768}
10769
10770std::string ConcatNode::getInputName(unsigned idx) const {
10771 idx -= 0;
10772 if (idx < Inputs_.size()) { return "Inputs" + std::to_string(idx); }
10773 idx -= Inputs_.size();
10774 llvm_unreachable("Invalid index");
10775}
10776
10777NodeValue ConcatNode::getNthInput(unsigned idx) {
10778 idx -= 0;
10779 if (idx < Inputs_.size()) { return Inputs_[idx]; }
10780 idx -= Inputs_.size();
10781 llvm_unreachable("Invalid index");
10782}
10783
10784void ConcatNode::setNthInput(unsigned idx, NodeValue val) {
10785 idx -= 0;
10786 if (idx < Inputs_.size()) { Inputs_[idx] = val; return; }
10787 idx -= Inputs_.size();
10788 llvm_unreachable("Invalid index");
10789}
10790
10791llvm::StringRef ConcatNode::getOutputName(unsigned idx) const {
10792 if (idx == 0) { return "Result"; }
10793 llvm_unreachable("Invalid index");
10794}
10795
10796std::string ConcatNode::getDebugDesc() const {
10797 DescriptionBuilder db(getKindName());
10798 db.addParam("Name", separateString(getName(), 100, "\n"));
10799 if (hasPredicate()) db.addParam("Predicate", "Yes");
10800 db
10801 .addParam("Dim", getDim())
10802 .addParam("Users", getNumUsers());
10803 {
10804 unsigned mIndex = 0;
10805 for (const auto &II : getInputs()) {
10806 db.addParam("Inputs"+std::to_string(mIndex++), *II.getType());
10807 }
10808 }
10809 db.addParam("Result", *(getResult().getType()));
10810 return db;
10811}
10812
10813void ConcatNode::visit(Node *parent, NodeWalker *visitor) {
10814 if (!visitor->shouldVisit(parent, this)) { return; }
10815 visitor->pre(parent, this);
10816if (hasPredicate())
10817 getPredicate().getNode()->visit(this, visitor);
10818 for (auto &I : Inputs_) { I.getNode()->visit(this, visitor); }
10819 visitor->post(parent, this);
10820}
10821
10822bool ConcatNode::isEqual(const ConcatNode &other) const {
10823 return true &&
10824 predicate_ == other.predicate_ &&
10825 Inputs_ == other.Inputs_ &&
10826 Dim_ == other.Dim_ &&
10827 getType(0) == other.getType(0);
10828}
10829
10830Node* ConcatNode::clone() const {
10831 return new ConcatNode(getName(), getResult().getType(), getInputs(), getDim());
10832}
10833
10834llvm::hash_code ConcatNode::getHash() const {
10835 return llvm::hash_combine(
10836 llvm::hash_combine_range(Inputs_.begin(), Inputs_.end()),
10837 Dim_);
10838}
10839
10840unsigned SliceNode::getNumInputs() const {
10841 return 1;
10842}
10843
10844std::string SliceNode::getInputName(unsigned idx) const {
10845 if (idx == 0) { return "Input"; }
10846 idx -= 1;
10847 llvm_unreachable("Invalid index");
10848}
10849
10850NodeValue SliceNode::getNthInput(unsigned idx) {
10851 if (idx == 0) { return Input_; }
10852 idx -= 1;
10853 llvm_unreachable("Invalid index");
10854}
10855
10856void SliceNode::setNthInput(unsigned idx, NodeValue val) {
10857 if (idx == 0) { Input_ = val; return; }
10858 idx -= 1;
10859 llvm_unreachable("Invalid index");
10860}
10861
10862llvm::StringRef SliceNode::getOutputName(unsigned idx) const {
10863 if (idx == 0) { return "Result"; }
10864 llvm_unreachable("Invalid index");
10865}
10866
10867std::string SliceNode::getDebugDesc() const {
10868 DescriptionBuilder db(getKindName());
10869 db.addParam("Name", separateString(getName(), 100, "\n"));
10870 if (hasPredicate()) db.addParam("Predicate", "Yes");
10871 db
10872 .addParam("Input", *(getInput().getType()))
10873 .addParam("Start", getStart())
10874 .addParam("Users", getNumUsers());
10875 db.addParam("Result", *(getResult().getType()));
10876 return db;
10877}
10878
10879void SliceNode::visit(Node *parent, NodeWalker *visitor) {
10880 if (!visitor->shouldVisit(parent, this)) { return; }
10881 visitor->pre(parent, this);
10882if (hasPredicate())
10883 getPredicate().getNode()->visit(this, visitor);
10884 getInput().getNode()->visit(this, visitor);
10885 visitor->post(parent, this);
10886}
10887
10888bool SliceNode::isEqual(const SliceNode &other) const {
10889 return true &&
10890 Input_ == other.Input_ &&
10891 predicate_ == other.predicate_ &&
10892 Start_ == other.Start_ &&
10893 getType(0) == other.getType(0);
10894}
10895
10896Node* SliceNode::clone() const {
10897 return new SliceNode(getName(), getResult().getType(), getInput(), getStart());
10898}
10899
10900llvm::hash_code SliceNode::getHash() const {
10901 return llvm::hash_combine(
10902 llvm::hash_combine_range(Start_.begin(), Start_.end()),
10903 Input_);
10904}
10905
10906unsigned InsertTensorNode::getNumInputs() const {
10907 return 2;
10908}
10909
10910std::string InsertTensorNode::getInputName(unsigned idx) const {
10911 if (idx == 0) { return "Big"; }
10912 if (idx == 1) { return "Small"; }
10913 idx -= 2;
10914 llvm_unreachable("Invalid index");
10915}
10916
10917NodeValue InsertTensorNode::getNthInput(unsigned idx) {
10918 if (idx == 0) { return Big_; }
10919 if (idx == 1) { return Small_; }
10920 idx -= 2;
10921 llvm_unreachable("Invalid index");
10922}
10923
10924void InsertTensorNode::setNthInput(unsigned idx, NodeValue val) {
10925 if (idx == 0) { Big_ = val; return; }
10926 if (idx == 1) { Small_ = val; return; }
10927 idx -= 2;
10928 llvm_unreachable("Invalid index");
10929}
10930
10931llvm::StringRef InsertTensorNode::getOutputName(unsigned idx) const {
10932 if (idx == 0) { return "Result"; }
10933 llvm_unreachable("Invalid index");
10934}
10935
10936std::string InsertTensorNode::getDebugDesc() const {
10937 DescriptionBuilder db(getKindName());
10938 db.addParam("Name", separateString(getName(), 100, "\n"));
10939 if (hasPredicate()) db.addParam("Predicate", "Yes");
10940 db
10941 .addParam("Big", *(getBig().getType()))
10942 .addParam("Small", *(getSmall().getType()))
10943 .addParam("Start", getStart())
10944 .addParam("Count", getCount())
10945 .addParam("Axis", getAxis())
10946 .addParam("Users", getNumUsers());
10947 db.addParam("Result", *(getResult().getType()));
10948 return db;
10949}
10950
10951void InsertTensorNode::visit(Node *parent, NodeWalker *visitor) {
10952 if (!visitor->shouldVisit(parent, this)) { return; }
10953 visitor->pre(parent, this);
10954if (hasPredicate())
10955 getPredicate().getNode()->visit(this, visitor);
10956 getBig().getNode()->visit(this, visitor);
10957 getSmall().getNode()->visit(this, visitor);
10958 visitor->post(parent, this);
10959}
10960
10961bool InsertTensorNode::isEqual(const InsertTensorNode &other) const {
10962 return true &&
10963 Big_ == other.Big_ &&
10964 Small_ == other.Small_ &&
10965 predicate_ == other.predicate_ &&
10966 Start_ == other.Start_ &&
10967 Count_ == other.Count_ &&
10968 Axis_ == other.Axis_ &&
10969 getType(0) == other.getType(0);
10970}
10971
10972Node* InsertTensorNode::clone() const {
10973 return new InsertTensorNode(getName(), getBig(), getSmall(), getStart(), getCount(), getAxis());
10974}
10975
10976llvm::hash_code InsertTensorNode::getHash() const {
10977 return llvm::hash_combine(
10978 llvm::hash_combine_range(Start_.begin(), Start_.end()),
10979 Count_,
10980 Axis_,
10981 Big_,
10982 Small_);
10983}
10984
10985unsigned GatherNode::getNumInputs() const {
10986 return 2;
10987}
10988
10989std::string GatherNode::getInputName(unsigned idx) const {
10990 if (idx == 0) { return "Data"; }
10991 if (idx == 1) { return "Indices"; }
10992 idx -= 2;
10993 llvm_unreachable("Invalid index");
10994}
10995
10996NodeValue GatherNode::getNthInput(unsigned idx) {
10997 if (idx == 0) { return Data_; }
10998 if (idx == 1) { return Indices_; }
10999 idx -= 2;
11000 llvm_unreachable("Invalid index");
11001}
11002
11003void GatherNode::setNthInput(unsigned idx, NodeValue val) {
11004 if (idx == 0) { Data_ = val; return; }
11005 if (idx == 1) { Indices_ = val; return; }
11006 idx -= 2;
11007 llvm_unreachable("Invalid index");
11008}
11009
11010llvm::StringRef GatherNode::getOutputName(unsigned idx) const {
11011 if (idx == 0) { return "Result"; }
11012 llvm_unreachable("Invalid index");
11013}
11014
11015std::string GatherNode::getDebugDesc() const {
11016 DescriptionBuilder db(getKindName());
11017 db.addParam("Name", separateString(getName(), 100, "\n"));
11018 if (hasPredicate()) db.addParam("Predicate", "Yes");
11019 db
11020 .addParam("Data", *(getData().getType()))
11021 .addParam("Indices", *(getIndices().getType()))
11022 .addParam("BatchDims", getBatchDims())
11023 .addParam("Users", getNumUsers());
11024 db.addParam("Result", *(getResult().getType()));
11025 return db;
11026}
11027
11028void GatherNode::visit(Node *parent, NodeWalker *visitor) {
11029 if (!visitor->shouldVisit(parent, this)) { return; }
11030 visitor->pre(parent, this);
11031if (hasPredicate())
11032 getPredicate().getNode()->visit(this, visitor);
11033 getData().getNode()->visit(this, visitor);
11034 getIndices().getNode()->visit(this, visitor);
11035 visitor->post(parent, this);
11036}
11037
11038bool GatherNode::isEqual(const GatherNode &other) const {
11039 return true &&
11040 Data_ == other.Data_ &&
11041 Indices_ == other.Indices_ &&
11042 predicate_ == other.predicate_ &&
11043 BatchDims_ == other.BatchDims_ &&
11044 getType(0) == other.getType(0);
11045}
11046
11047Node* GatherNode::clone() const {
11048 return new GatherNode(getName(), getResult().getType(), getData(), getIndices(), getBatchDims());
11049}
11050
11051llvm::hash_code GatherNode::getHash() const {
11052 return llvm::hash_combine(
11053 BatchDims_,
11054 Data_,
11055 Indices_);
11056}
11057
11058unsigned GatherNDNode::getNumInputs() const {
11059 return 2;
11060}
11061
11062std::string GatherNDNode::getInputName(unsigned idx) const {
11063 if (idx == 0) { return "Data"; }
11064 if (idx == 1) { return "Indices"; }
11065 idx -= 2;
11066 llvm_unreachable("Invalid index");
11067}
11068
11069NodeValue GatherNDNode::getNthInput(unsigned idx) {
11070 if (idx == 0) { return Data_; }
11071 if (idx == 1) { return Indices_; }
11072 idx -= 2;
11073 llvm_unreachable("Invalid index");
11074}
11075
11076void GatherNDNode::setNthInput(unsigned idx, NodeValue val) {
11077 if (idx == 0) { Data_ = val; return; }
11078 if (idx == 1) { Indices_ = val; return; }
11079 idx -= 2;
11080 llvm_unreachable("Invalid index");
11081}
11082
11083llvm::StringRef GatherNDNode::getOutputName(unsigned idx) const {
11084 if (idx == 0) { return "Result"; }
11085 llvm_unreachable("Invalid index");
11086}
11087
11088std::string GatherNDNode::getDebugDesc() const {
11089 DescriptionBuilder db(getKindName());
11090 db.addParam("Name", separateString(getName(), 100, "\n"));
11091 if (hasPredicate()) db.addParam("Predicate", "Yes");
11092 db
11093 .addParam("Data", *(getData().getType()))
11094 .addParam("Indices", *(getIndices().getType()))
11095 .addParam("BatchDims", getBatchDims())
11096 .addParam("Users", getNumUsers());
11097 db.addParam("Result", *(getResult().getType()));
11098 return db;
11099}
11100
11101void GatherNDNode::visit(Node *parent, NodeWalker *visitor) {
11102 if (!visitor->shouldVisit(parent, this)) { return; }
11103 visitor->pre(parent, this);
11104if (hasPredicate())
11105 getPredicate().getNode()->visit(this, visitor);
11106 getData().getNode()->visit(this, visitor);
11107 getIndices().getNode()->visit(this, visitor);
11108 visitor->post(parent, this);
11109}
11110
11111bool GatherNDNode::isEqual(const GatherNDNode &other) const {
11112 return true &&
11113 Data_ == other.Data_ &&
11114 Indices_ == other.Indices_ &&
11115 predicate_ == other.predicate_ &&
11116 BatchDims_ == other.BatchDims_ &&
11117 getType(0) == other.getType(0);
11118}
11119
11120Node* GatherNDNode::clone() const {
11121 return new GatherNDNode(getName(), getResult().getType(), getData(), getIndices(), getBatchDims());
11122}
11123
11124llvm::hash_code GatherNDNode::getHash() const {
11125 return llvm::hash_combine(
11126 BatchDims_,
11127 Data_,
11128 Indices_);
11129}
11130
11131unsigned GatherElementsNode::getNumInputs() const {
11132 return 2;
11133}
11134
11135std::string GatherElementsNode::getInputName(unsigned idx) const {
11136 if (idx == 0) { return "Data"; }
11137 if (idx == 1) { return "Indices"; }
11138 idx -= 2;
11139 llvm_unreachable("Invalid index");
11140}
11141
11142NodeValue GatherElementsNode::getNthInput(unsigned idx) {
11143 if (idx == 0) { return Data_; }
11144 if (idx == 1) { return Indices_; }
11145 idx -= 2;
11146 llvm_unreachable("Invalid index");
11147}
11148
11149void GatherElementsNode::setNthInput(unsigned idx, NodeValue val) {
11150 if (idx == 0) { Data_ = val; return; }
11151 if (idx == 1) { Indices_ = val; return; }
11152 idx -= 2;
11153 llvm_unreachable("Invalid index");
11154}
11155
11156llvm::StringRef GatherElementsNode::getOutputName(unsigned idx) const {
11157 if (idx == 0) { return "Result"; }
11158 llvm_unreachable("Invalid index");
11159}
11160
11161std::string GatherElementsNode::getDebugDesc() const {
11162 DescriptionBuilder db(getKindName());
11163 db.addParam("Name", separateString(getName(), 100, "\n"));
11164 if (hasPredicate()) db.addParam("Predicate", "Yes");
11165 db
11166 .addParam("Data", *(getData().getType()))
11167 .addParam("Indices", *(getIndices().getType()))
11168 .addParam("Dim", getDim())
11169 .addParam("Users", getNumUsers());
11170 db.addParam("Result", *(getResult().getType()));
11171 return db;
11172}
11173
11174void GatherElementsNode::visit(Node *parent, NodeWalker *visitor) {
11175 if (!visitor->shouldVisit(parent, this)) { return; }
11176 visitor->pre(parent, this);
11177if (hasPredicate())
11178 getPredicate().getNode()->visit(this, visitor);
11179 getData().getNode()->visit(this, visitor);
11180 getIndices().getNode()->visit(this, visitor);
11181 visitor->post(parent, this);
11182}
11183
11184bool GatherElementsNode::isEqual(const GatherElementsNode &other) const {
11185 return true &&
11186 Data_ == other.Data_ &&
11187 Indices_ == other.Indices_ &&
11188 predicate_ == other.predicate_ &&
11189 Dim_ == other.Dim_ &&
11190 getType(0) == other.getType(0);
11191}
11192
11193Node* GatherElementsNode::clone() const {
11194 return new GatherElementsNode(getName(), getResult().getType(), getData(), getIndices(), getDim());
11195}
11196
11197llvm::hash_code GatherElementsNode::getHash() const {
11198 return llvm::hash_combine(
11199 Dim_,
11200 Data_,
11201 Indices_);
11202}
11203
11204unsigned GatherRangesNode::getNumInputs() const {
11205 return 2;
11206}
11207
11208std::string GatherRangesNode::getInputName(unsigned idx) const {
11209 if (idx == 0) { return "Data"; }
11210 if (idx == 1) { return "Ranges"; }
11211 idx -= 2;
11212 llvm_unreachable("Invalid index");
11213}
11214
11215NodeValue GatherRangesNode::getNthInput(unsigned idx) {
11216 if (idx == 0) { return Data_; }
11217 if (idx == 1) { return Ranges_; }
11218 idx -= 2;
11219 llvm_unreachable("Invalid index");
11220}
11221
11222void GatherRangesNode::setNthInput(unsigned idx, NodeValue val) {
11223 if (idx == 0) { Data_ = val; return; }
11224 if (idx == 1) { Ranges_ = val; return; }
11225 idx -= 2;
11226 llvm_unreachable("Invalid index");
11227}
11228
11229llvm::StringRef GatherRangesNode::getOutputName(unsigned idx) const {
11230 if (idx == 0) { return "Output"; }
11231 if (idx == 1) { return "Lengths"; }
11232 llvm_unreachable("Invalid index");
11233}
11234
11235std::string GatherRangesNode::getDebugDesc() const {
11236 DescriptionBuilder db(getKindName());
11237 db.addParam("Name", separateString(getName(), 100, "\n"));
11238 if (hasPredicate()) db.addParam("Predicate", "Yes");
11239 db
11240 .addParam("Data", *(getData().getType()))
11241 .addParam("Ranges", *(getRanges().getType()))
11242 .addParam("Users", getNumUsers());
11243 db.addParam("Output", *(getOutput().getType()));
11244 db.addParam("Lengths", *(getLengths().getType()));
11245 return db;
11246}
11247
11248void GatherRangesNode::visit(Node *parent, NodeWalker *visitor) {
11249 if (!visitor->shouldVisit(parent, this)) { return; }
11250 visitor->pre(parent, this);
11251if (hasPredicate())
11252 getPredicate().getNode()->visit(this, visitor);
11253 getData().getNode()->visit(this, visitor);
11254 getRanges().getNode()->visit(this, visitor);
11255 visitor->post(parent, this);
11256}
11257
11258bool GatherRangesNode::isEqual(const GatherRangesNode &other) const {
11259 return true &&
11260 Data_ == other.Data_ &&
11261 Ranges_ == other.Ranges_ &&
11262 predicate_ == other.predicate_ &&
11263 getType(0) == other.getType(0) &&
11264 getType(1) == other.getType(1);
11265}
11266
11267Node* GatherRangesNode::clone() const {
11268 return new GatherRangesNode(getName(), getOutput().getType(), getLengths().getType(), getData(), getRanges());
11269}
11270
11271llvm::hash_code GatherRangesNode::getHash() const {
11272 return llvm::hash_combine(
11273 Data_,
11274 Ranges_);
11275}
11276
11277unsigned ScatterDataNode::getNumInputs() const {
11278 return 3;
11279}
11280
11281std::string ScatterDataNode::getInputName(unsigned idx) const {
11282 if (idx == 0) { return "Data"; }
11283 if (idx == 1) { return "Indices"; }
11284 if (idx == 2) { return "Slices"; }
11285 idx -= 3;
11286 llvm_unreachable("Invalid index");
11287}
11288
11289NodeValue ScatterDataNode::getNthInput(unsigned idx) {
11290 if (idx == 0) { return Data_; }
11291 if (idx == 1) { return Indices_; }
11292 if (idx == 2) { return Slices_; }
11293 idx -= 3;
11294 llvm_unreachable("Invalid index");
11295}
11296
11297void ScatterDataNode::setNthInput(unsigned idx, NodeValue val) {
11298 if (idx == 0) { Data_ = val; return; }
11299 if (idx == 1) { Indices_ = val; return; }
11300 if (idx == 2) { Slices_ = val; return; }
11301 idx -= 3;
11302 llvm_unreachable("Invalid index");
11303}
11304
11305llvm::StringRef ScatterDataNode::getOutputName(unsigned idx) const {
11306 if (idx == 0) { return "Result"; }
11307 llvm_unreachable("Invalid index");
11308}
11309
11310std::string ScatterDataNode::getDebugDesc() const {
11311 DescriptionBuilder db(getKindName());
11312 db.addParam("Name", separateString(getName(), 100, "\n"));
11313 if (hasPredicate()) db.addParam("Predicate", "Yes");
11314 db
11315 .addParam("Data", *(getData().getType()))
11316 .addParam("Indices", *(getIndices().getType()))
11317 .addParam("Slices", *(getSlices().getType()))
11318 .addParam("Cumulative", getCumulative())
11319 .addParam("Users", getNumUsers());
11320 db.addParam("Result", *(getResult().getType()));
11321 return db;
11322}
11323
11324void ScatterDataNode::visit(Node *parent, NodeWalker *visitor) {
11325 if (!visitor->shouldVisit(parent, this)) { return; }
11326 visitor->pre(parent, this);
11327if (hasPredicate())
11328 getPredicate().getNode()->visit(this, visitor);
11329 getData().getNode()->visit(this, visitor);
11330 getIndices().getNode()->visit(this, visitor);
11331 getSlices().getNode()->visit(this, visitor);
11332 visitor->post(parent, this);
11333}
11334
11335bool ScatterDataNode::isEqual(const ScatterDataNode &other) const {
11336 return true &&
11337 Data_ == other.Data_ &&
11338 Indices_ == other.Indices_ &&
11339 Slices_ == other.Slices_ &&
11340 predicate_ == other.predicate_ &&
11341 Cumulative_ == other.Cumulative_ &&
11342 getType(0) == other.getType(0);
11343}
11344
11345Node* ScatterDataNode::clone() const {
11346 return new ScatterDataNode(getName(), getData(), getIndices(), getSlices(), getCumulative());
11347}
11348
11349llvm::hash_code ScatterDataNode::getHash() const {
11350 return llvm::hash_combine(
11351 Cumulative_,
11352 Data_,
11353 Indices_,
11354 Slices_);
11355}
11356
11357unsigned TileNode::getNumInputs() const {
11358 return 1;
11359}
11360
11361std::string TileNode::getInputName(unsigned idx) const {
11362 if (idx == 0) { return "Input"; }
11363 idx -= 1;
11364 llvm_unreachable("Invalid index");
11365}
11366
11367NodeValue TileNode::getNthInput(unsigned idx) {
11368 if (idx == 0) { return Input_; }
11369 idx -= 1;
11370 llvm_unreachable("Invalid index");
11371}
11372
11373void TileNode::setNthInput(unsigned idx, NodeValue val) {
11374 if (idx == 0) { Input_ = val; return; }
11375 idx -= 1;
11376 llvm_unreachable("Invalid index");
11377}
11378
11379llvm::StringRef TileNode::getOutputName(unsigned idx) const {
11380 if (idx == 0) { return "Result"; }
11381 llvm_unreachable("Invalid index");
11382}
11383
11384std::string TileNode::getDebugDesc() const {
11385 DescriptionBuilder db(getKindName());
11386 db.addParam("Name", separateString(getName(), 100, "\n"));
11387 if (hasPredicate()) db.addParam("Predicate", "Yes");
11388 db
11389 .addParam("Input", *(getInput().getType()))
11390 .addParam("Count", getCount())
11391 .addParam("Axis", getAxis())
11392 .addParam("Users", getNumUsers());
11393 db.addParam("Result", *(getResult().getType()));
11394 return db;
11395}
11396
11397void TileNode::visit(Node *parent, NodeWalker *visitor) {
11398 if (!visitor->shouldVisit(parent, this)) { return; }
11399 visitor->pre(parent, this);
11400if (hasPredicate())
11401 getPredicate().getNode()->visit(this, visitor);
11402 getInput().getNode()->visit(this, visitor);
11403 visitor->post(parent, this);
11404}
11405
11406bool TileNode::isEqual(const TileNode &other) const {
11407 return true &&
11408 Input_ == other.Input_ &&
11409 predicate_ == other.predicate_ &&
11410 Count_ == other.Count_ &&
11411 Axis_ == other.Axis_ &&
11412 getType(0) == other.getType(0);
11413}
11414
11415Node* TileNode::clone() const {
11416 return new TileNode(getName(), getResult().getType(), getInput(), getCount(), getAxis());
11417}
11418
11419llvm::hash_code TileNode::getHash() const {
11420 return llvm::hash_combine(
11421 Count_,
11422 Axis_,
11423 Input_);
11424}
11425
11426unsigned BatchOneHotNode::getNumInputs() const {
11427 return 3;
11428}
11429
11430std::string BatchOneHotNode::getInputName(unsigned idx) const {
11431 if (idx == 0) { return "Data"; }
11432 if (idx == 1) { return "Lengths"; }
11433 if (idx == 2) { return "Values"; }
11434 idx -= 3;
11435 llvm_unreachable("Invalid index");
11436}
11437
11438NodeValue BatchOneHotNode::getNthInput(unsigned idx) {
11439 if (idx == 0) { return Data_; }
11440 if (idx == 1) { return Lengths_; }
11441 if (idx == 2) { return Values_; }
11442 idx -= 3;
11443 llvm_unreachable("Invalid index");
11444}
11445
11446void BatchOneHotNode::setNthInput(unsigned idx, NodeValue val) {
11447 if (idx == 0) { Data_ = val; return; }
11448 if (idx == 1) { Lengths_ = val; return; }
11449 if (idx == 2) { Values_ = val; return; }
11450 idx -= 3;
11451 llvm_unreachable("Invalid index");
11452}
11453
11454llvm::StringRef BatchOneHotNode::getOutputName(unsigned idx) const {
11455 if (idx == 0) { return "Result"; }
11456 llvm_unreachable("Invalid index");
11457}
11458
11459std::string BatchOneHotNode::getDebugDesc() const {
11460 DescriptionBuilder db(getKindName());
11461 db.addParam("Name", separateString(getName(), 100, "\n"));
11462 if (hasPredicate()) db.addParam("Predicate", "Yes");
11463 db
11464 .addParam("Data", *(getData().getType()))
11465 .addParam("Lengths", *(getLengths().getType()))
11466 .addParam("Values", *(getValues().getType()))
11467 .addParam("Users", getNumUsers());
11468 db.addParam("Result", *(getResult().getType()));
11469 return db;
11470}
11471
11472void BatchOneHotNode::visit(Node *parent, NodeWalker *visitor) {
11473 if (!visitor->shouldVisit(parent, this)) { return; }
11474 visitor->pre(parent, this);
11475if (hasPredicate())
11476 getPredicate().getNode()->visit(this, visitor);
11477 getData().getNode()->visit(this, visitor);
11478 getLengths().getNode()->visit(this, visitor);
11479 getValues().getNode()->visit(this, visitor);
11480 visitor->post(parent, this);
11481}
11482
11483bool BatchOneHotNode::isEqual(const BatchOneHotNode &other) const {
11484 return true &&
11485 Data_ == other.Data_ &&
11486 Lengths_ == other.Lengths_ &&
11487 Values_ == other.Values_ &&
11488 predicate_ == other.predicate_ &&
11489 getType(0) == other.getType(0);
11490}
11491
11492Node* BatchOneHotNode::clone() const {
11493 return new BatchOneHotNode(getName(), getResult().getType(), getData(), getLengths(), getValues());
11494}
11495
11496llvm::hash_code BatchOneHotNode::getHash() const {
11497 return llvm::hash_combine(
11498 Data_,
11499 Lengths_,
11500 Values_);
11501}
11502
11503unsigned SpaceToDepthNode::getNumInputs() const {
11504 return 1;
11505}
11506
11507std::string SpaceToDepthNode::getInputName(unsigned idx) const {
11508 if (idx == 0) { return "Input"; }
11509 idx -= 1;
11510 llvm_unreachable("Invalid index");
11511}
11512
11513NodeValue SpaceToDepthNode::getNthInput(unsigned idx) {
11514 if (idx == 0) { return Input_; }
11515 idx -= 1;
11516 llvm_unreachable("Invalid index");
11517}
11518
11519void SpaceToDepthNode::setNthInput(unsigned idx, NodeValue val) {
11520 if (idx == 0) { Input_ = val; return; }
11521 idx -= 1;
11522 llvm_unreachable("Invalid index");
11523}
11524
11525llvm::StringRef SpaceToDepthNode::getOutputName(unsigned idx) const {
11526 if (idx == 0) { return "Result"; }
11527 llvm_unreachable("Invalid index");
11528}
11529
11530std::string SpaceToDepthNode::getDebugDesc() const {
11531 DescriptionBuilder db(getKindName());
11532 db.addParam("Name", separateString(getName(), 100, "\n"));
11533 if (hasPredicate()) db.addParam("Predicate", "Yes");
11534 db
11535 .addParam("Input", *(getInput().getType()))
11536 .addParam("BlockSize", getBlockSize())
11537 .addParam("Users", getNumUsers());
11538 db.addParam("Result", *(getResult().getType()));
11539 return db;
11540}
11541
11542void SpaceToDepthNode::visit(Node *parent, NodeWalker *visitor) {
11543 if (!visitor->shouldVisit(parent, this)) { return; }
11544 visitor->pre(parent, this);
11545if (hasPredicate())
11546 getPredicate().getNode()->visit(this, visitor);
11547 getInput().getNode()->visit(this, visitor);
11548 visitor->post(parent, this);
11549}
11550
11551bool SpaceToDepthNode::isEqual(const SpaceToDepthNode &other) const {
11552 return true &&
11553 Input_ == other.Input_ &&
11554 predicate_ == other.predicate_ &&
11555 BlockSize_ == other.BlockSize_ &&
11556 getType(0) == other.getType(0);
11557}
11558
11559Node* SpaceToDepthNode::clone() const {
11560 return new SpaceToDepthNode(getName(), getResult().getType(), getInput(), getBlockSize());
11561}
11562
11563llvm::hash_code SpaceToDepthNode::getHash() const {
11564 return llvm::hash_combine(
11565 BlockSize_,
11566 Input_);
11567}
11568
11569unsigned ResizeNearestNode::getNumInputs() const {
11570 return 1;
11571}
11572
11573std::string ResizeNearestNode::getInputName(unsigned idx) const {
11574 if (idx == 0) { return "Input"; }
11575 idx -= 1;
11576 llvm_unreachable("Invalid index");
11577}
11578
11579NodeValue ResizeNearestNode::getNthInput(unsigned idx) {
11580 if (idx == 0) { return Input_; }
11581 idx -= 1;
11582 llvm_unreachable("Invalid index");
11583}
11584
11585void ResizeNearestNode::setNthInput(unsigned idx, NodeValue val) {
11586 if (idx == 0) { Input_ = val; return; }
11587 idx -= 1;
11588 llvm_unreachable("Invalid index");
11589}
11590
11591llvm::StringRef ResizeNearestNode::getOutputName(unsigned idx) const {
11592 if (idx == 0) { return "Result"; }
11593 llvm_unreachable("Invalid index");
11594}
11595
11596std::string ResizeNearestNode::getDebugDesc() const {
11597 DescriptionBuilder db(getKindName());
11598 db.addParam("Name", separateString(getName(), 100, "\n"));
11599 if (hasPredicate()) db.addParam("Predicate", "Yes");
11600 db
11601 .addParam("Input", *(getInput().getType()))
11602 .addParam("Scale", getScale())
11603 .addParam("Users", getNumUsers());
11604 db.addParam("Result", *(getResult().getType()));
11605 return db;
11606}
11607
11608void ResizeNearestNode::visit(Node *parent, NodeWalker *visitor) {
11609 if (!visitor->shouldVisit(parent, this)) { return; }
11610 visitor->pre(parent, this);
11611if (hasPredicate())
11612 getPredicate().getNode()->visit(this, visitor);
11613 getInput().getNode()->visit(this, visitor);
11614 visitor->post(parent, this);
11615}
11616
11617bool ResizeNearestNode::isEqual(const ResizeNearestNode &other) const {
11618 return true &&
11619 Input_ == other.Input_ &&
11620 predicate_ == other.predicate_ &&
11621 Scale_ == other.Scale_ &&
11622 getType(0) == other.getType(0);
11623}
11624
11625Node* ResizeNearestNode::clone() const {
11626 return new ResizeNearestNode(getName(), getResult().getType(), getInput(), getScale());
11627}
11628
11629llvm::hash_code ResizeNearestNode::getHash() const {
11630 return llvm::hash_combine(
11631 [](const std::vector<float>& floatVec) -> llvm::hash_code {
11632 std::vector<size_t> sizeVec = toBinary(floatVec);
11633 return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end());
11634 }(Scale_),
11635 Input_);
11636}
11637
11638unsigned ResizeBilinearNode::getNumInputs() const {
11639 return 1;
11640}
11641
11642std::string ResizeBilinearNode::getInputName(unsigned idx) const {
11643 if (idx == 0) { return "Input"; }
11644 idx -= 1;
11645 llvm_unreachable("Invalid index");
11646}
11647
11648NodeValue ResizeBilinearNode::getNthInput(unsigned idx) {
11649 if (idx == 0) { return Input_; }
11650 idx -= 1;
11651 llvm_unreachable("Invalid index");
11652}
11653
11654void ResizeBilinearNode::setNthInput(unsigned idx, NodeValue val) {
11655 if (idx == 0) { Input_ = val; return; }
11656 idx -= 1;
11657 llvm_unreachable("Invalid index");
11658}
11659
11660llvm::StringRef ResizeBilinearNode::getOutputName(unsigned idx) const {
11661 if (idx == 0) { return "Result"; }
11662 llvm_unreachable("Invalid index");
11663}
11664
11665std::string ResizeBilinearNode::getDebugDesc() const {
11666 DescriptionBuilder db(getKindName());
11667 db.addParam("Name", separateString(getName(), 100, "\n"));
11668 if (hasPredicate()) db.addParam("Predicate", "Yes");
11669 db
11670 .addParam("Input", *(getInput().getType()))
11671 .addParam("Scale", getScale())
11672 .addParam("Users", getNumUsers());
11673 db.addParam("Result", *(getResult().getType()));
11674 return db;
11675}
11676
11677void ResizeBilinearNode::visit(Node *parent, NodeWalker *visitor) {
11678 if (!visitor->shouldVisit(parent, this)) { return; }
11679 visitor->pre(parent, this);
11680if (hasPredicate())
11681 getPredicate().getNode()->visit(this, visitor);
11682 getInput().getNode()->visit(this, visitor);
11683 visitor->post(parent, this);
11684}
11685
11686bool ResizeBilinearNode::isEqual(const ResizeBilinearNode &other) const {
11687 return true &&
11688 Input_ == other.Input_ &&
11689 predicate_ == other.predicate_ &&
11690 Scale_ == other.Scale_ &&
11691 getType(0) == other.getType(0);
11692}
11693
11694Node* ResizeBilinearNode::clone() const {
11695 return new ResizeBilinearNode(getName(), getResult().getType(), getInput(), getScale());
11696}
11697
11698llvm::hash_code ResizeBilinearNode::getHash() const {
11699 return llvm::hash_combine(
11700 [](const std::vector<float>& floatVec) -> llvm::hash_code {
11701 std::vector<size_t> sizeVec = toBinary(floatVec);
11702 return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end());
11703 }(Scale_),
11704 Input_);
11705}
11706
11707unsigned BroadcastNode::getNumInputs() const {
11708 return 1;
11709}
11710
11711std::string BroadcastNode::getInputName(unsigned idx) const {
11712 if (idx == 0) { return "Input"; }
11713 idx -= 1;
11714 llvm_unreachable("Invalid index");
11715}
11716
11717NodeValue BroadcastNode::getNthInput(unsigned idx) {
11718 if (idx == 0) { return Input_; }
11719 idx -= 1;
11720 llvm_unreachable("Invalid index");
11721}
11722
11723void BroadcastNode::setNthInput(unsigned idx, NodeValue val) {
11724 if (idx == 0) { Input_ = val; return; }
11725 idx -= 1;
11726 llvm_unreachable("Invalid index");
11727}
11728
11729llvm::StringRef BroadcastNode::getOutputName(unsigned idx) const {
11730 if (idx == 0) { return "Result"; }
11731 llvm_unreachable("Invalid index");
11732}
11733
11734std::string BroadcastNode::getDebugDesc() const {
11735 DescriptionBuilder db(getKindName());
11736 db.addParam("Name", separateString(getName(), 100, "\n"));
11737 if (hasPredicate()) db.addParam("Predicate", "Yes");
11738 db
11739 .addParam("Input", *(getInput().getType()))
11740 .addParam("Axis", getAxis())
11741 .addParam("TargetDim", getTargetDim())
11742 .addParam("Users", getNumUsers());
11743 db.addParam("Result", *(getResult().getType()));
11744 return db;
11745}
11746
11747void BroadcastNode::visit(Node *parent, NodeWalker *visitor) {
11748 if (!visitor->shouldVisit(parent, this)) { return; }
11749 visitor->pre(parent, this);
11750if (hasPredicate())
11751 getPredicate().getNode()->visit(this, visitor);
11752 getInput().getNode()->visit(this, visitor);
11753 visitor->post(parent, this);
11754}
11755
11756bool BroadcastNode::isEqual(const BroadcastNode &other) const {
11757 return true &&
11758 Input_ == other.Input_ &&
11759 predicate_ == other.predicate_ &&
11760 Axis_ == other.Axis_ &&
11761 TargetDim_ == other.TargetDim_ &&
11762 getType(0) == other.getType(0);
11763}
11764
11765Node* BroadcastNode::clone() const {
11766 return new BroadcastNode(getName(), getResult().getType(), getInput(), getAxis(), getTargetDim());
11767}
11768
11769llvm::hash_code BroadcastNode::getHash() const {
11770 return llvm::hash_combine(
11771 Axis_,
11772 llvm::hash_combine_range(TargetDim_.begin(), TargetDim_.end()),
11773 Input_);
11774}
11775
11776unsigned SparseLabelSplitNode::getNumInputs() const {
11777 return 3;
11778}
11779
11780std::string SparseLabelSplitNode::getInputName(unsigned idx) const {
11781 if (idx == 0) { return "Lengths"; }
11782 if (idx == 1) { return "Indices"; }
11783 if (idx == 2) { return "Values"; }
11784 idx -= 3;
11785 llvm_unreachable("Invalid index");
11786}
11787
11788NodeValue SparseLabelSplitNode::getNthInput(unsigned idx) {
11789 if (idx == 0) { return Lengths_; }
11790 if (idx == 1) { return Indices_; }
11791 if (idx == 2) { return Values_; }
11792 idx -= 3;
11793 llvm_unreachable("Invalid index");
11794}
11795
11796void SparseLabelSplitNode::setNthInput(unsigned idx, NodeValue val) {
11797 if (idx == 0) { Lengths_ = val; return; }
11798 if (idx == 1) { Indices_ = val; return; }
11799 if (idx == 2) { Values_ = val; return; }
11800 idx -= 3;
11801 llvm_unreachable("Invalid index");
11802}
11803
11804llvm::StringRef SparseLabelSplitNode::getOutputName(unsigned idx) const {
11805 if (idx == 0) { return "LabelValues"; }
11806 if (idx == 1) { return "ExampleIds"; }
11807 if (idx == 2) { return "GradientOffsetMap"; }
11808 llvm_unreachable("Invalid index");
11809}
11810
11811std::string SparseLabelSplitNode::getDebugDesc() const {
11812 DescriptionBuilder db(getKindName());
11813 db.addParam("Name", separateString(getName(), 100, "\n"));
11814 if (hasPredicate()) db.addParam("Predicate", "Yes");
11815 db
11816 .addParam("Lengths", *(getLengths().getType()))
11817 .addParam("Indices", *(getIndices().getType()))
11818 .addParam("Values", *(getValues().getType()))
11819 .addParam("NumLabels", getNumLabels())
11820 .addParam("Users", getNumUsers());
11821 db.addParam("LabelValues", *(getLabelValues().getType()));
11822 db.addParam("ExampleIds", *(getExampleIds().getType()));
11823 db.addParam("GradientOffsetMap", *(getGradientOffsetMap().getType()));
11824 return db;
11825}
11826
11827void SparseLabelSplitNode::visit(Node *parent, NodeWalker *visitor) {
11828 if (!visitor->shouldVisit(parent, this)) { return; }
11829 visitor->pre(parent, this);
11830if (hasPredicate())
11831 getPredicate().getNode()->visit(this, visitor);
11832 getLengths().getNode()->visit(this, visitor);
11833 getIndices().getNode()->visit(this, visitor);
11834 getValues().getNode()->visit(this, visitor);
11835 visitor->post(parent, this);
11836}
11837
11838bool SparseLabelSplitNode::isEqual(const SparseLabelSplitNode &other) const {
11839 return true &&
11840 Lengths_ == other.Lengths_ &&
11841 Indices_ == other.Indices_ &&
11842 Values_ == other.Values_ &&
11843 predicate_ == other.predicate_ &&
11844 NumLabels_ == other.NumLabels_ &&
11845 getType(0) == other.getType(0) &&
11846 getType(1) == other.getType(1) &&
11847 getType(2) == other.getType(2);
11848}
11849
11850Node* SparseLabelSplitNode::clone() const {
11851 return new SparseLabelSplitNode(getName(), getLabelValues().getType(), getExampleIds().getType(), getGradientOffsetMap().getType(), getLengths(), getIndices(), getValues(), getNumLabels());
11852}
11853
11854llvm::hash_code SparseLabelSplitNode::getHash() const {
11855 return llvm::hash_combine(
11856 NumLabels_,
11857 Lengths_,
11858 Indices_,
11859 Values_);
11860}
11861
11862unsigned FlipNode::getNumInputs() const {
11863 return 1;
11864}
11865
11866std::string FlipNode::getInputName(unsigned idx) const {
11867 if (idx == 0) { return "Input"; }
11868 idx -= 1;
11869 llvm_unreachable("Invalid index");
11870}
11871
11872NodeValue FlipNode::getNthInput(unsigned idx) {
11873 if (idx == 0) { return Input_; }
11874 idx -= 1;
11875 llvm_unreachable("Invalid index");
11876}
11877
11878void FlipNode::setNthInput(unsigned idx, NodeValue val) {
11879 if (idx == 0) { Input_ = val; return; }
11880 idx -= 1;
11881 llvm_unreachable("Invalid index");
11882}
11883
11884llvm::StringRef FlipNode::getOutputName(unsigned idx) const {
11885 if (idx == 0) { return "Result"; }
11886 llvm_unreachable("Invalid index");
11887}
11888
11889std::string FlipNode::getDebugDesc() const {
11890 DescriptionBuilder db(getKindName());
11891 db.addParam("Name", separateString(getName(), 100, "\n"));
11892 if (hasPredicate()) db.addParam("Predicate", "Yes");
11893 db
11894 .addParam("Input", *(getInput().getType()))
11895 .addParam("Axis", getAxis())
11896 .addParam("Users", getNumUsers());
11897 db.addParam("Result", *(getResult().getType()));
11898 return db;
11899}
11900
11901void FlipNode::visit(Node *parent, NodeWalker *visitor) {
11902 if (!visitor->shouldVisit(parent, this)) { return; }
11903 visitor->pre(parent, this);
11904if (hasPredicate())
11905 getPredicate().getNode()->visit(this, visitor);
11906 getInput().getNode()->visit(this, visitor);
11907 visitor->post(parent, this);
11908}
11909
11910bool FlipNode::isEqual(const FlipNode &other) const {
11911 return true &&
11912 Input_ == other.Input_ &&
11913 predicate_ == other.predicate_ &&
11914 Axis_ == other.Axis_ &&
11915 getType(0) == other.getType(0);
11916}
11917
11918Node* FlipNode::clone() const {
11919 return new FlipNode(getName(), getResult().getType(), getInput(), getAxis());
11920}
11921
11922llvm::hash_code FlipNode::getHash() const {
11923 return llvm::hash_combine(
11924 Axis_,
11925 Input_);
11926}
11927
11928unsigned SplatNode::getNumInputs() const {
11929 return 0;
11930}
11931
11932std::string SplatNode::getInputName(unsigned idx) const {
11933 idx -= 0;
11934 llvm_unreachable("Invalid index");
11935}
11936
11937NodeValue SplatNode::getNthInput(unsigned idx) {
11938 idx -= 0;
11939 llvm_unreachable("Invalid index");
11940}
11941
11942void SplatNode::setNthInput(unsigned idx, NodeValue val) {
11943 idx -= 0;
11944 llvm_unreachable("Invalid index");
11945}
11946
11947llvm::StringRef SplatNode::getOutputName(unsigned idx) const {
11948 if (idx == 0) { return "Result"; }
11949 llvm_unreachable("Invalid index");
11950}
11951
11952std::string SplatNode::getDebugDesc() const {
11953 DescriptionBuilder db(getKindName());
11954 db.addParam("Name", separateString(getName(), 100, "\n"));
11955 if (hasPredicate()) db.addParam("Predicate", "Yes");
11956 db
11957 .addParam("Value", getValue())
11958 .addParam("Users", getNumUsers());
11959 db.addParam("Result", *(getResult().getType()));
11960 return db;
11961}
11962
11963void SplatNode::visit(Node *parent, NodeWalker *visitor) {
11964 if (!visitor->shouldVisit(parent, this)) { return; }
11965 visitor->pre(parent, this);
11966if (hasPredicate())
11967 getPredicate().getNode()->visit(this, visitor);
11968 visitor->post(parent, this);
11969}
11970
11971bool SplatNode::isEqual(const SplatNode &other) const {
11972 return true &&
11973 predicate_ == other.predicate_ &&
11974 Value_ == other.Value_ &&
11975 getType(0) == other.getType(0);
11976}
11977
11978Node* SplatNode::clone() const {
11979 return new SplatNode(getName(), getResult().getType(), getValue());
11980}
11981
11982llvm::hash_code SplatNode::getHash() const {
11983 return llvm::hash_combine(
11984 toBinary(Value_));
11985}
11986
11987unsigned TouchNode::getNumInputs() const {
11988 return 0;
11989}
11990
11991std::string TouchNode::getInputName(unsigned idx) const {
11992 idx -= 0;
11993 llvm_unreachable("Invalid index");
11994}
11995
11996NodeValue TouchNode::getNthInput(unsigned idx) {
11997 idx -= 0;
11998 llvm_unreachable("Invalid index");
11999}
12000
12001void TouchNode::setNthInput(unsigned idx, NodeValue val) {
12002 idx -= 0;
12003 llvm_unreachable("Invalid index");
12004}
12005
12006llvm::StringRef TouchNode::getOutputName(unsigned idx) const {
12007 if (idx == 0) { return "Result"; }
12008 llvm_unreachable("Invalid index");
12009}
12010
12011std::string TouchNode::getDebugDesc() const {
12012 DescriptionBuilder db(getKindName());
12013 db.addParam("Name", separateString(getName(), 100, "\n"));
12014 if (hasPredicate()) db.addParam("Predicate", "Yes");
12015 db
12016 .addParam("Users", getNumUsers());
12017 db.addParam("Result", *(getResult().getType()));
12018 return db;
12019}
12020
12021void TouchNode::visit(Node *parent, NodeWalker *visitor) {
12022 if (!visitor->shouldVisit(parent, this)) { return; }
12023 visitor->pre(parent, this);
12024if (hasPredicate())
12025 getPredicate().getNode()->visit(this, visitor);
12026 visitor->post(parent, this);
12027}
12028
12029bool TouchNode::isEqual(const TouchNode &other) const {
12030 return true &&
12031 predicate_ == other.predicate_ &&
12032 getType(0) == other.getType(0);
12033}
12034
12035Node* TouchNode::clone() const {
12036 return new TouchNode(getName(), getResult().getType());
12037}
12038
12039llvm::hash_code TouchNode::getHash() const {
12040 return llvm::hash_combine(0);
12041 }
12042
12043unsigned SGDNode::getNumInputs() const {
12044 return 2;
12045}
12046
12047std::string SGDNode::getInputName(unsigned idx) const {
12048 if (idx == 0) { return "Gradient"; }
12049 if (idx == 1) { return "Weight"; }
12050 idx -= 2;
12051 llvm_unreachable("Invalid index");
12052}
12053
12054NodeValue SGDNode::getNthInput(unsigned idx) {
12055 if (idx == 0) { return Gradient_; }
12056 if (idx == 1) { return Weight_; }
12057 idx -= 2;
12058 llvm_unreachable("Invalid index");
12059}
12060
12061void SGDNode::setNthInput(unsigned idx, NodeValue val) {
12062 if (idx == 0) { Gradient_ = val; return; }
12063 if (idx == 1) { Weight_ = val; return; }
12064 idx -= 2;
12065 llvm_unreachable("Invalid index");
12066}
12067
12068llvm::StringRef SGDNode::getOutputName(unsigned idx) const {
12069 if (idx == 0) { return "UpdatedWeight"; }
12070 llvm_unreachable("Invalid index");
12071}
12072
12073std::string SGDNode::getDebugDesc() const {
12074 DescriptionBuilder db(getKindName());
12075 db.addParam("Name", separateString(getName(), 100, "\n"));
12076 if (hasPredicate()) db.addParam("Predicate", "Yes");
12077 db
12078 .addParam("Gradient", *(getGradient().getType()))
12079 .addParam("Weight", *(getWeight().getType()))
12080 .addParam("L1Decay", getL1Decay())
12081 .addParam("L2Decay", getL2Decay())
12082 .addParam("LearningRate", getLearningRate())
12083 .addParam("Momentum", getMomentum())
12084 .addParam("BatchSize", getBatchSize())
12085 .addParam("Users", getNumUsers());
12086 db.addParam("UpdatedWeight", *(getUpdatedWeight().getType()));
12087 return db;
12088}
12089
12090void SGDNode::visit(Node *parent, NodeWalker *visitor) {
12091 if (!visitor->shouldVisit(parent, this)) { return; }
12092 visitor->pre(parent, this);
12093if (hasPredicate())
12094 getPredicate().getNode()->visit(this, visitor);
12095 getGradient().getNode()->visit(this, visitor);
12096 getWeight().getNode()->visit(this, visitor);
12097 visitor->post(parent, this);
12098}
12099
12100bool SGDNode::isEqual(const SGDNode &other) const {
12101 return true &&
12102 Gradient_ == other.Gradient_ &&
12103 Weight_ == other.Weight_ &&
12104 predicate_ == other.predicate_ &&
12105 L1Decay_ == other.L1Decay_ &&
12106 L2Decay_ == other.L2Decay_ &&
12107 LearningRate_ == other.LearningRate_ &&
12108 Momentum_ == other.Momentum_ &&
12109 BatchSize_ == other.BatchSize_ &&
12110 getType(0) == other.getType(0);
12111}
12112
12113Node* SGDNode::clone() const {
12114 return new SGDNode(getName(), getGradient(), getWeight(), getL1Decay(), getL2Decay(), getLearningRate(), getMomentum(), getBatchSize());
12115}
12116
12117llvm::hash_code SGDNode::getHash() const {
12118 return llvm::hash_combine(
12119 toBinary(L1Decay_),
12120 toBinary(L2Decay_),
12121 toBinary(LearningRate_),
12122 toBinary(Momentum_),
12123 BatchSize_,
12124 Gradient_,
12125 Weight_);
12126}
12127
12128unsigned TraceEventNode::getNumInputs() const {
12129 return 1;
12130}
12131
12132std::string TraceEventNode::getInputName(unsigned idx) const {
12133 if (idx == 0) { return "Data"; }
12134 idx -= 1;
12135 llvm_unreachable("Invalid index");
12136}
12137
12138NodeValue TraceEventNode::getNthInput(unsigned idx) {
12139 if (idx == 0) { return Data_; }
12140 idx -= 1;
12141 llvm_unreachable("Invalid index");
12142}
12143
12144void TraceEventNode::setNthInput(unsigned idx, NodeValue val) {
12145 if (idx == 0) { Data_ = val; return; }
12146 idx -= 1;
12147 llvm_unreachable("Invalid index");
12148}
12149
12150llvm::StringRef TraceEventNode::getOutputName(unsigned idx) const {
12151 llvm_unreachable("Invalid index");
12152}
12153
12154std::string TraceEventNode::getDebugDesc() const {
12155 DescriptionBuilder db(getKindName());
12156 db.addParam("Name", separateString(getName(), 100, "\n"));
12157 if (hasPredicate()) db.addParam("Predicate", "Yes");
12158 db
12159 .addParam("Data", *(getData().getType()))
12160 .addParam("EventName", getEventName())
12161 .addParam("EventType", getEventType())
12162 .addParam("Index", getIndex())
12163 .addParam("Users", getNumUsers());
12164 return db;
12165}
12166
12167void TraceEventNode::visit(Node *parent, NodeWalker *visitor) {
12168 if (!visitor->shouldVisit(parent, this)) { return; }
12169 visitor->pre(parent, this);
12170if (hasPredicate())
12171 getPredicate().getNode()->visit(this, visitor);
12172 getData().getNode()->visit(this, visitor);
12173 visitor->post(parent, this);
12174}
12175
12176bool TraceEventNode::isEqual(const TraceEventNode &other) const {
12177 return true &&
12178 Data_ == other.Data_ &&
12179 predicate_ == other.predicate_ &&
12180 EventName_ == other.EventName_ &&
12181 EventType_ == other.EventType_ &&
12182 Index_ == other.Index_;
12183}
12184
12185Node* TraceEventNode::clone() const {
12186 return new TraceEventNode(getName(), getData(), getEventName(), getEventType(), getIndex());
12187}
12188
12189llvm::hash_code TraceEventNode::getHash() const {
12190 return llvm::hash_combine(
12191 EventName_,
12192 EventType_,
12193 Index_,
12194 Data_);
12195}
12196
12197unsigned QuantizationProfileNode::getNumInputs() const {
12198 return 3;
12199}
12200
12201std::string QuantizationProfileNode::getInputName(unsigned idx) const {
12202 if (idx == 0) { return "Input"; }
12203 if (idx == 1) { return "Histogram"; }
12204 if (idx == 2) { return "ComputationInfo"; }
12205 idx -= 3;
12206 llvm_unreachable("Invalid index");
12207}
12208
12209NodeValue QuantizationProfileNode::getNthInput(unsigned idx) {
12210 if (idx == 0) { return Input_; }
12211 if (idx == 1) { return Histogram_; }
12212 if (idx == 2) { return ComputationInfo_; }
12213 idx -= 3;
12214 llvm_unreachable("Invalid index");
12215}
12216
12217void QuantizationProfileNode::setNthInput(unsigned idx, NodeValue val) {
12218 if (idx == 0) { Input_ = val; return; }
12219 if (idx == 1) { Histogram_ = val; return; }
12220 if (idx == 2) { ComputationInfo_ = val; return; }
12221 idx -= 3;
12222 llvm_unreachable("Invalid index");
12223}
12224
12225llvm::StringRef QuantizationProfileNode::getOutputName(unsigned idx) const {
12226 llvm_unreachable("Invalid index");
12227}
12228
12229std::string QuantizationProfileNode::getDebugDesc() const {
12230 DescriptionBuilder db(getKindName());
12231 db.addParam("Name", separateString(getName(), 100, "\n"));
12232 if (hasPredicate()) db.addParam("Predicate", "Yes");
12233 db
12234 .addParam("Input", *(getInput().getType()))
12235 .addParam("Histogram", *(getHistogram().getType()))
12236 .addParam("ComputationInfo", *(getComputationInfo().getType()))
12237 .addParam("ProfiledNodeName", getProfiledNodeName())
12238 .addParam("ProfiledOutputNumber", getProfiledOutputNumber())
12239 .addParam("Users", getNumUsers());
12240 return db;
12241}
12242
12243void QuantizationProfileNode::visit(Node *parent, NodeWalker *visitor) {
12244 if (!visitor->shouldVisit(parent, this)) { return; }
12245 visitor->pre(parent, this);
12246if (hasPredicate())
12247 getPredicate().getNode()->visit(this, visitor);
12248 getInput().getNode()->visit(this, visitor);
12249 getHistogram().getNode()->visit(this, visitor);
12250 getComputationInfo().getNode()->visit(this, visitor);
12251 visitor->post(parent, this);
12252}
12253
12254bool QuantizationProfileNode::isEqual(const QuantizationProfileNode &other) const {
12255 return true &&
12256 Input_ == other.Input_ &&
12257 Histogram_ == other.Histogram_ &&
12258 ComputationInfo_ == other.ComputationInfo_ &&
12259 predicate_ == other.predicate_ &&
12260 ProfiledNodeName_ == other.ProfiledNodeName_ &&
12261 ProfiledOutputNumber_ == other.ProfiledOutputNumber_;
12262}
12263
12264Node* QuantizationProfileNode::clone() const {
12265 return new QuantizationProfileNode(getName(), getInput(), getHistogram(), getComputationInfo(), getProfiledNodeName(), getProfiledOutputNumber());
12266}
12267
12268llvm::hash_code QuantizationProfileNode::getHash() const {
12269 return llvm::hash_combine(
12270 ProfiledNodeName_,
12271 ProfiledOutputNumber_,
12272 Input_,
12273 Histogram_,
12274 ComputationInfo_);
12275}
12276Placeholder *QuantizationProfileNode::getHistogramPlaceholder() const { return llvm::cast<Placeholder>(Histogram_.getNode()); };
12277Placeholder *QuantizationProfileNode::getComputationInfoPlaceholder() const { return llvm::cast<Placeholder>(ComputationInfo_.getNode()); };
12278
12279unsigned IntLookupTableNode::getNumInputs() const {
12280 return 2;
12281}
12282
12283std::string IntLookupTableNode::getInputName(unsigned idx) const {
12284 if (idx == 0) { return "Input"; }
12285 if (idx == 1) { return "Mapping"; }
12286 idx -= 2;
12287 llvm_unreachable("Invalid index");
12288}
12289
12290NodeValue IntLookupTableNode::getNthInput(unsigned idx) {
12291 if (idx == 0) { return Input_; }
12292 if (idx == 1) { return Mapping_; }
12293 idx -= 2;
12294 llvm_unreachable("Invalid index");
12295}
12296
12297void IntLookupTableNode::setNthInput(unsigned idx, NodeValue val) {
12298 if (idx == 0) { Input_ = val; return; }
12299 if (idx == 1) { Mapping_ = val; return; }
12300 idx -= 2;
12301 llvm_unreachable("Invalid index");
12302}
12303
12304llvm::StringRef IntLookupTableNode::getOutputName(unsigned idx) const {
12305 if (idx == 0) { return "Result"; }
12306 llvm_unreachable("Invalid index");
12307}
12308
12309std::string IntLookupTableNode::getDebugDesc() const {
12310 DescriptionBuilder db(getKindName());
12311 db.addParam("Name", separateString(getName(), 100, "\n"));
12312 if (hasPredicate()) db.addParam("Predicate", "Yes");
12313 db
12314 .addParam("Input", *(getInput().getType()))
12315 .addParam("Mapping", *(getMapping().getType()))
12316 .addParam("Users", getNumUsers());
12317 db.addParam("Result", *(getResult().getType()));
12318 return db;
12319}
12320
12321void IntLookupTableNode::visit(Node *parent, NodeWalker *visitor) {
12322 if (!visitor->shouldVisit(parent, this)) { return; }
12323 visitor->pre(parent, this);
12324if (hasPredicate())
12325 getPredicate().getNode()->visit(this, visitor);
12326 getInput().getNode()->visit(this, visitor);
12327 getMapping().getNode()->visit(this, visitor);
12328 visitor->post(parent, this);
12329}
12330
12331bool IntLookupTableNode::isEqual(const IntLookupTableNode &other) const {
12332 return true &&
12333 Input_ == other.Input_ &&
12334 Mapping_ == other.Mapping_ &&
12335 predicate_ == other.predicate_ &&
12336 getType(0) == other.getType(0);
12337}
12338
12339Node* IntLookupTableNode::clone() const {
12340 return new IntLookupTableNode(getName(), getResult().getType(), getInput(), getMapping());
12341}
12342
12343llvm::hash_code IntLookupTableNode::getHash() const {
12344 return llvm::hash_combine(
12345 Input_,
12346 Mapping_);
12347}
12348
12349unsigned QuantizeNode::getNumInputs() const {
12350 return 1;
12351}
12352
12353std::string QuantizeNode::getInputName(unsigned idx) const {
12354 if (idx == 0) { return "Input"; }
12355 idx -= 1;
12356 llvm_unreachable("Invalid index");
12357}
12358
12359NodeValue QuantizeNode::getNthInput(unsigned idx) {
12360 if (idx == 0) { return Input_; }
12361 idx -= 1;
12362 llvm_unreachable("Invalid index");
12363}
12364
12365void QuantizeNode::setNthInput(unsigned idx, NodeValue val) {
12366 if (idx == 0) { Input_ = val; return; }
12367 idx -= 1;
12368 llvm_unreachable("Invalid index");
12369}
12370
12371llvm::StringRef QuantizeNode::getOutputName(unsigned idx) const {
12372 if (idx == 0) { return "Result"; }
12373 llvm_unreachable("Invalid index");
12374}
12375
12376std::string QuantizeNode::getDebugDesc() const {
12377 DescriptionBuilder db(getKindName());
12378 db.addParam("Name", separateString(getName(), 100, "\n"));
12379 if (hasPredicate()) db.addParam("Predicate", "Yes");
12380 db
12381 .addParam("Input", *(getInput().getType()))
12382 .addParam("Users", getNumUsers());
12383 db.addParam("Result", *(getResult().getType()));
12384 return db;
12385}
12386
12387void QuantizeNode::visit(Node *parent, NodeWalker *visitor) {
12388 if (!visitor->shouldVisit(parent, this)) { return; }
12389 visitor->pre(parent, this);
12390if (hasPredicate())
12391 getPredicate().getNode()->visit(this, visitor);
12392 getInput().getNode()->visit(this, visitor);
12393 visitor->post(parent, this);
12394}
12395
12396bool QuantizeNode::isEqual(const QuantizeNode &other) const {
12397 return true &&
12398 Input_ == other.Input_ &&
12399 predicate_ == other.predicate_ &&
12400 getType(0) == other.getType(0);
12401}
12402
12403Node* QuantizeNode::clone() const {
12404 return new QuantizeNode(getName(), getResult().getType(), getInput());
12405}
12406
12407llvm::hash_code QuantizeNode::getHash() const {
12408 return llvm::hash_combine(
12409 Input_);
12410}
12411
12412unsigned DequantizeNode::getNumInputs() const {
12413 return 1;
12414}
12415
12416std::string DequantizeNode::getInputName(unsigned idx) const {
12417 if (idx == 0) { return "Input"; }
12418 idx -= 1;
12419 llvm_unreachable("Invalid index");
12420}
12421
12422NodeValue DequantizeNode::getNthInput(unsigned idx) {
12423 if (idx == 0) { return Input_; }
12424 idx -= 1;
12425 llvm_unreachable("Invalid index");
12426}
12427
12428void DequantizeNode::setNthInput(unsigned idx, NodeValue val) {
12429 if (idx == 0) { Input_ = val; return; }
12430 idx -= 1;
12431 llvm_unreachable("Invalid index");
12432}
12433
12434llvm::StringRef DequantizeNode::getOutputName(unsigned idx) const {
12435 if (idx == 0) { return "Result"; }
12436 llvm_unreachable("Invalid index");
12437}
12438
12439std::string DequantizeNode::getDebugDesc() const {
12440 DescriptionBuilder db(getKindName());
12441 db.addParam("Name", separateString(getName(), 100, "\n"));
12442 if (hasPredicate()) db.addParam("Predicate", "Yes");
12443 db
12444 .addParam("Input", *(getInput().getType()))
12445 .addParam("Users", getNumUsers());
12446 db.addParam("Result", *(getResult().getType()));
12447 return db;
12448}
12449
12450void DequantizeNode::visit(Node *parent, NodeWalker *visitor) {
12451 if (!visitor->shouldVisit(parent, this)) { return; }
12452 visitor->pre(parent, this);
12453if (hasPredicate())
12454 getPredicate().getNode()->visit(this, visitor);
12455 getInput().getNode()->visit(this, visitor);
12456 visitor->post(parent, this);
12457}
12458
12459bool DequantizeNode::isEqual(const DequantizeNode &other) const {
12460 return true &&
12461 Input_ == other.Input_ &&
12462 predicate_ == other.predicate_ &&
12463 getType(0) == other.getType(0);
12464}
12465
12466Node* DequantizeNode::clone() const {
12467 return new DequantizeNode(getName(), getResult().getType(), getInput());
12468}
12469
12470llvm::hash_code DequantizeNode::getHash() const {
12471 return llvm::hash_combine(
12472 Input_);
12473}
12474
12475unsigned RescaleQuantizedNode::getNumInputs() const {
12476 return 1;
12477}
12478
12479std::string RescaleQuantizedNode::getInputName(unsigned idx) const {
12480 if (idx == 0) { return "Input"; }
12481 idx -= 1;
12482 llvm_unreachable("Invalid index");
12483}
12484
12485NodeValue RescaleQuantizedNode::getNthInput(unsigned idx) {
12486 if (idx == 0) { return Input_; }
12487 idx -= 1;
12488 llvm_unreachable("Invalid index");
12489}
12490
12491void RescaleQuantizedNode::setNthInput(unsigned idx, NodeValue val) {
12492 if (idx == 0) { Input_ = val; return; }
12493 idx -= 1;
12494 llvm_unreachable("Invalid index");
12495}
12496
12497llvm::StringRef RescaleQuantizedNode::getOutputName(unsigned idx) const {
12498 if (idx == 0) { return "Result"; }
12499 llvm_unreachable("Invalid index");
12500}
12501
12502std::string RescaleQuantizedNode::getDebugDesc() const {
12503 DescriptionBuilder db(getKindName());
12504 db.addParam("Name", separateString(getName(), 100, "\n"));
12505 if (hasPredicate()) db.addParam("Predicate", "Yes");
12506 db
12507 .addParam("Input", *(getInput().getType()))
12508 .addParam("Users", getNumUsers());
12509 db.addParam("Result", *(getResult().getType()));
12510 return db;
12511}
12512
12513void RescaleQuantizedNode::visit(Node *parent, NodeWalker *visitor) {
12514 if (!visitor->shouldVisit(parent, this)) { return; }
12515 visitor->pre(parent, this);
12516if (hasPredicate())
12517 getPredicate().getNode()->visit(this, visitor);
12518 getInput().getNode()->visit(this, visitor);
12519 visitor->post(parent, this);
12520}
12521
12522bool RescaleQuantizedNode::isEqual(const RescaleQuantizedNode &other) const {
12523 return true &&
12524 Input_ == other.Input_ &&
12525 predicate_ == other.predicate_ &&
12526 getType(0) == other.getType(0);
12527}
12528
12529Node* RescaleQuantizedNode::clone() const {
12530 return new RescaleQuantizedNode(getName(), getResult().getType(), getInput());
12531}
12532
12533llvm::hash_code RescaleQuantizedNode::getHash() const {
12534 return llvm::hash_combine(
12535 Input_);
12536}
12537
12538unsigned TopKNode::getNumInputs() const {
12539 return 1;
12540}
12541
12542std::string TopKNode::getInputName(unsigned idx) const {
12543 if (idx == 0) { return "Input"; }
12544 idx -= 1;
12545 llvm_unreachable("Invalid index");
12546}
12547
12548NodeValue TopKNode::getNthInput(unsigned idx) {
12549 if (idx == 0) { return Input_; }
12550 idx -= 1;
12551 llvm_unreachable("Invalid index");
12552}
12553
12554void TopKNode::setNthInput(unsigned idx, NodeValue val) {
12555 if (idx == 0) { Input_ = val; return; }
12556 idx -= 1;
12557 llvm_unreachable("Invalid index");
12558}
12559
12560llvm::StringRef TopKNode::getOutputName(unsigned idx) const {
12561 if (idx == 0) { return "Values"; }
12562 if (idx == 1) { return "Indices"; }
12563 llvm_unreachable("Invalid index");
12564}
12565
12566std::string TopKNode::getDebugDesc() const {
12567 DescriptionBuilder db(getKindName());
12568 db.addParam("Name", separateString(getName(), 100, "\n"));
12569 if (hasPredicate()) db.addParam("Predicate", "Yes");
12570 db
12571 .addParam("Input", *(getInput().getType()))
12572 .addParam("K", getK())
12573 .addParam("Users", getNumUsers());
12574 db.addParam("Values", *(getValues().getType()));
12575 db.addParam("Indices", *(getIndices().getType()));
12576 return db;
12577}
12578
12579void TopKNode::visit(Node *parent, NodeWalker *visitor) {
12580 if (!visitor->shouldVisit(parent, this)) { return; }
12581 visitor->pre(parent, this);
12582if (hasPredicate())
12583 getPredicate().getNode()->visit(this, visitor);
12584 getInput().getNode()->visit(this, visitor);
12585 visitor->post(parent, this);
12586}
12587
12588bool TopKNode::isEqual(const TopKNode &other) const {
12589 return true &&
12590 Input_ == other.Input_ &&
12591 predicate_ == other.predicate_ &&
12592 K_ == other.K_ &&
12593 getType(0) == other.getType(0) &&
12594 getType(1) == other.getType(1);
12595}
12596
12597Node* TopKNode::clone() const {
12598 return new TopKNode(getName(), getValues().getType(), getIndices().getType(), getInput(), getK());
12599}
12600
12601llvm::hash_code TopKNode::getHash() const {
12602 return llvm::hash_combine(
12603 K_,
12604 Input_);
12605}
12606
12607unsigned LSTMUnitNode::getNumInputs() const {
12608 return 2;
12609}
12610
12611std::string LSTMUnitNode::getInputName(unsigned idx) const {
12612 if (idx == 0) { return "Input"; }
12613 if (idx == 1) { return "C"; }
12614 idx -= 2;
12615 llvm_unreachable("Invalid index");
12616}
12617
12618NodeValue LSTMUnitNode::getNthInput(unsigned idx) {
12619 if (idx == 0) { return Input_; }
12620 if (idx == 1) { return C_; }
12621 idx -= 2;
12622 llvm_unreachable("Invalid index");
12623}
12624
12625void LSTMUnitNode::setNthInput(unsigned idx, NodeValue val) {
12626 if (idx == 0) { Input_ = val; return; }
12627 if (idx == 1) { C_ = val; return; }
12628 idx -= 2;
12629 llvm_unreachable("Invalid index");
12630}
12631
12632llvm::StringRef LSTMUnitNode::getOutputName(unsigned idx) const {
12633 if (idx == 0) { return "newC"; }
12634 if (idx == 1) { return "newH"; }
12635 llvm_unreachable("Invalid index");
12636}
12637
12638std::string LSTMUnitNode::getDebugDesc() const {
12639 DescriptionBuilder db(getKindName());
12640 db.addParam("Name", separateString(getName(), 100, "\n"));
12641 if (hasPredicate()) db.addParam("Predicate", "Yes");
12642 db
12643 .addParam("Input", *(getInput().getType()))
12644 .addParam("C", *(getC().getType()))
12645 .addParam("Users", getNumUsers());
12646 db.addParam("newC", *(getnewC().getType()));
12647 db.addParam("newH", *(getnewH().getType()));
12648 return db;
12649}
12650
12651void LSTMUnitNode::visit(Node *parent, NodeWalker *visitor) {
12652 if (!visitor->shouldVisit(parent, this)) { return; }
12653 visitor->pre(parent, this);
12654if (hasPredicate())
12655 getPredicate().getNode()->visit(this, visitor);
12656 getInput().getNode()->visit(this, visitor);
12657 getC().getNode()->visit(this, visitor);
12658 visitor->post(parent, this);
12659}
12660
12661bool LSTMUnitNode::isEqual(const LSTMUnitNode &other) const {
12662 return true &&
12663 Input_ == other.Input_ &&
12664 C_ == other.C_ &&
12665 predicate_ == other.predicate_ &&
12666 getType(0) == other.getType(0) &&
12667 getType(1) == other.getType(1);
12668}
12669
12670Node* LSTMUnitNode::clone() const {
12671 return new LSTMUnitNode(getName(), getInput(), getC());
12672}
12673
12674llvm::hash_code LSTMUnitNode::getHash() const {
12675 return llvm::hash_combine(
12676 Input_,
12677 C_);
12678}
12679
12680unsigned ConvertToNode::getNumInputs() const {
12681 return 1;
12682}
12683
12684std::string ConvertToNode::getInputName(unsigned idx) const {
12685 if (idx == 0) { return "Input"; }
12686 idx -= 1;
12687 llvm_unreachable("Invalid index");
12688}
12689
12690NodeValue ConvertToNode::getNthInput(unsigned idx) {
12691 if (idx == 0) { return Input_; }
12692 idx -= 1;
12693 llvm_unreachable("Invalid index");
12694}
12695
12696void ConvertToNode::setNthInput(unsigned idx, NodeValue val) {
12697 if (idx == 0) { Input_ = val; return; }
12698 idx -= 1;
12699 llvm_unreachable("Invalid index");
12700}
12701
12702llvm::StringRef ConvertToNode::getOutputName(unsigned idx) const {
12703 if (idx == 0) { return "Result"; }
12704 llvm_unreachable("Invalid index");
12705}
12706
12707std::string ConvertToNode::getDebugDesc() const {
12708 DescriptionBuilder db(getKindName());
12709 db.addParam("Name", separateString(getName(), 100, "\n"));
12710 if (hasPredicate()) db.addParam("Predicate", "Yes");
12711 db
12712 .addParam("Input", *(getInput().getType()))
12713 .addParam("Users", getNumUsers());
12714 db.addParam("Result", *(getResult().getType()));
12715 return db;
12716}
12717
12718void ConvertToNode::visit(Node *parent, NodeWalker *visitor) {
12719 if (!visitor->shouldVisit(parent, this)) { return; }
12720 visitor->pre(parent, this);
12721if (hasPredicate())
12722 getPredicate().getNode()->visit(this, visitor);
12723 getInput().getNode()->visit(this, visitor);
12724 visitor->post(parent, this);
12725}
12726
12727bool ConvertToNode::isEqual(const ConvertToNode &other) const {
12728 return true &&
12729 Input_ == other.Input_ &&
12730 predicate_ == other.predicate_ &&
12731 getType(0) == other.getType(0);
12732}
12733
12734Node* ConvertToNode::clone() const {
12735 return new ConvertToNode(getName(), getResult().getType(), getInput());
12736}
12737
12738llvm::hash_code ConvertToNode::getHash() const {
12739 return llvm::hash_combine(
12740 Input_);
12741}
12742
12743unsigned ExternalFunctionCallNode::getNumInputs() const {
12744 return 0 + Inputs_.size();
12745}
12746
12747std::string ExternalFunctionCallNode::getInputName(unsigned idx) const {
12748 idx -= 0;
12749 if (idx < Inputs_.size()) { return "Inputs" + std::to_string(idx); }
12750 idx -= Inputs_.size();
12751 llvm_unreachable("Invalid index");
12752}
12753
12754NodeValue ExternalFunctionCallNode::getNthInput(unsigned idx) {
12755 idx -= 0;
12756 if (idx < Inputs_.size()) { return Inputs_[idx]; }
12757 idx -= Inputs_.size();
12758 llvm_unreachable("Invalid index");
12759}
12760
12761void ExternalFunctionCallNode::setNthInput(unsigned idx, NodeValue val) {
12762 idx -= 0;
12763 if (idx < Inputs_.size()) { Inputs_[idx] = val; return; }
12764 idx -= Inputs_.size();
12765 llvm_unreachable("Invalid index");
12766}
12767
12768llvm::StringRef ExternalFunctionCallNode::getOutputName(unsigned idx) const {
12769 if (idx == 0) { return "Result"; }
12770 llvm_unreachable("Invalid index");
12771}
12772
12773std::string ExternalFunctionCallNode::getDebugDesc() const {
12774 DescriptionBuilder db(getKindName());
12775 db.addParam("Name", separateString(getName(), 100, "\n"));
12776 if (hasPredicate()) db.addParam("Predicate", "Yes");
12777 db
12778 .addParam("FunctionName", getFunctionName())
12779 .addParam("FunctionImpl", getFunctionImpl())
12780 .addParam("FunctionKind", getFunctionKind())
12781 .addParam("Users", getNumUsers());
12782 {
12783 unsigned mIndex = 0;
12784 for (const auto &II : getInputs()) {
12785 db.addParam("Inputs"+std::to_string(mIndex++), *II.getType());
12786 }
12787 }
12788 db.addParam("Result", *(getResult().getType()));
12789 return db;
12790}
12791
12792void ExternalFunctionCallNode::visit(Node *parent, NodeWalker *visitor) {
12793 if (!visitor->shouldVisit(parent, this)) { return; }
12794 visitor->pre(parent, this);
12795if (hasPredicate())
12796 getPredicate().getNode()->visit(this, visitor);
12797 for (auto &I : Inputs_) { I.getNode()->visit(this, visitor); }
12798 visitor->post(parent, this);
12799}
12800
12801bool ExternalFunctionCallNode::isEqual(const ExternalFunctionCallNode &other) const {
12802 return true &&
12803 predicate_ == other.predicate_ &&
12804 Inputs_ == other.Inputs_ &&
12805 FunctionName_ == other.FunctionName_ &&
12806 FunctionImpl_ == other.FunctionImpl_ &&
12807 FunctionKind_ == other.FunctionKind_ &&
12808 getType(0) == other.getType(0);
12809}
12810
12811Node* ExternalFunctionCallNode::clone() const {
12812 return new ExternalFunctionCallNode(getName(), getResult().getType(), getInputs(), getFunctionName(), getFunctionImpl(), getFunctionKind());
12813}
12814
12815llvm::hash_code ExternalFunctionCallNode::getHash() const {
12816 return llvm::hash_combine(
12817 llvm::hash_combine_range(Inputs_.begin(), Inputs_.end()),
12818 FunctionName_,
12819 FunctionImpl_,
12820 FunctionKind_);
12821}
12822
12823unsigned AudioSpectrogramNode::getNumInputs() const {
12824 return 5;
12825}
12826
12827std::string AudioSpectrogramNode::getInputName(unsigned idx) const {
12828 if (idx == 0) { return "Input"; }
12829 if (idx == 1) { return "Window"; }
12830 if (idx == 2) { return "TwiddleFactors"; }
12831 if (idx == 3) { return "BitReverseIndices"; }
12832 if (idx == 4) { return "ComplexToRealWeights"; }
12833 idx -= 5;
12834 llvm_unreachable("Invalid index");
12835}
12836
12837NodeValue AudioSpectrogramNode::getNthInput(unsigned idx) {
12838 if (idx == 0) { return Input_; }
12839 if (idx == 1) { return Window_; }
12840 if (idx == 2) { return TwiddleFactors_; }
12841 if (idx == 3) { return BitReverseIndices_; }
12842 if (idx == 4) { return ComplexToRealWeights_; }
12843 idx -= 5;
12844 llvm_unreachable("Invalid index");
12845}
12846
12847void AudioSpectrogramNode::setNthInput(unsigned idx, NodeValue val) {
12848 if (idx == 0) { Input_ = val; return; }
12849 if (idx == 1) { Window_ = val; return; }
12850 if (idx == 2) { TwiddleFactors_ = val; return; }
12851 if (idx == 3) { BitReverseIndices_ = val; return; }
12852 if (idx == 4) { ComplexToRealWeights_ = val; return; }
12853 idx -= 5;
12854 llvm_unreachable("Invalid index");
12855}
12856
12857llvm::StringRef AudioSpectrogramNode::getOutputName(unsigned idx) const {
12858 if (idx == 0) { return "Spectrogram"; }
12859 llvm_unreachable("Invalid index");
12860}
12861
12862std::string AudioSpectrogramNode::getDebugDesc() const {
12863 DescriptionBuilder db(getKindName());
12864 db.addParam("Name", separateString(getName(), 100, "\n"));
12865 if (hasPredicate()) db.addParam("Predicate", "Yes");
12866 db
12867 .addParam("Input", *(getInput().getType()))
12868 .addParam("Window", *(getWindow().getType()))
12869 .addParam("TwiddleFactors", *(getTwiddleFactors().getType()))
12870 .addParam("BitReverseIndices", *(getBitReverseIndices().getType()))
12871 .addParam("ComplexToRealWeights", *(getComplexToRealWeights().getType()))
12872 .addParam("WindowSize", getWindowSize())
12873 .addParam("WindowStride", getWindowStride())
12874 .addParam("MagnitudeSquared", getMagnitudeSquared())
12875 .addParam("Users", getNumUsers());
12876 db.addParam("Spectrogram", *(getSpectrogram().getType()));
12877 return db;
12878}
12879
12880void AudioSpectrogramNode::visit(Node *parent, NodeWalker *visitor) {
12881 if (!visitor->shouldVisit(parent, this)) { return; }
12882 visitor->pre(parent, this);
12883if (hasPredicate())
12884 getPredicate().getNode()->visit(this, visitor);
12885 getInput().getNode()->visit(this, visitor);
12886 getWindow().getNode()->visit(this, visitor);
12887 getTwiddleFactors().getNode()->visit(this, visitor);
12888 getBitReverseIndices().getNode()->visit(this, visitor);
12889 getComplexToRealWeights().getNode()->visit(this, visitor);
12890 visitor->post(parent, this);
12891}
12892
12893bool AudioSpectrogramNode::isEqual(const AudioSpectrogramNode &other) const {
12894 return true &&
12895 Input_ == other.Input_ &&
12896 Window_ == other.Window_ &&
12897 TwiddleFactors_ == other.TwiddleFactors_ &&
12898 BitReverseIndices_ == other.BitReverseIndices_ &&
12899 ComplexToRealWeights_ == other.ComplexToRealWeights_ &&
12900 predicate_ == other.predicate_ &&
12901 WindowSize_ == other.WindowSize_ &&
12902 WindowStride_ == other.WindowStride_ &&
12903 MagnitudeSquared_ == other.MagnitudeSquared_ &&
12904 getType(0) == other.getType(0);
12905}
12906
12907Node* AudioSpectrogramNode::clone() const {
12908 return new AudioSpectrogramNode(getName(), getSpectrogram().getType(), getInput(), getWindow(), getTwiddleFactors(), getBitReverseIndices(), getComplexToRealWeights(), getWindowSize(), getWindowStride(), getMagnitudeSquared());
12909}
12910
12911llvm::hash_code AudioSpectrogramNode::getHash() const {
12912 return llvm::hash_combine(
12913 WindowSize_,
12914 WindowStride_,
12915 MagnitudeSquared_,
12916 Input_,
12917 Window_,
12918 TwiddleFactors_,
12919 BitReverseIndices_,
12920 ComplexToRealWeights_);
12921}
12922
12923unsigned MFCCNode::getNumInputs() const {
12924 return 4;
12925}
12926
12927std::string MFCCNode::getInputName(unsigned idx) const {
12928 if (idx == 0) { return "Spectrogram"; }
12929 if (idx == 1) { return "MelWeights"; }
12930 if (idx == 2) { return "MelRanges"; }
12931 if (idx == 3) { return "DctMat"; }
12932 idx -= 4;
12933 llvm_unreachable("Invalid index");
12934}
12935
12936NodeValue MFCCNode::getNthInput(unsigned idx) {
12937 if (idx == 0) { return Spectrogram_; }
12938 if (idx == 1) { return MelWeights_; }
12939 if (idx == 2) { return MelRanges_; }
12940 if (idx == 3) { return DctMat_; }
12941 idx -= 4;
12942 llvm_unreachable("Invalid index");
12943}
12944
12945void MFCCNode::setNthInput(unsigned idx, NodeValue val) {
12946 if (idx == 0) { Spectrogram_ = val; return; }
12947 if (idx == 1) { MelWeights_ = val; return; }
12948 if (idx == 2) { MelRanges_ = val; return; }
12949 if (idx == 3) { DctMat_ = val; return; }
12950 idx -= 4;
12951 llvm_unreachable("Invalid index");
12952}
12953
12954llvm::StringRef MFCCNode::getOutputName(unsigned idx) const {
12955 if (idx == 0) { return "Coefficients"; }
12956 llvm_unreachable("Invalid index");
12957}
12958
12959std::string MFCCNode::getDebugDesc() const {
12960 DescriptionBuilder db(getKindName());
12961 db.addParam("Name", separateString(getName(), 100, "\n"));
12962 if (hasPredicate()) db.addParam("Predicate", "Yes");
12963 db
12964 .addParam("Spectrogram", *(getSpectrogram().getType()))
12965 .addParam("MelWeights", *(getMelWeights().getType()))
12966 .addParam("MelRanges", *(getMelRanges().getType()))
12967 .addParam("DctMat", *(getDctMat().getType()))
12968 .addParam("SampleRate", getSampleRate())
12969 .addParam("LowerFrequency", getLowerFrequency())
12970 .addParam("UpperFrequency", getUpperFrequency())
12971 .addParam("FilterBankCount", getFilterBankCount())
12972 .addParam("NumCoefficients", getNumCoefficients())
12973 .addParam("Users", getNumUsers());
12974 db.addParam("Coefficients", *(getCoefficients().getType()));
12975 return db;
12976}
12977
12978void MFCCNode::visit(Node *parent, NodeWalker *visitor) {
12979 if (!visitor->shouldVisit(parent, this)) { return; }
12980 visitor->pre(parent, this);
12981if (hasPredicate())
12982 getPredicate().getNode()->visit(this, visitor);
12983 getSpectrogram().getNode()->visit(this, visitor);
12984 getMelWeights().getNode()->visit(this, visitor);
12985 getMelRanges().getNode()->visit(this, visitor);
12986 getDctMat().getNode()->visit(this, visitor);
12987 visitor->post(parent, this);
12988}
12989
12990bool MFCCNode::isEqual(const MFCCNode &other) const {
12991 return true &&
12992 Spectrogram_ == other.Spectrogram_ &&
12993 MelWeights_ == other.MelWeights_ &&
12994 MelRanges_ == other.MelRanges_ &&
12995 DctMat_ == other.DctMat_ &&
12996 predicate_ == other.predicate_ &&
12997 SampleRate_ == other.SampleRate_ &&
12998 LowerFrequency_ == other.LowerFrequency_ &&
12999 UpperFrequency_ == other.UpperFrequency_ &&
13000 FilterBankCount_ == other.FilterBankCount_ &&
13001 NumCoefficients_ == other.NumCoefficients_ &&
13002 getType(0) == other.getType(0);
13003}
13004
13005Node* MFCCNode::clone() const {
13006 return new MFCCNode(getName(), getCoefficients().getType(), getSpectrogram(), getMelWeights(), getMelRanges(), getDctMat(), getSampleRate(), getLowerFrequency(), getUpperFrequency(), getFilterBankCount(), getNumCoefficients());
13007}
13008
13009llvm::hash_code MFCCNode::getHash() const {
13010 return llvm::hash_combine(
13011 toBinary(SampleRate_),
13012 toBinary(LowerFrequency_),
13013 toBinary(UpperFrequency_),
13014 FilterBankCount_,
13015 NumCoefficients_,
13016 Spectrogram_,
13017 MelWeights_,
13018 MelRanges_,
13019 DctMat_);
13020}
13021
13022unsigned NonMaxSuppressionNode::getNumInputs() const {
13023 return 2;
13024}
13025
13026std::string NonMaxSuppressionNode::getInputName(unsigned idx) const {
13027 if (idx == 0) { return "Boxes"; }
13028 if (idx == 1) { return "Scores"; }
13029 idx -= 2;
13030 llvm_unreachable("Invalid index");
13031}
13032
13033NodeValue NonMaxSuppressionNode::getNthInput(unsigned idx) {
13034 if (idx == 0) { return Boxes_; }
13035 if (idx == 1) { return Scores_; }
13036 idx -= 2;
13037 llvm_unreachable("Invalid index");
13038}
13039
13040void NonMaxSuppressionNode::setNthInput(unsigned idx, NodeValue val) {
13041 if (idx == 0) { Boxes_ = val; return; }
13042 if (idx == 1) { Scores_ = val; return; }
13043 idx -= 2;
13044 llvm_unreachable("Invalid index");
13045}
13046
13047llvm::StringRef NonMaxSuppressionNode::getOutputName(unsigned idx) const {
13048 if (idx == 0) { return "Indices"; }
13049 if (idx == 1) { return "NumberOfSelectedIndices"; }
13050 llvm_unreachable("Invalid index");
13051}
13052
13053std::string NonMaxSuppressionNode::getDebugDesc() const {
13054 DescriptionBuilder db(getKindName());
13055 db.addParam("Name", separateString(getName(), 100, "\n"));
13056 if (hasPredicate()) db.addParam("Predicate", "Yes");
13057 db
13058 .addParam("Boxes", *(getBoxes().getType()))
13059 .addParam("Scores", *(getScores().getType()))
13060 .addParam("CenterPointBox", getCenterPointBox())
13061 .addParam("MaxOutputBoxesPerClass", getMaxOutputBoxesPerClass())
13062 .addParam("IouThreshold", getIouThreshold())
13063 .addParam("ScoreThreshold", getScoreThreshold())
13064 .addParam("IsTFVersion4", getIsTFVersion4())
13065 .addParam("Users", getNumUsers());
13066 db.addParam("Indices", *(getIndices().getType()));
13067 db.addParam("NumberOfSelectedIndices", *(getNumberOfSelectedIndices().getType()));
13068 return db;
13069}
13070
13071void NonMaxSuppressionNode::visit(Node *parent, NodeWalker *visitor) {
13072 if (!visitor->shouldVisit(parent, this)) { return; }
13073 visitor->pre(parent, this);
13074if (hasPredicate())
13075 getPredicate().getNode()->visit(this, visitor);
13076 getBoxes().getNode()->visit(this, visitor);
13077 getScores().getNode()->visit(this, visitor);
13078 visitor->post(parent, this);
13079}
13080
13081bool NonMaxSuppressionNode::isEqual(const NonMaxSuppressionNode &other) const {
13082 return true &&
13083 Boxes_ == other.Boxes_ &&
13084 Scores_ == other.Scores_ &&
13085 predicate_ == other.predicate_ &&
13086 CenterPointBox_ == other.CenterPointBox_ &&
13087 MaxOutputBoxesPerClass_ == other.MaxOutputBoxesPerClass_ &&
13088 IouThreshold_ == other.IouThreshold_ &&
13089 ScoreThreshold_ == other.ScoreThreshold_ &&
13090 IsTFVersion4_ == other.IsTFVersion4_ &&
13091 getType(0) == other.getType(0) &&
13092 getType(1) == other.getType(1);
13093}
13094
13095Node* NonMaxSuppressionNode::clone() const {
13096 return new NonMaxSuppressionNode(getName(), getIndices().getType(), getNumberOfSelectedIndices().getType(), getBoxes(), getScores(), getCenterPointBox(), getMaxOutputBoxesPerClass(), getIouThreshold(), getScoreThreshold(), getIsTFVersion4());
13097}
13098
13099llvm::hash_code NonMaxSuppressionNode::getHash() const {
13100 return llvm::hash_combine(
13101 CenterPointBox_,
13102 MaxOutputBoxesPerClass_,
13103 toBinary(IouThreshold_),
13104 toBinary(ScoreThreshold_),
13105 IsTFVersion4_,
13106 Boxes_,
13107 Scores_);
13108}
13109
13110unsigned TFLiteDetectionPostProcessNode::getNumInputs() const {
13111 return 3;
13112}
13113
13114std::string TFLiteDetectionPostProcessNode::getInputName(unsigned idx) const {
13115 if (idx == 0) { return "Boxes"; }
13116 if (idx == 1) { return "Scores"; }
13117 if (idx == 2) { return "Anchors"; }
13118 idx -= 3;
13119 llvm_unreachable("Invalid index");
13120}
13121
13122NodeValue TFLiteDetectionPostProcessNode::getNthInput(unsigned idx) {
13123 if (idx == 0) { return Boxes_; }
13124 if (idx == 1) { return Scores_; }
13125 if (idx == 2) { return Anchors_; }
13126 idx -= 3;
13127 llvm_unreachable("Invalid index");
13128}
13129
13130void TFLiteDetectionPostProcessNode::setNthInput(unsigned idx, NodeValue val) {
13131 if (idx == 0) { Boxes_ = val; return; }
13132 if (idx == 1) { Scores_ = val; return; }
13133 if (idx == 2) { Anchors_ = val; return; }
13134 idx -= 3;
13135 llvm_unreachable("Invalid index");
13136}
13137
13138llvm::StringRef TFLiteDetectionPostProcessNode::getOutputName(unsigned idx) const {
13139 if (idx == 0) { return "DetectionBoxes"; }
13140 if (idx == 1) { return "DetectionClasses"; }
13141 if (idx == 2) { return "DetectionScores"; }
13142 if (idx == 3) { return "NumDetections"; }
13143 llvm_unreachable("Invalid index");
13144}
13145
13146std::string TFLiteDetectionPostProcessNode::getDebugDesc() const {
13147 DescriptionBuilder db(getKindName());
13148 db.addParam("Name", separateString(getName(), 100, "\n"));
13149 if (hasPredicate()) db.addParam("Predicate", "Yes");
13150 db
13151 .addParam("Boxes", *(getBoxes().getType()))
13152 .addParam("Scores", *(getScores().getType()))
13153 .addParam("Anchors", *(getAnchors().getType()))
13154 .addParam("NumClasses", getNumClasses())
13155 .addParam("MaxDetections", getMaxDetections())
13156 .addParam("MaxClassesPerDetection", getMaxClassesPerDetection())
13157 .addParam("MaxDetectionsPerClass", getMaxDetectionsPerClass())
13158 .addParam("IouThreshold", getIouThreshold())
13159 .addParam("ScoreThreshold", getScoreThreshold())
13160 .addParam("XScale", getXScale())
13161 .addParam("YScale", getYScale())
13162 .addParam("HScale", getHScale())
13163 .addParam("WScale", getWScale())
13164 .addParam("RegularNMS", getRegularNMS())
13165 .addParam("Users", getNumUsers());
13166 db.addParam("DetectionBoxes", *(getDetectionBoxes().getType()));
13167 db.addParam("DetectionClasses", *(getDetectionClasses().getType()));
13168 db.addParam("DetectionScores", *(getDetectionScores().getType()));
13169 db.addParam("NumDetections", *(getNumDetections().getType()));
13170 return db;
13171}
13172
13173void TFLiteDetectionPostProcessNode::visit(Node *parent, NodeWalker *visitor) {
13174 if (!visitor->shouldVisit(parent, this)) { return; }
13175 visitor->pre(parent, this);
13176if (hasPredicate())
13177 getPredicate().getNode()->visit(this, visitor);
13178 getBoxes().getNode()->visit(this, visitor);
13179 getScores().getNode()->visit(this, visitor);
13180 getAnchors().getNode()->visit(this, visitor);
13181 visitor->post(parent, this);
13182}
13183
13184bool TFLiteDetectionPostProcessNode::isEqual(const TFLiteDetectionPostProcessNode &other) const {
13185 return true &&
13186 Boxes_ == other.Boxes_ &&
13187 Scores_ == other.Scores_ &&
13188 Anchors_ == other.Anchors_ &&
13189 predicate_ == other.predicate_ &&
13190 NumClasses_ == other.NumClasses_ &&
13191 MaxDetections_ == other.MaxDetections_ &&
13192 MaxClassesPerDetection_ == other.MaxClassesPerDetection_ &&
13193 MaxDetectionsPerClass_ == other.MaxDetectionsPerClass_ &&
13194 IouThreshold_ == other.IouThreshold_ &&
13195 ScoreThreshold_ == other.ScoreThreshold_ &&
13196 XScale_ == other.XScale_ &&
13197 YScale_ == other.YScale_ &&
13198 HScale_ == other.HScale_ &&
13199 WScale_ == other.WScale_ &&
13200 RegularNMS_ == other.RegularNMS_ &&
13201 getType(0) == other.getType(0) &&
13202 getType(1) == other.getType(1) &&
13203 getType(2) == other.getType(2) &&
13204 getType(3) == other.getType(3);
13205}
13206
13207Node* TFLiteDetectionPostProcessNode::clone() const {
13208 return new TFLiteDetectionPostProcessNode(getName(), getDetectionBoxes().getType(), getDetectionClasses().getType(), getDetectionScores().getType(), getNumDetections().getType(), getBoxes(), getScores(), getAnchors(), getNumClasses(), getMaxDetections(), getMaxClassesPerDetection(), getMaxDetectionsPerClass(), getIouThreshold(), getScoreThreshold(), getXScale(), getYScale(), getHScale(), getWScale(), getRegularNMS());
13209}
13210
13211llvm::hash_code TFLiteDetectionPostProcessNode::getHash() const {
13212 return llvm::hash_combine(
13213 NumClasses_,
13214 MaxDetections_,
13215 MaxClassesPerDetection_,
13216 MaxDetectionsPerClass_,
13217 toBinary(IouThreshold_),
13218 toBinary(ScoreThreshold_),
13219 toBinary(XScale_),
13220 toBinary(YScale_),
13221 toBinary(HScale_),
13222 toBinary(WScale_),
13223 RegularNMS_,
13224 Boxes_,
13225 Scores_,
13226 Anchors_);
13227}
13228
13229unsigned ROIAlignNode::getNumInputs() const {
13230 return 3;
13231}
13232
13233std::string ROIAlignNode::getInputName(unsigned idx) const {
13234 if (idx == 0) { return "FeatureMap"; }
13235 if (idx == 1) { return "Boxes"; }
13236 if (idx == 2) { return "BatchIndices"; }
13237 idx -= 3;
13238 llvm_unreachable("Invalid index");
13239}
13240
13241NodeValue ROIAlignNode::getNthInput(unsigned idx) {
13242 if (idx == 0) { return FeatureMap_; }
13243 if (idx == 1) { return Boxes_; }
13244 if (idx == 2) { return BatchIndices_; }
13245 idx -= 3;
13246 llvm_unreachable("Invalid index");
13247}
13248
13249void ROIAlignNode::setNthInput(unsigned idx, NodeValue val) {
13250 if (idx == 0) { FeatureMap_ = val; return; }
13251 if (idx == 1) { Boxes_ = val; return; }
13252 if (idx == 2) { BatchIndices_ = val; return; }
13253 idx -= 3;
13254 llvm_unreachable("Invalid index");
13255}
13256
13257llvm::StringRef ROIAlignNode::getOutputName(unsigned idx) const {
13258 if (idx == 0) { return "Result"; }
13259 llvm_unreachable("Invalid index");
13260}
13261
13262std::string ROIAlignNode::getDebugDesc() const {
13263 DescriptionBuilder db(getKindName());
13264 db.addParam("Name", separateString(getName(), 100, "\n"));
13265 if (hasPredicate()) db.addParam("Predicate", "Yes");
13266 db
13267 .addParam("FeatureMap", *(getFeatureMap().getType()))
13268 .addParam("Boxes", *(getBoxes().getType()))
13269 .addParam("BatchIndices", *(getBatchIndices().getType()))
13270 .addParam("Mode", getMode())
13271 .addParam("OutputHeight", getOutputHeight())
13272 .addParam("OutputWidth", getOutputWidth())
13273 .addParam("SamplingRatio", getSamplingRatio())
13274 .addParam("SpatialScale", getSpatialScale())
13275 .addParam("Aligned", getAligned())
13276 .addParam("Rotated", getRotated())
13277 .addParam("Users", getNumUsers());
13278 db.addParam("Result", *(getResult().getType()));
13279 return db;
13280}
13281
13282void ROIAlignNode::visit(Node *parent, NodeWalker *visitor) {
13283 if (!visitor->shouldVisit(parent, this)) { return; }
13284 visitor->pre(parent, this);
13285if (hasPredicate())
13286 getPredicate().getNode()->visit(this, visitor);
13287 getFeatureMap().getNode()->visit(this, visitor);
13288 getBoxes().getNode()->visit(this, visitor);
13289 getBatchIndices().getNode()->visit(this, visitor);
13290 visitor->post(parent, this);
13291}
13292
13293bool ROIAlignNode::isEqual(const ROIAlignNode &other) const {
13294 return true &&
13295 FeatureMap_ == other.FeatureMap_ &&
13296 Boxes_ == other.Boxes_ &&
13297 BatchIndices_ == other.BatchIndices_ &&
13298 predicate_ == other.predicate_ &&
13299 Mode_ == other.Mode_ &&
13300 OutputHeight_ == other.OutputHeight_ &&
13301 OutputWidth_ == other.OutputWidth_ &&
13302 SamplingRatio_ == other.SamplingRatio_ &&
13303 SpatialScale_ == other.SpatialScale_ &&
13304 Aligned_ == other.Aligned_ &&
13305 Rotated_ == other.Rotated_ &&
13306 getType(0) == other.getType(0);
13307}
13308
13309Node* ROIAlignNode::clone() const {
13310 return new ROIAlignNode(getName(), getResult().getType(), getFeatureMap(), getBoxes(), getBatchIndices(), getMode(), getOutputHeight(), getOutputWidth(), getSamplingRatio(), getSpatialScale(), getAligned(), getRotated());
13311}
13312
13313llvm::hash_code ROIAlignNode::getHash() const {
13314 return llvm::hash_combine(
13315 Mode_,
13316 OutputHeight_,
13317 OutputWidth_,
13318 SamplingRatio_,
13319 toBinary(SpatialScale_),
13320 Aligned_,
13321 Rotated_,
13322 FeatureMap_,
13323 Boxes_,
13324 BatchIndices_);
13325}
13326
13327unsigned BBoxTransformNode::getNumInputs() const {
13328 return 3;
13329}
13330
13331std::string BBoxTransformNode::getInputName(unsigned idx) const {
13332 if (idx == 0) { return "Rois"; }
13333 if (idx == 1) { return "Deltas"; }
13334 if (idx == 2) { return "ImInfo"; }
13335 idx -= 3;
13336 llvm_unreachable("Invalid index");
13337}
13338
13339NodeValue BBoxTransformNode::getNthInput(unsigned idx) {
13340 if (idx == 0) { return Rois_; }
13341 if (idx == 1) { return Deltas_; }
13342 if (idx == 2) { return ImInfo_; }
13343 idx -= 3;
13344 llvm_unreachable("Invalid index");
13345}
13346
13347void BBoxTransformNode::setNthInput(unsigned idx, NodeValue val) {
13348 if (idx == 0) { Rois_ = val; return; }
13349 if (idx == 1) { Deltas_ = val; return; }
13350 if (idx == 2) { ImInfo_ = val; return; }
13351 idx -= 3;
13352 llvm_unreachable("Invalid index");
13353}
13354
13355llvm::StringRef BBoxTransformNode::getOutputName(unsigned idx) const {
13356 if (idx == 0) { return "BoxOut"; }
13357 if (idx == 1) { return "RoiBatchSplits"; }
13358 llvm_unreachable("Invalid index");
13359}
13360
13361std::string BBoxTransformNode::getDebugDesc() const {
13362 DescriptionBuilder db(getKindName());
13363 db.addParam("Name", separateString(getName(), 100, "\n"));
13364 if (hasPredicate()) db.addParam("Predicate", "Yes");
13365 db
13366 .addParam("Rois", *(getRois().getType()))
13367 .addParam("Deltas", *(getDeltas().getType()))
13368 .addParam("ImInfo", *(getImInfo().getType()))
13369 .addParam("Weights", getWeights())
13370 .addParam("ApplyScale", getApplyScale())
13371 .addParam("Rotated", getRotated())
13372 .addParam("AngleBoundOn", getAngleBoundOn())
13373 .addParam("AngleBoundLo", getAngleBoundLo())
13374 .addParam("AngleBoundHi", getAngleBoundHi())
13375 .addParam("ClipAngleThresh", getClipAngleThresh())
13376 .addParam("LegacyPlusOne", getLegacyPlusOne())
13377 .addParam("Users", getNumUsers());
13378 db.addParam("BoxOut", *(getBoxOut().getType()));
13379 db.addParam("RoiBatchSplits", *(getRoiBatchSplits().getType()));
13380 return db;
13381}
13382
13383void BBoxTransformNode::visit(Node *parent, NodeWalker *visitor) {
13384 if (!visitor->shouldVisit(parent, this)) { return; }
13385 visitor->pre(parent, this);
13386if (hasPredicate())
13387 getPredicate().getNode()->visit(this, visitor);
13388 getRois().getNode()->visit(this, visitor);
13389 getDeltas().getNode()->visit(this, visitor);
13390 getImInfo().getNode()->visit(this, visitor);
13391 visitor->post(parent, this);
13392}
13393
13394bool BBoxTransformNode::isEqual(const BBoxTransformNode &other) const {
13395 return true &&
13396 Rois_ == other.Rois_ &&
13397 Deltas_ == other.Deltas_ &&
13398 ImInfo_ == other.ImInfo_ &&
13399 predicate_ == other.predicate_ &&
13400 Weights_ == other.Weights_ &&
13401 ApplyScale_ == other.ApplyScale_ &&
13402 Rotated_ == other.Rotated_ &&
13403 AngleBoundOn_ == other.AngleBoundOn_ &&
13404 AngleBoundLo_ == other.AngleBoundLo_ &&
13405 AngleBoundHi_ == other.AngleBoundHi_ &&
13406 ClipAngleThresh_ == other.ClipAngleThresh_ &&
13407 LegacyPlusOne_ == other.LegacyPlusOne_ &&
13408 getType(0) == other.getType(0) &&
13409 getType(1) == other.getType(1);
13410}
13411
13412Node* BBoxTransformNode::clone() const {
13413 return new BBoxTransformNode(getName(), getBoxOut().getType(), getRoiBatchSplits().getType(), getRois(), getDeltas(), getImInfo(), getWeights(), getApplyScale(), getRotated(), getAngleBoundOn(), getAngleBoundLo(), getAngleBoundHi(), getClipAngleThresh(), getLegacyPlusOne());
13414}
13415
13416llvm::hash_code BBoxTransformNode::getHash() const {
13417 return llvm::hash_combine(
13418 [](const std::vector<float>& floatVec) -> llvm::hash_code {
13419 std::vector<size_t> sizeVec = toBinary(floatVec);
13420 return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end());
13421 }(Weights_),
13422 ApplyScale_,
13423 Rotated_,
13424 AngleBoundOn_,
13425 AngleBoundLo_,
13426 AngleBoundHi_,
13427 toBinary(ClipAngleThresh_),
13428 LegacyPlusOne_,
13429 Rois_,
13430 Deltas_,
13431 ImInfo_);
13432}
13433
13434unsigned CollectRpnProposalsNode::getNumInputs() const {
13435 return 0 + RoisIn_.size() + RoisProbsIn_.size();
13436}
13437
13438std::string CollectRpnProposalsNode::getInputName(unsigned idx) const {
13439 idx -= 0;
13440 if (idx < RoisIn_.size()) { return "RoisIn" + std::to_string(idx); }
13441 idx -= RoisIn_.size();
13442 if (idx < RoisProbsIn_.size()) { return "RoisProbsIn" + std::to_string(idx); }
13443 idx -= RoisProbsIn_.size();
13444 llvm_unreachable("Invalid index");
13445}
13446
13447NodeValue CollectRpnProposalsNode::getNthInput(unsigned idx) {
13448 idx -= 0;
13449 if (idx < RoisIn_.size()) { return RoisIn_[idx]; }
13450 idx -= RoisIn_.size();
13451 if (idx < RoisProbsIn_.size()) { return RoisProbsIn_[idx]; }
13452 idx -= RoisProbsIn_.size();
13453 llvm_unreachable("Invalid index");
13454}
13455
13456void CollectRpnProposalsNode::setNthInput(unsigned idx, NodeValue val) {
13457 idx -= 0;
13458 if (idx < RoisIn_.size()) { RoisIn_[idx] = val; return; }
13459 idx -= RoisIn_.size();
13460 if (idx < RoisProbsIn_.size()) { RoisProbsIn_[idx] = val; return; }
13461 idx -= RoisProbsIn_.size();
13462 llvm_unreachable("Invalid index");
13463}
13464
13465llvm::StringRef CollectRpnProposalsNode::getOutputName(unsigned idx) const {
13466 if (idx == 0) { return "Result"; }
13467 llvm_unreachable("Invalid index");
13468}
13469
13470std::string CollectRpnProposalsNode::getDebugDesc() const {
13471 DescriptionBuilder db(getKindName());
13472 db.addParam("Name", separateString(getName(), 100, "\n"));
13473 if (hasPredicate()) db.addParam("Predicate", "Yes");
13474 db
13475 .addParam("RpnMaxLevel", getRpnMaxLevel())
13476 .addParam("RpnMinLevel", getRpnMinLevel())
13477 .addParam("RpnPostNmsTopN", getRpnPostNmsTopN())
13478 .addParam("Users", getNumUsers());
13479 {
13480 unsigned mIndex = 0;
13481 for (const auto &II : getRoisIn()) {
13482 db.addParam("RoisIn"+std::to_string(mIndex++), *II.getType());
13483 }
13484 }
13485 {
13486 unsigned mIndex = 0;
13487 for (const auto &II : getRoisProbsIn()) {
13488 db.addParam("RoisProbsIn"+std::to_string(mIndex++), *II.getType());
13489 }
13490 }
13491 db.addParam("Result", *(getResult().getType()));
13492 return db;
13493}
13494
13495void CollectRpnProposalsNode::visit(Node *parent, NodeWalker *visitor) {
13496 if (!visitor->shouldVisit(parent, this)) { return; }
13497 visitor->pre(parent, this);
13498if (hasPredicate())
13499 getPredicate().getNode()->visit(this, visitor);
13500 for (auto &I : RoisIn_) { I.getNode()->visit(this, visitor); }
13501 for (auto &I : RoisProbsIn_) { I.getNode()->visit(this, visitor); }
13502 visitor->post(parent, this);
13503}
13504
13505bool CollectRpnProposalsNode::isEqual(const CollectRpnProposalsNode &other) const {
13506 return true &&
13507 predicate_ == other.predicate_ &&
13508 RoisIn_ == other.RoisIn_ &&
13509 RoisProbsIn_ == other.RoisProbsIn_ &&
13510 RpnMaxLevel_ == other.RpnMaxLevel_ &&
13511 RpnMinLevel_ == other.RpnMinLevel_ &&
13512 RpnPostNmsTopN_ == other.RpnPostNmsTopN_ &&
13513 getType(0) == other.getType(0);
13514}
13515
13516Node* CollectRpnProposalsNode::clone() const {
13517 return new CollectRpnProposalsNode(getName(), getResult().getType(), getRoisIn(), getRoisProbsIn(), getRpnMaxLevel(), getRpnMinLevel(), getRpnPostNmsTopN());
13518}
13519
13520llvm::hash_code CollectRpnProposalsNode::getHash() const {
13521 return llvm::hash_combine(
13522 llvm::hash_combine_range(RoisIn_.begin(), RoisIn_.end()),
13523 llvm::hash_combine_range(RoisProbsIn_.begin(), RoisProbsIn_.end()),
13524 RpnMaxLevel_,
13525 RpnMinLevel_,
13526 RpnPostNmsTopN_);
13527}
13528
13529unsigned LookupTableNode::getNumInputs() const {
13530 return 3;
13531}
13532
13533std::string LookupTableNode::getInputName(unsigned idx) const {
13534 if (idx == 0) { return "Input"; }
13535 if (idx == 1) { return "Table"; }
13536 if (idx == 2) { return "TableIdx"; }
13537 idx -= 3;
13538 llvm_unreachable("Invalid index");
13539}
13540
13541NodeValue LookupTableNode::getNthInput(unsigned idx) {
13542 if (idx == 0) { return Input_; }
13543 if (idx == 1) { return Table_; }
13544 if (idx == 2) { return TableIdx_; }
13545 idx -= 3;
13546 llvm_unreachable("Invalid index");
13547}
13548
13549void LookupTableNode::setNthInput(unsigned idx, NodeValue val) {
13550 if (idx == 0) { Input_ = val; return; }
13551 if (idx == 1) { Table_ = val; return; }
13552 if (idx == 2) { TableIdx_ = val; return; }
13553 idx -= 3;
13554 llvm_unreachable("Invalid index");
13555}
13556
13557llvm::StringRef LookupTableNode::getOutputName(unsigned idx) const {
13558 if (idx == 0) { return "Result"; }
13559 llvm_unreachable("Invalid index");
13560}
13561
13562std::string LookupTableNode::getDebugDesc() const {
13563 DescriptionBuilder db(getKindName());
13564 db.addParam("Name", separateString(getName(), 100, "\n"));
13565 if (hasPredicate()) db.addParam("Predicate", "Yes");
13566 db
13567 .addParam("Input", *(getInput().getType()))
13568 .addParam("Table", *(getTable().getType()))
13569 .addParam("TableIdx", *(getTableIdx().getType()))
13570 .addParam("Operator", getOperator())
13571 .addParam("OperatorArgs", getOperatorArgs())
13572 .addParam("Users", getNumUsers());
13573 db.addParam("Result", *(getResult().getType()));
13574 return db;
13575}
13576
13577void LookupTableNode::visit(Node *parent, NodeWalker *visitor) {
13578 if (!visitor->shouldVisit(parent, this)) { return; }
13579 visitor->pre(parent, this);
13580if (hasPredicate())
13581 getPredicate().getNode()->visit(this, visitor);
13582 getInput().getNode()->visit(this, visitor);
13583 getTable().getNode()->visit(this, visitor);
13584 getTableIdx().getNode()->visit(this, visitor);
13585 visitor->post(parent, this);
13586}
13587
13588bool LookupTableNode::isEqual(const LookupTableNode &other) const {
13589 return true &&
13590 Input_ == other.Input_ &&
13591 Table_ == other.Table_ &&
13592 TableIdx_ == other.TableIdx_ &&
13593 predicate_ == other.predicate_ &&
13594 Operator_ == other.Operator_ &&
13595 OperatorArgs_ == other.OperatorArgs_ &&
13596 getType(0) == other.getType(0);
13597}
13598
13599Node* LookupTableNode::clone() const {
13600 return new LookupTableNode(getName(), getResult().getType(), getInput(), getTable(), getTableIdx(), getOperator(), getOperatorArgs());
13601}
13602
13603llvm::hash_code LookupTableNode::getHash() const {
13604 return llvm::hash_combine(
13605 Operator_,
13606 [](const std::vector<float>& floatVec) -> llvm::hash_code {
13607 std::vector<size_t> sizeVec = toBinary(floatVec);
13608 return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end());
13609 }(OperatorArgs_),
13610 Input_,
13611 Table_,
13612 TableIdx_);
13613}
13614
13615unsigned CPUMaxSplatNode::getNumInputs() const {
13616 return 1;
13617}
13618
13619std::string CPUMaxSplatNode::getInputName(unsigned idx) const {
13620 if (idx == 0) { return "Input"; }
13621 idx -= 1;
13622 llvm_unreachable("Invalid index");
13623}
13624
13625NodeValue CPUMaxSplatNode::getNthInput(unsigned idx) {
13626 if (idx == 0) { return Input_; }
13627 idx -= 1;
13628 llvm_unreachable("Invalid index");
13629}
13630
13631void CPUMaxSplatNode::setNthInput(unsigned idx, NodeValue val) {
13632 if (idx == 0) { Input_ = val; return; }
13633 idx -= 1;
13634 llvm_unreachable("Invalid index");
13635}
13636
13637llvm::StringRef CPUMaxSplatNode::getOutputName(unsigned idx) const {
13638 if (idx == 0) { return "Result"; }
13639 llvm_unreachable("Invalid index");
13640}
13641
13642std::string CPUMaxSplatNode::getDebugDesc() const {
13643 DescriptionBuilder db(getKindName());
13644 db.addParam("Name", separateString(getName(), 100, "\n"));
13645 if (hasPredicate()) db.addParam("Predicate", "Yes");
13646 db
13647 .addParam("Input", *(getInput().getType()))
13648 .addParam("SplatValue", getSplatValue())
13649 .addParam("Users", getNumUsers());
13650 db.addParam("Result", *(getResult().getType()));
13651 return db;
13652}
13653
13654void CPUMaxSplatNode::visit(Node *parent, NodeWalker *visitor) {
13655 if (!visitor->shouldVisit(parent, this)) { return; }
13656 visitor->pre(parent, this);
13657if (hasPredicate())
13658 getPredicate().getNode()->visit(this, visitor);
13659 getInput().getNode()->visit(this, visitor);
13660 visitor->post(parent, this);
13661}
13662
13663bool CPUMaxSplatNode::isEqual(const CPUMaxSplatNode &other) const {
13664 return true &&
13665 Input_ == other.Input_ &&
13666 predicate_ == other.predicate_ &&
13667 SplatValue_ == other.SplatValue_ &&
13668 getType(0) == other.getType(0);
13669}
13670
13671Node* CPUMaxSplatNode::clone() const {
13672 return new CPUMaxSplatNode(getName(), getInput(), getSplatValue());
13673}
13674
13675llvm::hash_code CPUMaxSplatNode::getHash() const {
13676 return llvm::hash_combine(
13677 toBinary(SplatValue_),
13678 Input_);
13679}
13680
13681unsigned CPUConvDKKC8Node::getNumInputs() const {
13682 return 3;
13683}
13684
13685std::string CPUConvDKKC8Node::getInputName(unsigned idx) const {
13686 if (idx == 0) { return "Input"; }
13687 if (idx == 1) { return "Filter"; }
13688 if (idx == 2) { return "Bias"; }
13689 idx -= 3;
13690 llvm_unreachable("Invalid index");
13691}
13692
13693NodeValue CPUConvDKKC8Node::getNthInput(unsigned idx) {
13694 if (idx == 0) { return Input_; }
13695 if (idx == 1) { return Filter_; }
13696 if (idx == 2) { return Bias_; }
13697 idx -= 3;
13698 llvm_unreachable("Invalid index");
13699}
13700
13701void CPUConvDKKC8Node::setNthInput(unsigned idx, NodeValue val) {
13702 if (idx == 0) { Input_ = val; return; }
13703 if (idx == 1) { Filter_ = val; return; }
13704 if (idx == 2) { Bias_ = val; return; }
13705 idx -= 3;
13706 llvm_unreachable("Invalid index");
13707}
13708
13709llvm::StringRef CPUConvDKKC8Node::getOutputName(unsigned idx) const {
13710 if (idx == 0) { return "Result"; }
13711 llvm_unreachable("Invalid index");
13712}
13713
13714std::string CPUConvDKKC8Node::getDebugDesc() const {
13715 DescriptionBuilder db(getKindName());
13716 db.addParam("Name", separateString(getName(), 100, "\n"));
13717 if (hasPredicate()) db.addParam("Predicate", "Yes");
13718 db
13719 .addParam("Input", *(getInput().getType()))
13720 .addParam("Filter", *(getFilter().getType()))
13721 .addParam("Bias", *(getBias().getType()))
13722 .addParam("Kernels", getKernels())
13723 .addParam("Strides", getStrides())
13724 .addParam("Pads", getPads())
13725 .addParam("Group", getGroup())
13726 .addParam("Users", getNumUsers());
13727 db.addParam("Result", *(getResult().getType()));
13728 return db;
13729}
13730
13731void CPUConvDKKC8Node::visit(Node *parent, NodeWalker *visitor) {
13732 if (!visitor->shouldVisit(parent, this)) { return; }
13733 visitor->pre(parent, this);
13734if (hasPredicate())
13735 getPredicate().getNode()->visit(this, visitor);
13736 getInput().getNode()->visit(this, visitor);
13737 getFilter().getNode()->visit(this, visitor);
13738 getBias().getNode()->visit(this, visitor);
13739 visitor->post(parent, this);
13740}
13741
13742bool CPUConvDKKC8Node::isEqual(const CPUConvDKKC8Node &other) const {
13743 return true &&
13744 Input_ == other.Input_ &&
13745 Filter_ == other.Filter_ &&
13746 Bias_ == other.Bias_ &&
13747 predicate_ == other.predicate_ &&
13748 Kernels_ == other.Kernels_ &&
13749 Strides_ == other.Strides_ &&
13750 Pads_ == other.Pads_ &&
13751 Group_ == other.Group_ &&
13752 getType(0) == other.getType(0);
13753}
13754
13755Node* CPUConvDKKC8Node::clone() const {
13756 return new CPUConvDKKC8Node(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getKernels(), getStrides(), getPads(), getGroup());
13757}
13758
13759llvm::hash_code CPUConvDKKC8Node::getHash() const {
13760 return llvm::hash_combine(
13761 llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()),
13762 llvm::hash_combine_range(Strides_.begin(), Strides_.end()),
13763 llvm::hash_combine_range(Pads_.begin(), Pads_.end()),
13764 Group_,
13765 Input_,
13766 Filter_,
13767 Bias_);
13768}
13769
13770#include "glow/CPUSpecificNodesVerification.h"
13771