1 | #include "glow/Graph/Nodes.h" |
2 | #include "glow/Base/Type.h" |
3 | #include "glow/Support/Support.h" |
4 | using namespace glow; |
5 | |
6 | unsigned SaveNode::getNumInputs() const { |
7 | return 2; |
8 | } |
9 | |
10 | std::string SaveNode::getInputName(unsigned idx) const { |
11 | if (idx == 0) { return "Input" ; } |
12 | if (idx == 1) { return "Output" ; } |
13 | idx -= 2; |
14 | llvm_unreachable("Invalid index" ); |
15 | } |
16 | |
17 | NodeValue SaveNode::getNthInput(unsigned idx) { |
18 | if (idx == 0) { return Input_; } |
19 | if (idx == 1) { return Output_; } |
20 | idx -= 2; |
21 | llvm_unreachable("Invalid index" ); |
22 | } |
23 | |
24 | void SaveNode::setNthInput(unsigned idx, NodeValue val) { |
25 | if (idx == 0) { Input_ = val; return; } |
26 | if (idx == 1) { Output_ = val; return; } |
27 | idx -= 2; |
28 | llvm_unreachable("Invalid index" ); |
29 | } |
30 | |
31 | llvm::StringRef SaveNode::getOutputName(unsigned idx) const { |
32 | llvm_unreachable("Invalid index" ); |
33 | } |
34 | |
35 | std::string SaveNode::getDebugDesc() const { |
36 | DescriptionBuilder db(getKindName()); |
37 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
38 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
39 | db |
40 | .addParam("Input" , *(getInput().getType())) |
41 | .addParam("Output" , *(getOutput().getType())) |
42 | .addParam("Users" , getNumUsers()); |
43 | return db; |
44 | } |
45 | |
46 | void SaveNode::visit(Node *parent, NodeWalker *visitor) { |
47 | if (!visitor->shouldVisit(parent, this)) { return; } |
48 | visitor->pre(parent, this); |
49 | if (hasPredicate()) |
50 | getPredicate().getNode()->visit(this, visitor); |
51 | getInput().getNode()->visit(this, visitor); |
52 | getOutput().getNode()->visit(this, visitor); |
53 | visitor->post(parent, this); |
54 | } |
55 | |
56 | bool SaveNode::isEqual(const SaveNode &other) const { |
57 | return true && |
58 | Input_ == other.Input_ && |
59 | Output_ == other.Output_ && |
60 | predicate_ == other.predicate_; |
61 | } |
62 | |
63 | Node* SaveNode::clone() const { |
64 | return new SaveNode(getName(), getInput(), getOutput()); |
65 | } |
66 | |
67 | llvm::hash_code SaveNode::getHash() const { |
68 | return llvm::hash_combine( |
69 | Input_, |
70 | Output_); |
71 | } |
72 | Placeholder *SaveNode::getPlaceholder() const { return llvm::cast<Placeholder>(Output_.getNode()); }; |
73 | unsigned PadNode::getNumInputs() const { |
74 | return 1; |
75 | } |
76 | |
77 | std::string PadNode::getInputName(unsigned idx) const { |
78 | if (idx == 0) { return "Input" ; } |
79 | idx -= 1; |
80 | llvm_unreachable("Invalid index" ); |
81 | } |
82 | |
83 | NodeValue PadNode::getNthInput(unsigned idx) { |
84 | if (idx == 0) { return Input_; } |
85 | idx -= 1; |
86 | llvm_unreachable("Invalid index" ); |
87 | } |
88 | |
89 | void PadNode::setNthInput(unsigned idx, NodeValue val) { |
90 | if (idx == 0) { Input_ = val; return; } |
91 | idx -= 1; |
92 | llvm_unreachable("Invalid index" ); |
93 | } |
94 | |
95 | llvm::StringRef PadNode::getOutputName(unsigned idx) const { |
96 | if (idx == 0) { return "Result" ; } |
97 | llvm_unreachable("Invalid index" ); |
98 | } |
99 | |
100 | std::string PadNode::getDebugDesc() const { |
101 | DescriptionBuilder db(getKindName()); |
102 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
103 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
104 | db |
105 | .addParam("Input" , *(getInput().getType())) |
106 | .addParam("Mode" , getMode()) |
107 | .addParam("Pads" , getPads()) |
108 | .addParam("Value" , getValue()) |
109 | .addParam("Users" , getNumUsers()); |
110 | db.addParam("Result" , *(getResult().getType())); |
111 | return db; |
112 | } |
113 | |
114 | void PadNode::visit(Node *parent, NodeWalker *visitor) { |
115 | if (!visitor->shouldVisit(parent, this)) { return; } |
116 | visitor->pre(parent, this); |
117 | if (hasPredicate()) |
118 | getPredicate().getNode()->visit(this, visitor); |
119 | getInput().getNode()->visit(this, visitor); |
120 | visitor->post(parent, this); |
121 | } |
122 | |
123 | bool PadNode::isEqual(const PadNode &other) const { |
124 | return true && |
125 | Input_ == other.Input_ && |
126 | predicate_ == other.predicate_ && |
127 | Mode_ == other.Mode_ && |
128 | Pads_ == other.Pads_ && |
129 | Value_ == other.Value_ && |
130 | getType(0) == other.getType(0); |
131 | } |
132 | |
133 | Node* PadNode::clone() const { |
134 | return new PadNode(getName(), getResult().getType(), getInput(), getMode(), getPads(), getValue()); |
135 | } |
136 | |
137 | llvm::hash_code PadNode::getHash() const { |
138 | return llvm::hash_combine( |
139 | Mode_, |
140 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
141 | toBinary(Value_), |
142 | Input_); |
143 | } |
144 | |
145 | unsigned ConvolutionGradNode::getNumInputs() const { |
146 | return 5; |
147 | } |
148 | |
149 | std::string ConvolutionGradNode::getInputName(unsigned idx) const { |
150 | if (idx == 0) { return "Input" ; } |
151 | if (idx == 1) { return "Filter" ; } |
152 | if (idx == 2) { return "Bias" ; } |
153 | if (idx == 3) { return "OriginalOutputForResult" ; } |
154 | if (idx == 4) { return "GradOfOriginalOutputNamedResult" ; } |
155 | idx -= 5; |
156 | llvm_unreachable("Invalid index" ); |
157 | } |
158 | |
159 | NodeValue ConvolutionGradNode::getNthInput(unsigned idx) { |
160 | if (idx == 0) { return Input_; } |
161 | if (idx == 1) { return Filter_; } |
162 | if (idx == 2) { return Bias_; } |
163 | if (idx == 3) { return OriginalOutputForResult_; } |
164 | if (idx == 4) { return GradOfOriginalOutputNamedResult_; } |
165 | idx -= 5; |
166 | llvm_unreachable("Invalid index" ); |
167 | } |
168 | |
169 | void ConvolutionGradNode::setNthInput(unsigned idx, NodeValue val) { |
170 | if (idx == 0) { Input_ = val; return; } |
171 | if (idx == 1) { Filter_ = val; return; } |
172 | if (idx == 2) { Bias_ = val; return; } |
173 | if (idx == 3) { OriginalOutputForResult_ = val; return; } |
174 | if (idx == 4) { GradOfOriginalOutputNamedResult_ = val; return; } |
175 | idx -= 5; |
176 | llvm_unreachable("Invalid index" ); |
177 | } |
178 | |
179 | llvm::StringRef ConvolutionGradNode::getOutputName(unsigned idx) const { |
180 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
181 | if (idx == 1) { return "GradOfInputNamedFilter" ; } |
182 | if (idx == 2) { return "GradOfInputNamedBias" ; } |
183 | llvm_unreachable("Invalid index" ); |
184 | } |
185 | |
186 | std::string ConvolutionGradNode::getDebugDesc() const { |
187 | DescriptionBuilder db(getKindName()); |
188 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
189 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
190 | db |
191 | .addParam("Input" , *(getInput().getType())) |
192 | .addParam("Filter" , *(getFilter().getType())) |
193 | .addParam("Bias" , *(getBias().getType())) |
194 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
195 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
196 | .addParam("Kernels" , getKernels()) |
197 | .addParam("Strides" , getStrides()) |
198 | .addParam("Pads" , getPads()) |
199 | .addParam("Group" , getGroup()) |
200 | .addParam("Dilation" , getDilation()) |
201 | .addParam("Layout" , getLayout()) |
202 | .addParam("FusedActivation" , getFusedActivation()) |
203 | .addParam("FusedActivationArgs" , getFusedActivationArgs()) |
204 | .addParam("Users" , getNumUsers()); |
205 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
206 | db.addParam("GradOfInputNamedFilter" , *(getGradOfInputNamedFilter().getType())); |
207 | db.addParam("GradOfInputNamedBias" , *(getGradOfInputNamedBias().getType())); |
208 | return db; |
209 | } |
210 | |
211 | void ConvolutionGradNode::visit(Node *parent, NodeWalker *visitor) { |
212 | if (!visitor->shouldVisit(parent, this)) { return; } |
213 | visitor->pre(parent, this); |
214 | if (hasPredicate()) |
215 | getPredicate().getNode()->visit(this, visitor); |
216 | getInput().getNode()->visit(this, visitor); |
217 | getFilter().getNode()->visit(this, visitor); |
218 | getBias().getNode()->visit(this, visitor); |
219 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
220 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
221 | visitor->post(parent, this); |
222 | } |
223 | |
224 | bool ConvolutionGradNode::isEqual(const ConvolutionGradNode &other) const { |
225 | return true && |
226 | Input_ == other.Input_ && |
227 | Filter_ == other.Filter_ && |
228 | Bias_ == other.Bias_ && |
229 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
230 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
231 | predicate_ == other.predicate_ && |
232 | Kernels_ == other.Kernels_ && |
233 | Strides_ == other.Strides_ && |
234 | Pads_ == other.Pads_ && |
235 | Group_ == other.Group_ && |
236 | Dilation_ == other.Dilation_ && |
237 | Layout_ == other.Layout_ && |
238 | FusedActivation_ == other.FusedActivation_ && |
239 | FusedActivationArgs_ == other.FusedActivationArgs_ && |
240 | getType(0) == other.getType(0) && |
241 | getType(1) == other.getType(1) && |
242 | getType(2) == other.getType(2); |
243 | } |
244 | |
245 | Node* ConvolutionGradNode::clone() const { |
246 | return new ConvolutionGradNode(getName(), getInput(), getFilter(), getBias(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getKernels(), getStrides(), getPads(), getGroup(), getDilation(), getLayout(), getFusedActivation(), getFusedActivationArgs()); |
247 | } |
248 | |
249 | llvm::hash_code ConvolutionGradNode::getHash() const { |
250 | return llvm::hash_combine( |
251 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
252 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
253 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
254 | Group_, |
255 | llvm::hash_combine_range(Dilation_.begin(), Dilation_.end()), |
256 | Layout_, |
257 | FusedActivation_, |
258 | [](const std::vector<float>& floatVec) -> llvm::hash_code { |
259 | std::vector<size_t> sizeVec = toBinary(floatVec); |
260 | return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end()); |
261 | }(FusedActivationArgs_), |
262 | Input_, |
263 | Filter_, |
264 | Bias_, |
265 | OriginalOutputForResult_, |
266 | GradOfOriginalOutputNamedResult_); |
267 | } |
268 | |
269 | unsigned ConvolutionNode::getNumInputs() const { |
270 | return 3; |
271 | } |
272 | |
273 | std::string ConvolutionNode::getInputName(unsigned idx) const { |
274 | if (idx == 0) { return "Input" ; } |
275 | if (idx == 1) { return "Filter" ; } |
276 | if (idx == 2) { return "Bias" ; } |
277 | idx -= 3; |
278 | llvm_unreachable("Invalid index" ); |
279 | } |
280 | |
281 | NodeValue ConvolutionNode::getNthInput(unsigned idx) { |
282 | if (idx == 0) { return Input_; } |
283 | if (idx == 1) { return Filter_; } |
284 | if (idx == 2) { return Bias_; } |
285 | idx -= 3; |
286 | llvm_unreachable("Invalid index" ); |
287 | } |
288 | |
289 | void ConvolutionNode::setNthInput(unsigned idx, NodeValue val) { |
290 | if (idx == 0) { Input_ = val; return; } |
291 | if (idx == 1) { Filter_ = val; return; } |
292 | if (idx == 2) { Bias_ = val; return; } |
293 | idx -= 3; |
294 | llvm_unreachable("Invalid index" ); |
295 | } |
296 | |
297 | llvm::StringRef ConvolutionNode::getOutputName(unsigned idx) const { |
298 | if (idx == 0) { return "Result" ; } |
299 | llvm_unreachable("Invalid index" ); |
300 | } |
301 | |
302 | std::string ConvolutionNode::getDebugDesc() const { |
303 | DescriptionBuilder db(getKindName()); |
304 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
305 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
306 | db |
307 | .addParam("Input" , *(getInput().getType())) |
308 | .addParam("Filter" , *(getFilter().getType())) |
309 | .addParam("Bias" , *(getBias().getType())) |
310 | .addParam("Kernels" , getKernels()) |
311 | .addParam("Strides" , getStrides()) |
312 | .addParam("Pads" , getPads()) |
313 | .addParam("Group" , getGroup()) |
314 | .addParam("Dilation" , getDilation()) |
315 | .addParam("Layout" , getLayout()) |
316 | .addParam("FusedActivation" , getFusedActivation()) |
317 | .addParam("FusedActivationArgs" , getFusedActivationArgs()) |
318 | .addParam("Users" , getNumUsers()); |
319 | db.addParam("Result" , *(getResult().getType())); |
320 | return db; |
321 | } |
322 | |
323 | void ConvolutionNode::visit(Node *parent, NodeWalker *visitor) { |
324 | if (!visitor->shouldVisit(parent, this)) { return; } |
325 | visitor->pre(parent, this); |
326 | if (hasPredicate()) |
327 | getPredicate().getNode()->visit(this, visitor); |
328 | getInput().getNode()->visit(this, visitor); |
329 | getFilter().getNode()->visit(this, visitor); |
330 | getBias().getNode()->visit(this, visitor); |
331 | visitor->post(parent, this); |
332 | } |
333 | |
334 | bool ConvolutionNode::isEqual(const ConvolutionNode &other) const { |
335 | return true && |
336 | Input_ == other.Input_ && |
337 | Filter_ == other.Filter_ && |
338 | Bias_ == other.Bias_ && |
339 | predicate_ == other.predicate_ && |
340 | Kernels_ == other.Kernels_ && |
341 | Strides_ == other.Strides_ && |
342 | Pads_ == other.Pads_ && |
343 | Group_ == other.Group_ && |
344 | Dilation_ == other.Dilation_ && |
345 | Layout_ == other.Layout_ && |
346 | FusedActivation_ == other.FusedActivation_ && |
347 | FusedActivationArgs_ == other.FusedActivationArgs_ && |
348 | getType(0) == other.getType(0); |
349 | } |
350 | |
351 | Node* ConvolutionNode::clone() const { |
352 | return new ConvolutionNode(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getKernels(), getStrides(), getPads(), getGroup(), getDilation(), getLayout(), getFusedActivation(), getFusedActivationArgs()); |
353 | } |
354 | |
355 | llvm::hash_code ConvolutionNode::getHash() const { |
356 | return llvm::hash_combine( |
357 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
358 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
359 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
360 | Group_, |
361 | llvm::hash_combine_range(Dilation_.begin(), Dilation_.end()), |
362 | Layout_, |
363 | FusedActivation_, |
364 | [](const std::vector<float>& floatVec) -> llvm::hash_code { |
365 | std::vector<size_t> sizeVec = toBinary(floatVec); |
366 | return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end()); |
367 | }(FusedActivationArgs_), |
368 | Input_, |
369 | Filter_, |
370 | Bias_); |
371 | } |
372 | bool ConvolutionNode::hasFusedActivation() const { return getFusedActivation() != FusedActivation::NONE; } |
373 | ConvolutionGradNode *ConvolutionNode::getGrad(GraphGradMapper &builder) { |
374 | auto *x = new ConvolutionGradNode(getName().str() + "_grad" , getInput(), getFilter(), getBias(), getResult(), builder.getGradient(getResult()), getKernels(), getStrides(), getPads(), getGroup(), getDilation(), getLayout(), getFusedActivation(), getFusedActivationArgs()); |
375 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
376 | builder.addGradient(getFilter(), x->getGradOfInputNamedFilter()); |
377 | builder.addGradient(getBias(), x->getGradOfInputNamedBias()); |
378 | return x; |
379 | } |
380 | |
381 | unsigned ChannelwiseQuantizedConvolutionNode::getNumInputs() const { |
382 | return 7; |
383 | } |
384 | |
385 | std::string ChannelwiseQuantizedConvolutionNode::getInputName(unsigned idx) const { |
386 | if (idx == 0) { return "Input" ; } |
387 | if (idx == 1) { return "Filter" ; } |
388 | if (idx == 2) { return "Bias" ; } |
389 | if (idx == 3) { return "FilterScales" ; } |
390 | if (idx == 4) { return "FilterOffsets" ; } |
391 | if (idx == 5) { return "BiasScales" ; } |
392 | if (idx == 6) { return "BiasOffsets" ; } |
393 | idx -= 7; |
394 | llvm_unreachable("Invalid index" ); |
395 | } |
396 | |
397 | NodeValue ChannelwiseQuantizedConvolutionNode::getNthInput(unsigned idx) { |
398 | if (idx == 0) { return Input_; } |
399 | if (idx == 1) { return Filter_; } |
400 | if (idx == 2) { return Bias_; } |
401 | if (idx == 3) { return FilterScales_; } |
402 | if (idx == 4) { return FilterOffsets_; } |
403 | if (idx == 5) { return BiasScales_; } |
404 | if (idx == 6) { return BiasOffsets_; } |
405 | idx -= 7; |
406 | llvm_unreachable("Invalid index" ); |
407 | } |
408 | |
409 | void ChannelwiseQuantizedConvolutionNode::setNthInput(unsigned idx, NodeValue val) { |
410 | if (idx == 0) { Input_ = val; return; } |
411 | if (idx == 1) { Filter_ = val; return; } |
412 | if (idx == 2) { Bias_ = val; return; } |
413 | if (idx == 3) { FilterScales_ = val; return; } |
414 | if (idx == 4) { FilterOffsets_ = val; return; } |
415 | if (idx == 5) { BiasScales_ = val; return; } |
416 | if (idx == 6) { BiasOffsets_ = val; return; } |
417 | idx -= 7; |
418 | llvm_unreachable("Invalid index" ); |
419 | } |
420 | |
421 | llvm::StringRef ChannelwiseQuantizedConvolutionNode::getOutputName(unsigned idx) const { |
422 | if (idx == 0) { return "Result" ; } |
423 | llvm_unreachable("Invalid index" ); |
424 | } |
425 | |
426 | std::string ChannelwiseQuantizedConvolutionNode::getDebugDesc() const { |
427 | DescriptionBuilder db(getKindName()); |
428 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
429 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
430 | db |
431 | .addParam("Input" , *(getInput().getType())) |
432 | .addParam("Filter" , *(getFilter().getType())) |
433 | .addParam("Bias" , *(getBias().getType())) |
434 | .addParam("FilterScales" , *(getFilterScales().getType())) |
435 | .addParam("FilterOffsets" , *(getFilterOffsets().getType())) |
436 | .addParam("BiasScales" , *(getBiasScales().getType())) |
437 | .addParam("BiasOffsets" , *(getBiasOffsets().getType())) |
438 | .addParam("Kernels" , getKernels()) |
439 | .addParam("Strides" , getStrides()) |
440 | .addParam("Pads" , getPads()) |
441 | .addParam("Group" , getGroup()) |
442 | .addParam("Dilation" , getDilation()) |
443 | .addParam("FusedActivation" , getFusedActivation()) |
444 | .addParam("FusedActivationArgs" , getFusedActivationArgs()) |
445 | .addParam("Users" , getNumUsers()); |
446 | db.addParam("Result" , *(getResult().getType())); |
447 | return db; |
448 | } |
449 | |
450 | void ChannelwiseQuantizedConvolutionNode::visit(Node *parent, NodeWalker *visitor) { |
451 | if (!visitor->shouldVisit(parent, this)) { return; } |
452 | visitor->pre(parent, this); |
453 | if (hasPredicate()) |
454 | getPredicate().getNode()->visit(this, visitor); |
455 | getInput().getNode()->visit(this, visitor); |
456 | getFilter().getNode()->visit(this, visitor); |
457 | getBias().getNode()->visit(this, visitor); |
458 | getFilterScales().getNode()->visit(this, visitor); |
459 | getFilterOffsets().getNode()->visit(this, visitor); |
460 | getBiasScales().getNode()->visit(this, visitor); |
461 | getBiasOffsets().getNode()->visit(this, visitor); |
462 | visitor->post(parent, this); |
463 | } |
464 | |
465 | bool ChannelwiseQuantizedConvolutionNode::isEqual(const ChannelwiseQuantizedConvolutionNode &other) const { |
466 | return true && |
467 | Input_ == other.Input_ && |
468 | Filter_ == other.Filter_ && |
469 | Bias_ == other.Bias_ && |
470 | FilterScales_ == other.FilterScales_ && |
471 | FilterOffsets_ == other.FilterOffsets_ && |
472 | BiasScales_ == other.BiasScales_ && |
473 | BiasOffsets_ == other.BiasOffsets_ && |
474 | predicate_ == other.predicate_ && |
475 | Kernels_ == other.Kernels_ && |
476 | Strides_ == other.Strides_ && |
477 | Pads_ == other.Pads_ && |
478 | Group_ == other.Group_ && |
479 | Dilation_ == other.Dilation_ && |
480 | FusedActivation_ == other.FusedActivation_ && |
481 | FusedActivationArgs_ == other.FusedActivationArgs_ && |
482 | getType(0) == other.getType(0); |
483 | } |
484 | |
485 | Node* ChannelwiseQuantizedConvolutionNode::clone() const { |
486 | return new ChannelwiseQuantizedConvolutionNode(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getFilterScales(), getFilterOffsets(), getBiasScales(), getBiasOffsets(), getKernels(), getStrides(), getPads(), getGroup(), getDilation(), getFusedActivation(), getFusedActivationArgs()); |
487 | } |
488 | |
489 | llvm::hash_code ChannelwiseQuantizedConvolutionNode::getHash() const { |
490 | return llvm::hash_combine( |
491 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
492 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
493 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
494 | Group_, |
495 | llvm::hash_combine_range(Dilation_.begin(), Dilation_.end()), |
496 | FusedActivation_, |
497 | [](const std::vector<float>& floatVec) -> llvm::hash_code { |
498 | std::vector<size_t> sizeVec = toBinary(floatVec); |
499 | return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end()); |
500 | }(FusedActivationArgs_), |
501 | Input_, |
502 | Filter_, |
503 | Bias_, |
504 | FilterScales_, |
505 | FilterOffsets_, |
506 | BiasScales_, |
507 | BiasOffsets_); |
508 | } |
509 | bool ChannelwiseQuantizedConvolutionNode::hasFusedActivation() const { return getFusedActivation() != FusedActivation::NONE; } |
510 | unsigned ConvTransposeNode::getNumInputs() const { |
511 | return 3; |
512 | } |
513 | |
514 | std::string ConvTransposeNode::getInputName(unsigned idx) const { |
515 | if (idx == 0) { return "Input" ; } |
516 | if (idx == 1) { return "Filter" ; } |
517 | if (idx == 2) { return "Bias" ; } |
518 | idx -= 3; |
519 | llvm_unreachable("Invalid index" ); |
520 | } |
521 | |
522 | NodeValue ConvTransposeNode::getNthInput(unsigned idx) { |
523 | if (idx == 0) { return Input_; } |
524 | if (idx == 1) { return Filter_; } |
525 | if (idx == 2) { return Bias_; } |
526 | idx -= 3; |
527 | llvm_unreachable("Invalid index" ); |
528 | } |
529 | |
530 | void ConvTransposeNode::setNthInput(unsigned idx, NodeValue val) { |
531 | if (idx == 0) { Input_ = val; return; } |
532 | if (idx == 1) { Filter_ = val; return; } |
533 | if (idx == 2) { Bias_ = val; return; } |
534 | idx -= 3; |
535 | llvm_unreachable("Invalid index" ); |
536 | } |
537 | |
538 | llvm::StringRef ConvTransposeNode::getOutputName(unsigned idx) const { |
539 | if (idx == 0) { return "Result" ; } |
540 | llvm_unreachable("Invalid index" ); |
541 | } |
542 | |
543 | std::string ConvTransposeNode::getDebugDesc() const { |
544 | DescriptionBuilder db(getKindName()); |
545 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
546 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
547 | db |
548 | .addParam("Input" , *(getInput().getType())) |
549 | .addParam("Filter" , *(getFilter().getType())) |
550 | .addParam("Bias" , *(getBias().getType())) |
551 | .addParam("Kernels" , getKernels()) |
552 | .addParam("Strides" , getStrides()) |
553 | .addParam("Pads" , getPads()) |
554 | .addParam("Group" , getGroup()) |
555 | .addParam("Dilation" , getDilation()) |
556 | .addParam("Users" , getNumUsers()); |
557 | db.addParam("Result" , *(getResult().getType())); |
558 | return db; |
559 | } |
560 | |
561 | void ConvTransposeNode::visit(Node *parent, NodeWalker *visitor) { |
562 | if (!visitor->shouldVisit(parent, this)) { return; } |
563 | visitor->pre(parent, this); |
564 | if (hasPredicate()) |
565 | getPredicate().getNode()->visit(this, visitor); |
566 | getInput().getNode()->visit(this, visitor); |
567 | getFilter().getNode()->visit(this, visitor); |
568 | getBias().getNode()->visit(this, visitor); |
569 | visitor->post(parent, this); |
570 | } |
571 | |
572 | bool ConvTransposeNode::isEqual(const ConvTransposeNode &other) const { |
573 | return true && |
574 | Input_ == other.Input_ && |
575 | Filter_ == other.Filter_ && |
576 | Bias_ == other.Bias_ && |
577 | predicate_ == other.predicate_ && |
578 | Kernels_ == other.Kernels_ && |
579 | Strides_ == other.Strides_ && |
580 | Pads_ == other.Pads_ && |
581 | Group_ == other.Group_ && |
582 | Dilation_ == other.Dilation_ && |
583 | getType(0) == other.getType(0); |
584 | } |
585 | |
586 | Node* ConvTransposeNode::clone() const { |
587 | return new ConvTransposeNode(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getKernels(), getStrides(), getPads(), getGroup(), getDilation()); |
588 | } |
589 | |
590 | llvm::hash_code ConvTransposeNode::getHash() const { |
591 | return llvm::hash_combine( |
592 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
593 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
594 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
595 | Group_, |
596 | llvm::hash_combine_range(Dilation_.begin(), Dilation_.end()), |
597 | Input_, |
598 | Filter_, |
599 | Bias_); |
600 | } |
601 | |
602 | unsigned Convolution3DGradNode::getNumInputs() const { |
603 | return 5; |
604 | } |
605 | |
606 | std::string Convolution3DGradNode::getInputName(unsigned idx) const { |
607 | if (idx == 0) { return "Input" ; } |
608 | if (idx == 1) { return "Filter" ; } |
609 | if (idx == 2) { return "Bias" ; } |
610 | if (idx == 3) { return "OriginalOutputForResult" ; } |
611 | if (idx == 4) { return "GradOfOriginalOutputNamedResult" ; } |
612 | idx -= 5; |
613 | llvm_unreachable("Invalid index" ); |
614 | } |
615 | |
616 | NodeValue Convolution3DGradNode::getNthInput(unsigned idx) { |
617 | if (idx == 0) { return Input_; } |
618 | if (idx == 1) { return Filter_; } |
619 | if (idx == 2) { return Bias_; } |
620 | if (idx == 3) { return OriginalOutputForResult_; } |
621 | if (idx == 4) { return GradOfOriginalOutputNamedResult_; } |
622 | idx -= 5; |
623 | llvm_unreachable("Invalid index" ); |
624 | } |
625 | |
626 | void Convolution3DGradNode::setNthInput(unsigned idx, NodeValue val) { |
627 | if (idx == 0) { Input_ = val; return; } |
628 | if (idx == 1) { Filter_ = val; return; } |
629 | if (idx == 2) { Bias_ = val; return; } |
630 | if (idx == 3) { OriginalOutputForResult_ = val; return; } |
631 | if (idx == 4) { GradOfOriginalOutputNamedResult_ = val; return; } |
632 | idx -= 5; |
633 | llvm_unreachable("Invalid index" ); |
634 | } |
635 | |
636 | llvm::StringRef Convolution3DGradNode::getOutputName(unsigned idx) const { |
637 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
638 | if (idx == 1) { return "GradOfInputNamedFilter" ; } |
639 | if (idx == 2) { return "GradOfInputNamedBias" ; } |
640 | llvm_unreachable("Invalid index" ); |
641 | } |
642 | |
643 | std::string Convolution3DGradNode::getDebugDesc() const { |
644 | DescriptionBuilder db(getKindName()); |
645 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
646 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
647 | db |
648 | .addParam("Input" , *(getInput().getType())) |
649 | .addParam("Filter" , *(getFilter().getType())) |
650 | .addParam("Bias" , *(getBias().getType())) |
651 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
652 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
653 | .addParam("Kernels" , getKernels()) |
654 | .addParam("Strides" , getStrides()) |
655 | .addParam("Pads" , getPads()) |
656 | .addParam("Group" , getGroup()) |
657 | .addParam("Users" , getNumUsers()); |
658 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
659 | db.addParam("GradOfInputNamedFilter" , *(getGradOfInputNamedFilter().getType())); |
660 | db.addParam("GradOfInputNamedBias" , *(getGradOfInputNamedBias().getType())); |
661 | return db; |
662 | } |
663 | |
664 | void Convolution3DGradNode::visit(Node *parent, NodeWalker *visitor) { |
665 | if (!visitor->shouldVisit(parent, this)) { return; } |
666 | visitor->pre(parent, this); |
667 | if (hasPredicate()) |
668 | getPredicate().getNode()->visit(this, visitor); |
669 | getInput().getNode()->visit(this, visitor); |
670 | getFilter().getNode()->visit(this, visitor); |
671 | getBias().getNode()->visit(this, visitor); |
672 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
673 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
674 | visitor->post(parent, this); |
675 | } |
676 | |
677 | bool Convolution3DGradNode::isEqual(const Convolution3DGradNode &other) const { |
678 | return true && |
679 | Input_ == other.Input_ && |
680 | Filter_ == other.Filter_ && |
681 | Bias_ == other.Bias_ && |
682 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
683 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
684 | predicate_ == other.predicate_ && |
685 | Kernels_ == other.Kernels_ && |
686 | Strides_ == other.Strides_ && |
687 | Pads_ == other.Pads_ && |
688 | Group_ == other.Group_ && |
689 | getType(0) == other.getType(0) && |
690 | getType(1) == other.getType(1) && |
691 | getType(2) == other.getType(2); |
692 | } |
693 | |
694 | Node* Convolution3DGradNode::clone() const { |
695 | return new Convolution3DGradNode(getName(), getInput(), getFilter(), getBias(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getKernels(), getStrides(), getPads(), getGroup()); |
696 | } |
697 | |
698 | llvm::hash_code Convolution3DGradNode::getHash() const { |
699 | return llvm::hash_combine( |
700 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
701 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
702 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
703 | Group_, |
704 | Input_, |
705 | Filter_, |
706 | Bias_, |
707 | OriginalOutputForResult_, |
708 | GradOfOriginalOutputNamedResult_); |
709 | } |
710 | |
711 | unsigned Convolution3DNode::getNumInputs() const { |
712 | return 3; |
713 | } |
714 | |
715 | std::string Convolution3DNode::getInputName(unsigned idx) const { |
716 | if (idx == 0) { return "Input" ; } |
717 | if (idx == 1) { return "Filter" ; } |
718 | if (idx == 2) { return "Bias" ; } |
719 | idx -= 3; |
720 | llvm_unreachable("Invalid index" ); |
721 | } |
722 | |
723 | NodeValue Convolution3DNode::getNthInput(unsigned idx) { |
724 | if (idx == 0) { return Input_; } |
725 | if (idx == 1) { return Filter_; } |
726 | if (idx == 2) { return Bias_; } |
727 | idx -= 3; |
728 | llvm_unreachable("Invalid index" ); |
729 | } |
730 | |
731 | void Convolution3DNode::setNthInput(unsigned idx, NodeValue val) { |
732 | if (idx == 0) { Input_ = val; return; } |
733 | if (idx == 1) { Filter_ = val; return; } |
734 | if (idx == 2) { Bias_ = val; return; } |
735 | idx -= 3; |
736 | llvm_unreachable("Invalid index" ); |
737 | } |
738 | |
739 | llvm::StringRef Convolution3DNode::getOutputName(unsigned idx) const { |
740 | if (idx == 0) { return "Result" ; } |
741 | llvm_unreachable("Invalid index" ); |
742 | } |
743 | |
744 | std::string Convolution3DNode::getDebugDesc() const { |
745 | DescriptionBuilder db(getKindName()); |
746 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
747 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
748 | db |
749 | .addParam("Input" , *(getInput().getType())) |
750 | .addParam("Filter" , *(getFilter().getType())) |
751 | .addParam("Bias" , *(getBias().getType())) |
752 | .addParam("Kernels" , getKernels()) |
753 | .addParam("Strides" , getStrides()) |
754 | .addParam("Pads" , getPads()) |
755 | .addParam("Group" , getGroup()) |
756 | .addParam("Users" , getNumUsers()); |
757 | db.addParam("Result" , *(getResult().getType())); |
758 | return db; |
759 | } |
760 | |
761 | void Convolution3DNode::visit(Node *parent, NodeWalker *visitor) { |
762 | if (!visitor->shouldVisit(parent, this)) { return; } |
763 | visitor->pre(parent, this); |
764 | if (hasPredicate()) |
765 | getPredicate().getNode()->visit(this, visitor); |
766 | getInput().getNode()->visit(this, visitor); |
767 | getFilter().getNode()->visit(this, visitor); |
768 | getBias().getNode()->visit(this, visitor); |
769 | visitor->post(parent, this); |
770 | } |
771 | |
772 | bool Convolution3DNode::isEqual(const Convolution3DNode &other) const { |
773 | return true && |
774 | Input_ == other.Input_ && |
775 | Filter_ == other.Filter_ && |
776 | Bias_ == other.Bias_ && |
777 | predicate_ == other.predicate_ && |
778 | Kernels_ == other.Kernels_ && |
779 | Strides_ == other.Strides_ && |
780 | Pads_ == other.Pads_ && |
781 | Group_ == other.Group_ && |
782 | getType(0) == other.getType(0); |
783 | } |
784 | |
785 | Node* Convolution3DNode::clone() const { |
786 | return new Convolution3DNode(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getKernels(), getStrides(), getPads(), getGroup()); |
787 | } |
788 | |
789 | llvm::hash_code Convolution3DNode::getHash() const { |
790 | return llvm::hash_combine( |
791 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
792 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
793 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
794 | Group_, |
795 | Input_, |
796 | Filter_, |
797 | Bias_); |
798 | } |
799 | |
800 | Convolution3DGradNode *Convolution3DNode::getGrad(GraphGradMapper &builder) { |
801 | auto *x = new Convolution3DGradNode(getName().str() + "_grad" , getInput(), getFilter(), getBias(), getResult(), builder.getGradient(getResult()), getKernels(), getStrides(), getPads(), getGroup()); |
802 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
803 | builder.addGradient(getFilter(), x->getGradOfInputNamedFilter()); |
804 | builder.addGradient(getBias(), x->getGradOfInputNamedBias()); |
805 | return x; |
806 | } |
807 | |
808 | unsigned MaxPoolGradNode::getNumInputs() const { |
809 | return 5; |
810 | } |
811 | |
812 | std::string MaxPoolGradNode::getInputName(unsigned idx) const { |
813 | if (idx == 0) { return "Input" ; } |
814 | if (idx == 1) { return "OriginalOutputForResult" ; } |
815 | if (idx == 2) { return "GradOfOriginalOutputNamedResult" ; } |
816 | if (idx == 3) { return "OriginalOutputForArgmax" ; } |
817 | if (idx == 4) { return "GradOfOriginalOutputNamedArgmax" ; } |
818 | idx -= 5; |
819 | llvm_unreachable("Invalid index" ); |
820 | } |
821 | |
822 | NodeValue MaxPoolGradNode::getNthInput(unsigned idx) { |
823 | if (idx == 0) { return Input_; } |
824 | if (idx == 1) { return OriginalOutputForResult_; } |
825 | if (idx == 2) { return GradOfOriginalOutputNamedResult_; } |
826 | if (idx == 3) { return OriginalOutputForArgmax_; } |
827 | if (idx == 4) { return GradOfOriginalOutputNamedArgmax_; } |
828 | idx -= 5; |
829 | llvm_unreachable("Invalid index" ); |
830 | } |
831 | |
832 | void MaxPoolGradNode::setNthInput(unsigned idx, NodeValue val) { |
833 | if (idx == 0) { Input_ = val; return; } |
834 | if (idx == 1) { OriginalOutputForResult_ = val; return; } |
835 | if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; } |
836 | if (idx == 3) { OriginalOutputForArgmax_ = val; return; } |
837 | if (idx == 4) { GradOfOriginalOutputNamedArgmax_ = val; return; } |
838 | idx -= 5; |
839 | llvm_unreachable("Invalid index" ); |
840 | } |
841 | |
842 | llvm::StringRef MaxPoolGradNode::getOutputName(unsigned idx) const { |
843 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
844 | llvm_unreachable("Invalid index" ); |
845 | } |
846 | |
847 | std::string MaxPoolGradNode::getDebugDesc() const { |
848 | DescriptionBuilder db(getKindName()); |
849 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
850 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
851 | db |
852 | .addParam("Input" , *(getInput().getType())) |
853 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
854 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
855 | .addParam("OriginalOutputForArgmax" , *(getOriginalOutputForArgmax().getType())) |
856 | .addParam("GradOfOriginalOutputNamedArgmax" , *(getGradOfOriginalOutputNamedArgmax().getType())) |
857 | .addParam("Kernels" , getKernels()) |
858 | .addParam("Strides" , getStrides()) |
859 | .addParam("Pads" , getPads()) |
860 | .addParam("Layout" , getLayout()) |
861 | .addParam("Users" , getNumUsers()); |
862 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
863 | return db; |
864 | } |
865 | |
866 | void MaxPoolGradNode::visit(Node *parent, NodeWalker *visitor) { |
867 | if (!visitor->shouldVisit(parent, this)) { return; } |
868 | visitor->pre(parent, this); |
869 | if (hasPredicate()) |
870 | getPredicate().getNode()->visit(this, visitor); |
871 | getInput().getNode()->visit(this, visitor); |
872 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
873 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
874 | getOriginalOutputForArgmax().getNode()->visit(this, visitor); |
875 | getGradOfOriginalOutputNamedArgmax().getNode()->visit(this, visitor); |
876 | visitor->post(parent, this); |
877 | } |
878 | |
879 | bool MaxPoolGradNode::isEqual(const MaxPoolGradNode &other) const { |
880 | return true && |
881 | Input_ == other.Input_ && |
882 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
883 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
884 | OriginalOutputForArgmax_ == other.OriginalOutputForArgmax_ && |
885 | GradOfOriginalOutputNamedArgmax_ == other.GradOfOriginalOutputNamedArgmax_ && |
886 | predicate_ == other.predicate_ && |
887 | Kernels_ == other.Kernels_ && |
888 | Strides_ == other.Strides_ && |
889 | Pads_ == other.Pads_ && |
890 | Layout_ == other.Layout_ && |
891 | getType(0) == other.getType(0); |
892 | } |
893 | |
894 | Node* MaxPoolGradNode::clone() const { |
895 | return new MaxPoolGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getOriginalOutputForArgmax(), getGradOfOriginalOutputNamedArgmax(), getKernels(), getStrides(), getPads(), getLayout()); |
896 | } |
897 | |
898 | llvm::hash_code MaxPoolGradNode::getHash() const { |
899 | return llvm::hash_combine( |
900 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
901 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
902 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
903 | Layout_, |
904 | Input_, |
905 | OriginalOutputForResult_, |
906 | GradOfOriginalOutputNamedResult_, |
907 | OriginalOutputForArgmax_, |
908 | GradOfOriginalOutputNamedArgmax_); |
909 | } |
910 | |
911 | unsigned MaxPoolNode::getNumInputs() const { |
912 | return 1; |
913 | } |
914 | |
915 | std::string MaxPoolNode::getInputName(unsigned idx) const { |
916 | if (idx == 0) { return "Input" ; } |
917 | idx -= 1; |
918 | llvm_unreachable("Invalid index" ); |
919 | } |
920 | |
921 | NodeValue MaxPoolNode::getNthInput(unsigned idx) { |
922 | if (idx == 0) { return Input_; } |
923 | idx -= 1; |
924 | llvm_unreachable("Invalid index" ); |
925 | } |
926 | |
927 | void MaxPoolNode::setNthInput(unsigned idx, NodeValue val) { |
928 | if (idx == 0) { Input_ = val; return; } |
929 | idx -= 1; |
930 | llvm_unreachable("Invalid index" ); |
931 | } |
932 | |
933 | llvm::StringRef MaxPoolNode::getOutputName(unsigned idx) const { |
934 | if (idx == 0) { return "Result" ; } |
935 | if (idx == 1) { return "Argmax" ; } |
936 | llvm_unreachable("Invalid index" ); |
937 | } |
938 | |
939 | std::string MaxPoolNode::getDebugDesc() const { |
940 | DescriptionBuilder db(getKindName()); |
941 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
942 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
943 | db |
944 | .addParam("Input" , *(getInput().getType())) |
945 | .addParam("Kernels" , getKernels()) |
946 | .addParam("Strides" , getStrides()) |
947 | .addParam("Pads" , getPads()) |
948 | .addParam("Layout" , getLayout()) |
949 | .addParam("Users" , getNumUsers()); |
950 | db.addParam("Result" , *(getResult().getType())); |
951 | db.addParam("Argmax" , *(getArgmax().getType())); |
952 | return db; |
953 | } |
954 | |
955 | void MaxPoolNode::visit(Node *parent, NodeWalker *visitor) { |
956 | if (!visitor->shouldVisit(parent, this)) { return; } |
957 | visitor->pre(parent, this); |
958 | if (hasPredicate()) |
959 | getPredicate().getNode()->visit(this, visitor); |
960 | getInput().getNode()->visit(this, visitor); |
961 | visitor->post(parent, this); |
962 | } |
963 | |
964 | bool MaxPoolNode::isEqual(const MaxPoolNode &other) const { |
965 | return true && |
966 | Input_ == other.Input_ && |
967 | predicate_ == other.predicate_ && |
968 | Kernels_ == other.Kernels_ && |
969 | Strides_ == other.Strides_ && |
970 | Pads_ == other.Pads_ && |
971 | Layout_ == other.Layout_ && |
972 | getType(0) == other.getType(0) && |
973 | getType(1) == other.getType(1); |
974 | } |
975 | |
976 | Node* MaxPoolNode::clone() const { |
977 | return new MaxPoolNode(getName(), getResult().getType(), getArgmax().getType(), getInput(), getKernels(), getStrides(), getPads(), getLayout()); |
978 | } |
979 | |
980 | llvm::hash_code MaxPoolNode::getHash() const { |
981 | return llvm::hash_combine( |
982 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
983 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
984 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
985 | Layout_, |
986 | Input_); |
987 | } |
988 | |
989 | MaxPoolGradNode *MaxPoolNode::getGrad(GraphGradMapper &builder) { |
990 | auto *x = new MaxPoolGradNode(getName().str() + "_grad" , getInput(), getResult(), builder.getGradient(getResult()), getArgmax(), builder.getGradient(getArgmax()), getKernels(), getStrides(), getPads(), getLayout()); |
991 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
992 | return x; |
993 | } |
994 | |
995 | unsigned ArgMaxNode::getNumInputs() const { |
996 | return 1; |
997 | } |
998 | |
999 | std::string ArgMaxNode::getInputName(unsigned idx) const { |
1000 | if (idx == 0) { return "Input" ; } |
1001 | idx -= 1; |
1002 | llvm_unreachable("Invalid index" ); |
1003 | } |
1004 | |
1005 | NodeValue ArgMaxNode::getNthInput(unsigned idx) { |
1006 | if (idx == 0) { return Input_; } |
1007 | idx -= 1; |
1008 | llvm_unreachable("Invalid index" ); |
1009 | } |
1010 | |
1011 | void ArgMaxNode::setNthInput(unsigned idx, NodeValue val) { |
1012 | if (idx == 0) { Input_ = val; return; } |
1013 | idx -= 1; |
1014 | llvm_unreachable("Invalid index" ); |
1015 | } |
1016 | |
1017 | llvm::StringRef ArgMaxNode::getOutputName(unsigned idx) const { |
1018 | if (idx == 0) { return "Result" ; } |
1019 | llvm_unreachable("Invalid index" ); |
1020 | } |
1021 | |
1022 | std::string ArgMaxNode::getDebugDesc() const { |
1023 | DescriptionBuilder db(getKindName()); |
1024 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1025 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1026 | db |
1027 | .addParam("Input" , *(getInput().getType())) |
1028 | .addParam("Axis" , getAxis()) |
1029 | .addParam("KeepDims" , getKeepDims()) |
1030 | .addParam("Users" , getNumUsers()); |
1031 | db.addParam("Result" , *(getResult().getType())); |
1032 | return db; |
1033 | } |
1034 | |
1035 | void ArgMaxNode::visit(Node *parent, NodeWalker *visitor) { |
1036 | if (!visitor->shouldVisit(parent, this)) { return; } |
1037 | visitor->pre(parent, this); |
1038 | if (hasPredicate()) |
1039 | getPredicate().getNode()->visit(this, visitor); |
1040 | getInput().getNode()->visit(this, visitor); |
1041 | visitor->post(parent, this); |
1042 | } |
1043 | |
1044 | bool ArgMaxNode::isEqual(const ArgMaxNode &other) const { |
1045 | return true && |
1046 | Input_ == other.Input_ && |
1047 | predicate_ == other.predicate_ && |
1048 | Axis_ == other.Axis_ && |
1049 | KeepDims_ == other.KeepDims_ && |
1050 | getType(0) == other.getType(0); |
1051 | } |
1052 | |
1053 | Node* ArgMaxNode::clone() const { |
1054 | return new ArgMaxNode(getName(), getResult().getType(), getInput(), getAxis(), getKeepDims()); |
1055 | } |
1056 | |
1057 | llvm::hash_code ArgMaxNode::getHash() const { |
1058 | return llvm::hash_combine( |
1059 | Axis_, |
1060 | KeepDims_, |
1061 | Input_); |
1062 | } |
1063 | |
1064 | unsigned ArgMinNode::getNumInputs() const { |
1065 | return 1; |
1066 | } |
1067 | |
1068 | std::string ArgMinNode::getInputName(unsigned idx) const { |
1069 | if (idx == 0) { return "Input" ; } |
1070 | idx -= 1; |
1071 | llvm_unreachable("Invalid index" ); |
1072 | } |
1073 | |
1074 | NodeValue ArgMinNode::getNthInput(unsigned idx) { |
1075 | if (idx == 0) { return Input_; } |
1076 | idx -= 1; |
1077 | llvm_unreachable("Invalid index" ); |
1078 | } |
1079 | |
1080 | void ArgMinNode::setNthInput(unsigned idx, NodeValue val) { |
1081 | if (idx == 0) { Input_ = val; return; } |
1082 | idx -= 1; |
1083 | llvm_unreachable("Invalid index" ); |
1084 | } |
1085 | |
1086 | llvm::StringRef ArgMinNode::getOutputName(unsigned idx) const { |
1087 | if (idx == 0) { return "Result" ; } |
1088 | llvm_unreachable("Invalid index" ); |
1089 | } |
1090 | |
1091 | std::string ArgMinNode::getDebugDesc() const { |
1092 | DescriptionBuilder db(getKindName()); |
1093 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1094 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1095 | db |
1096 | .addParam("Input" , *(getInput().getType())) |
1097 | .addParam("Axis" , getAxis()) |
1098 | .addParam("KeepDims" , getKeepDims()) |
1099 | .addParam("Users" , getNumUsers()); |
1100 | db.addParam("Result" , *(getResult().getType())); |
1101 | return db; |
1102 | } |
1103 | |
1104 | void ArgMinNode::visit(Node *parent, NodeWalker *visitor) { |
1105 | if (!visitor->shouldVisit(parent, this)) { return; } |
1106 | visitor->pre(parent, this); |
1107 | if (hasPredicate()) |
1108 | getPredicate().getNode()->visit(this, visitor); |
1109 | getInput().getNode()->visit(this, visitor); |
1110 | visitor->post(parent, this); |
1111 | } |
1112 | |
1113 | bool ArgMinNode::isEqual(const ArgMinNode &other) const { |
1114 | return true && |
1115 | Input_ == other.Input_ && |
1116 | predicate_ == other.predicate_ && |
1117 | Axis_ == other.Axis_ && |
1118 | KeepDims_ == other.KeepDims_ && |
1119 | getType(0) == other.getType(0); |
1120 | } |
1121 | |
1122 | Node* ArgMinNode::clone() const { |
1123 | return new ArgMinNode(getName(), getResult().getType(), getInput(), getAxis(), getKeepDims()); |
1124 | } |
1125 | |
1126 | llvm::hash_code ArgMinNode::getHash() const { |
1127 | return llvm::hash_combine( |
1128 | Axis_, |
1129 | KeepDims_, |
1130 | Input_); |
1131 | } |
1132 | |
1133 | unsigned AvgPoolGradNode::getNumInputs() const { |
1134 | return 3; |
1135 | } |
1136 | |
1137 | std::string AvgPoolGradNode::getInputName(unsigned idx) const { |
1138 | if (idx == 0) { return "Input" ; } |
1139 | if (idx == 1) { return "OriginalOutputForResult" ; } |
1140 | if (idx == 2) { return "GradOfOriginalOutputNamedResult" ; } |
1141 | idx -= 3; |
1142 | llvm_unreachable("Invalid index" ); |
1143 | } |
1144 | |
1145 | NodeValue AvgPoolGradNode::getNthInput(unsigned idx) { |
1146 | if (idx == 0) { return Input_; } |
1147 | if (idx == 1) { return OriginalOutputForResult_; } |
1148 | if (idx == 2) { return GradOfOriginalOutputNamedResult_; } |
1149 | idx -= 3; |
1150 | llvm_unreachable("Invalid index" ); |
1151 | } |
1152 | |
1153 | void AvgPoolGradNode::setNthInput(unsigned idx, NodeValue val) { |
1154 | if (idx == 0) { Input_ = val; return; } |
1155 | if (idx == 1) { OriginalOutputForResult_ = val; return; } |
1156 | if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; } |
1157 | idx -= 3; |
1158 | llvm_unreachable("Invalid index" ); |
1159 | } |
1160 | |
1161 | llvm::StringRef AvgPoolGradNode::getOutputName(unsigned idx) const { |
1162 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
1163 | llvm_unreachable("Invalid index" ); |
1164 | } |
1165 | |
1166 | std::string AvgPoolGradNode::getDebugDesc() const { |
1167 | DescriptionBuilder db(getKindName()); |
1168 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1169 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1170 | db |
1171 | .addParam("Input" , *(getInput().getType())) |
1172 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
1173 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
1174 | .addParam("Kernels" , getKernels()) |
1175 | .addParam("Strides" , getStrides()) |
1176 | .addParam("Pads" , getPads()) |
1177 | .addParam("Layout" , getLayout()) |
1178 | .addParam("CountIncludePads" , getCountIncludePads()) |
1179 | .addParam("Users" , getNumUsers()); |
1180 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
1181 | return db; |
1182 | } |
1183 | |
1184 | void AvgPoolGradNode::visit(Node *parent, NodeWalker *visitor) { |
1185 | if (!visitor->shouldVisit(parent, this)) { return; } |
1186 | visitor->pre(parent, this); |
1187 | if (hasPredicate()) |
1188 | getPredicate().getNode()->visit(this, visitor); |
1189 | getInput().getNode()->visit(this, visitor); |
1190 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
1191 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
1192 | visitor->post(parent, this); |
1193 | } |
1194 | |
1195 | bool AvgPoolGradNode::isEqual(const AvgPoolGradNode &other) const { |
1196 | return true && |
1197 | Input_ == other.Input_ && |
1198 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
1199 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
1200 | predicate_ == other.predicate_ && |
1201 | Kernels_ == other.Kernels_ && |
1202 | Strides_ == other.Strides_ && |
1203 | Pads_ == other.Pads_ && |
1204 | Layout_ == other.Layout_ && |
1205 | CountIncludePads_ == other.CountIncludePads_ && |
1206 | getType(0) == other.getType(0); |
1207 | } |
1208 | |
1209 | Node* AvgPoolGradNode::clone() const { |
1210 | return new AvgPoolGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getKernels(), getStrides(), getPads(), getLayout(), getCountIncludePads()); |
1211 | } |
1212 | |
1213 | llvm::hash_code AvgPoolGradNode::getHash() const { |
1214 | return llvm::hash_combine( |
1215 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
1216 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
1217 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
1218 | Layout_, |
1219 | CountIncludePads_, |
1220 | Input_, |
1221 | OriginalOutputForResult_, |
1222 | GradOfOriginalOutputNamedResult_); |
1223 | } |
1224 | |
1225 | unsigned AvgPoolNode::getNumInputs() const { |
1226 | return 1; |
1227 | } |
1228 | |
1229 | std::string AvgPoolNode::getInputName(unsigned idx) const { |
1230 | if (idx == 0) { return "Input" ; } |
1231 | idx -= 1; |
1232 | llvm_unreachable("Invalid index" ); |
1233 | } |
1234 | |
1235 | NodeValue AvgPoolNode::getNthInput(unsigned idx) { |
1236 | if (idx == 0) { return Input_; } |
1237 | idx -= 1; |
1238 | llvm_unreachable("Invalid index" ); |
1239 | } |
1240 | |
1241 | void AvgPoolNode::setNthInput(unsigned idx, NodeValue val) { |
1242 | if (idx == 0) { Input_ = val; return; } |
1243 | idx -= 1; |
1244 | llvm_unreachable("Invalid index" ); |
1245 | } |
1246 | |
1247 | llvm::StringRef AvgPoolNode::getOutputName(unsigned idx) const { |
1248 | if (idx == 0) { return "Result" ; } |
1249 | llvm_unreachable("Invalid index" ); |
1250 | } |
1251 | |
1252 | std::string AvgPoolNode::getDebugDesc() const { |
1253 | DescriptionBuilder db(getKindName()); |
1254 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1255 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1256 | db |
1257 | .addParam("Input" , *(getInput().getType())) |
1258 | .addParam("Kernels" , getKernels()) |
1259 | .addParam("Strides" , getStrides()) |
1260 | .addParam("Pads" , getPads()) |
1261 | .addParam("Layout" , getLayout()) |
1262 | .addParam("CountIncludePads" , getCountIncludePads()) |
1263 | .addParam("Users" , getNumUsers()); |
1264 | db.addParam("Result" , *(getResult().getType())); |
1265 | return db; |
1266 | } |
1267 | |
1268 | void AvgPoolNode::visit(Node *parent, NodeWalker *visitor) { |
1269 | if (!visitor->shouldVisit(parent, this)) { return; } |
1270 | visitor->pre(parent, this); |
1271 | if (hasPredicate()) |
1272 | getPredicate().getNode()->visit(this, visitor); |
1273 | getInput().getNode()->visit(this, visitor); |
1274 | visitor->post(parent, this); |
1275 | } |
1276 | |
1277 | bool AvgPoolNode::isEqual(const AvgPoolNode &other) const { |
1278 | return true && |
1279 | Input_ == other.Input_ && |
1280 | predicate_ == other.predicate_ && |
1281 | Kernels_ == other.Kernels_ && |
1282 | Strides_ == other.Strides_ && |
1283 | Pads_ == other.Pads_ && |
1284 | Layout_ == other.Layout_ && |
1285 | CountIncludePads_ == other.CountIncludePads_ && |
1286 | getType(0) == other.getType(0); |
1287 | } |
1288 | |
1289 | Node* AvgPoolNode::clone() const { |
1290 | return new AvgPoolNode(getName(), getResult().getType(), getInput(), getKernels(), getStrides(), getPads(), getLayout(), getCountIncludePads()); |
1291 | } |
1292 | |
1293 | llvm::hash_code AvgPoolNode::getHash() const { |
1294 | return llvm::hash_combine( |
1295 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
1296 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
1297 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
1298 | Layout_, |
1299 | CountIncludePads_, |
1300 | Input_); |
1301 | } |
1302 | |
1303 | AvgPoolGradNode *AvgPoolNode::getGrad(GraphGradMapper &builder) { |
1304 | auto *x = new AvgPoolGradNode(getName().str() + "_grad" , getInput(), getResult(), builder.getGradient(getResult()), getKernels(), getStrides(), getPads(), getLayout(), getCountIncludePads()); |
1305 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
1306 | return x; |
1307 | } |
1308 | |
1309 | unsigned AdaptiveAvgPoolGradNode::getNumInputs() const { |
1310 | return 3; |
1311 | } |
1312 | |
1313 | std::string AdaptiveAvgPoolGradNode::getInputName(unsigned idx) const { |
1314 | if (idx == 0) { return "Input" ; } |
1315 | if (idx == 1) { return "OriginalOutputForResult" ; } |
1316 | if (idx == 2) { return "GradOfOriginalOutputNamedResult" ; } |
1317 | idx -= 3; |
1318 | llvm_unreachable("Invalid index" ); |
1319 | } |
1320 | |
1321 | NodeValue AdaptiveAvgPoolGradNode::getNthInput(unsigned idx) { |
1322 | if (idx == 0) { return Input_; } |
1323 | if (idx == 1) { return OriginalOutputForResult_; } |
1324 | if (idx == 2) { return GradOfOriginalOutputNamedResult_; } |
1325 | idx -= 3; |
1326 | llvm_unreachable("Invalid index" ); |
1327 | } |
1328 | |
1329 | void AdaptiveAvgPoolGradNode::setNthInput(unsigned idx, NodeValue val) { |
1330 | if (idx == 0) { Input_ = val; return; } |
1331 | if (idx == 1) { OriginalOutputForResult_ = val; return; } |
1332 | if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; } |
1333 | idx -= 3; |
1334 | llvm_unreachable("Invalid index" ); |
1335 | } |
1336 | |
1337 | llvm::StringRef AdaptiveAvgPoolGradNode::getOutputName(unsigned idx) const { |
1338 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
1339 | llvm_unreachable("Invalid index" ); |
1340 | } |
1341 | |
1342 | std::string AdaptiveAvgPoolGradNode::getDebugDesc() const { |
1343 | DescriptionBuilder db(getKindName()); |
1344 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1345 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1346 | db |
1347 | .addParam("Input" , *(getInput().getType())) |
1348 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
1349 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
1350 | .addParam("Users" , getNumUsers()); |
1351 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
1352 | return db; |
1353 | } |
1354 | |
1355 | void AdaptiveAvgPoolGradNode::visit(Node *parent, NodeWalker *visitor) { |
1356 | if (!visitor->shouldVisit(parent, this)) { return; } |
1357 | visitor->pre(parent, this); |
1358 | if (hasPredicate()) |
1359 | getPredicate().getNode()->visit(this, visitor); |
1360 | getInput().getNode()->visit(this, visitor); |
1361 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
1362 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
1363 | visitor->post(parent, this); |
1364 | } |
1365 | |
1366 | bool AdaptiveAvgPoolGradNode::isEqual(const AdaptiveAvgPoolGradNode &other) const { |
1367 | return true && |
1368 | Input_ == other.Input_ && |
1369 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
1370 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
1371 | predicate_ == other.predicate_ && |
1372 | getType(0) == other.getType(0); |
1373 | } |
1374 | |
1375 | Node* AdaptiveAvgPoolGradNode::clone() const { |
1376 | return new AdaptiveAvgPoolGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
1377 | } |
1378 | |
1379 | llvm::hash_code AdaptiveAvgPoolGradNode::getHash() const { |
1380 | return llvm::hash_combine( |
1381 | Input_, |
1382 | OriginalOutputForResult_, |
1383 | GradOfOriginalOutputNamedResult_); |
1384 | } |
1385 | |
1386 | unsigned AdaptiveAvgPoolNode::getNumInputs() const { |
1387 | return 1; |
1388 | } |
1389 | |
1390 | std::string AdaptiveAvgPoolNode::getInputName(unsigned idx) const { |
1391 | if (idx == 0) { return "Input" ; } |
1392 | idx -= 1; |
1393 | llvm_unreachable("Invalid index" ); |
1394 | } |
1395 | |
1396 | NodeValue AdaptiveAvgPoolNode::getNthInput(unsigned idx) { |
1397 | if (idx == 0) { return Input_; } |
1398 | idx -= 1; |
1399 | llvm_unreachable("Invalid index" ); |
1400 | } |
1401 | |
1402 | void AdaptiveAvgPoolNode::setNthInput(unsigned idx, NodeValue val) { |
1403 | if (idx == 0) { Input_ = val; return; } |
1404 | idx -= 1; |
1405 | llvm_unreachable("Invalid index" ); |
1406 | } |
1407 | |
1408 | llvm::StringRef AdaptiveAvgPoolNode::getOutputName(unsigned idx) const { |
1409 | if (idx == 0) { return "Result" ; } |
1410 | llvm_unreachable("Invalid index" ); |
1411 | } |
1412 | |
1413 | std::string AdaptiveAvgPoolNode::getDebugDesc() const { |
1414 | DescriptionBuilder db(getKindName()); |
1415 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1416 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1417 | db |
1418 | .addParam("Input" , *(getInput().getType())) |
1419 | .addParam("Users" , getNumUsers()); |
1420 | db.addParam("Result" , *(getResult().getType())); |
1421 | return db; |
1422 | } |
1423 | |
1424 | void AdaptiveAvgPoolNode::visit(Node *parent, NodeWalker *visitor) { |
1425 | if (!visitor->shouldVisit(parent, this)) { return; } |
1426 | visitor->pre(parent, this); |
1427 | if (hasPredicate()) |
1428 | getPredicate().getNode()->visit(this, visitor); |
1429 | getInput().getNode()->visit(this, visitor); |
1430 | visitor->post(parent, this); |
1431 | } |
1432 | |
1433 | bool AdaptiveAvgPoolNode::isEqual(const AdaptiveAvgPoolNode &other) const { |
1434 | return true && |
1435 | Input_ == other.Input_ && |
1436 | predicate_ == other.predicate_ && |
1437 | getType(0) == other.getType(0); |
1438 | } |
1439 | |
1440 | Node* AdaptiveAvgPoolNode::clone() const { |
1441 | return new AdaptiveAvgPoolNode(getName(), getResult().getType(), getInput()); |
1442 | } |
1443 | |
1444 | llvm::hash_code AdaptiveAvgPoolNode::getHash() const { |
1445 | return llvm::hash_combine( |
1446 | Input_); |
1447 | } |
1448 | |
1449 | AdaptiveAvgPoolGradNode *AdaptiveAvgPoolNode::getGrad(GraphGradMapper &builder) { |
1450 | auto *x = new AdaptiveAvgPoolGradNode(getName().str() + "_grad" , getInput(), getResult(), builder.getGradient(getResult())); |
1451 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
1452 | return x; |
1453 | } |
1454 | |
1455 | unsigned GemmNode::getNumInputs() const { |
1456 | return 3; |
1457 | } |
1458 | |
1459 | std::string GemmNode::getInputName(unsigned idx) const { |
1460 | if (idx == 0) { return "A" ; } |
1461 | if (idx == 1) { return "B" ; } |
1462 | if (idx == 2) { return "C" ; } |
1463 | idx -= 3; |
1464 | llvm_unreachable("Invalid index" ); |
1465 | } |
1466 | |
1467 | NodeValue GemmNode::getNthInput(unsigned idx) { |
1468 | if (idx == 0) { return A_; } |
1469 | if (idx == 1) { return B_; } |
1470 | if (idx == 2) { return C_; } |
1471 | idx -= 3; |
1472 | llvm_unreachable("Invalid index" ); |
1473 | } |
1474 | |
1475 | void GemmNode::setNthInput(unsigned idx, NodeValue val) { |
1476 | if (idx == 0) { A_ = val; return; } |
1477 | if (idx == 1) { B_ = val; return; } |
1478 | if (idx == 2) { C_ = val; return; } |
1479 | idx -= 3; |
1480 | llvm_unreachable("Invalid index" ); |
1481 | } |
1482 | |
1483 | llvm::StringRef GemmNode::getOutputName(unsigned idx) const { |
1484 | if (idx == 0) { return "Result" ; } |
1485 | llvm_unreachable("Invalid index" ); |
1486 | } |
1487 | |
1488 | std::string GemmNode::getDebugDesc() const { |
1489 | DescriptionBuilder db(getKindName()); |
1490 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1491 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1492 | db |
1493 | .addParam("A" , *(getA().getType())) |
1494 | .addParam("B" , *(getB().getType())) |
1495 | .addParam("C" , *(getC().getType())) |
1496 | .addParam("Alpha" , getAlpha()) |
1497 | .addParam("Beta" , getBeta()) |
1498 | .addParam("TransposeA" , getTransposeA()) |
1499 | .addParam("TransposeB" , getTransposeB()) |
1500 | .addParam("Users" , getNumUsers()); |
1501 | db.addParam("Result" , *(getResult().getType())); |
1502 | return db; |
1503 | } |
1504 | |
1505 | void GemmNode::visit(Node *parent, NodeWalker *visitor) { |
1506 | if (!visitor->shouldVisit(parent, this)) { return; } |
1507 | visitor->pre(parent, this); |
1508 | if (hasPredicate()) |
1509 | getPredicate().getNode()->visit(this, visitor); |
1510 | getA().getNode()->visit(this, visitor); |
1511 | getB().getNode()->visit(this, visitor); |
1512 | getC().getNode()->visit(this, visitor); |
1513 | visitor->post(parent, this); |
1514 | } |
1515 | |
1516 | bool GemmNode::isEqual(const GemmNode &other) const { |
1517 | return true && |
1518 | A_ == other.A_ && |
1519 | B_ == other.B_ && |
1520 | C_ == other.C_ && |
1521 | predicate_ == other.predicate_ && |
1522 | Alpha_ == other.Alpha_ && |
1523 | Beta_ == other.Beta_ && |
1524 | TransposeA_ == other.TransposeA_ && |
1525 | TransposeB_ == other.TransposeB_ && |
1526 | getType(0) == other.getType(0); |
1527 | } |
1528 | |
1529 | Node* GemmNode::clone() const { |
1530 | return new GemmNode(getName(), getResult().getType(), getA(), getB(), getC(), getAlpha(), getBeta(), getTransposeA(), getTransposeB()); |
1531 | } |
1532 | |
1533 | llvm::hash_code GemmNode::getHash() const { |
1534 | return llvm::hash_combine( |
1535 | toBinary(Alpha_), |
1536 | toBinary(Beta_), |
1537 | TransposeA_, |
1538 | TransposeB_, |
1539 | A_, |
1540 | B_, |
1541 | C_); |
1542 | } |
1543 | |
1544 | unsigned FullyConnectedGradNode::getNumInputs() const { |
1545 | return 5; |
1546 | } |
1547 | |
1548 | std::string FullyConnectedGradNode::getInputName(unsigned idx) const { |
1549 | if (idx == 0) { return "Input" ; } |
1550 | if (idx == 1) { return "Weights" ; } |
1551 | if (idx == 2) { return "Bias" ; } |
1552 | if (idx == 3) { return "OriginalOutputForResult" ; } |
1553 | if (idx == 4) { return "GradOfOriginalOutputNamedResult" ; } |
1554 | idx -= 5; |
1555 | llvm_unreachable("Invalid index" ); |
1556 | } |
1557 | |
1558 | NodeValue FullyConnectedGradNode::getNthInput(unsigned idx) { |
1559 | if (idx == 0) { return Input_; } |
1560 | if (idx == 1) { return Weights_; } |
1561 | if (idx == 2) { return Bias_; } |
1562 | if (idx == 3) { return OriginalOutputForResult_; } |
1563 | if (idx == 4) { return GradOfOriginalOutputNamedResult_; } |
1564 | idx -= 5; |
1565 | llvm_unreachable("Invalid index" ); |
1566 | } |
1567 | |
1568 | void FullyConnectedGradNode::setNthInput(unsigned idx, NodeValue val) { |
1569 | if (idx == 0) { Input_ = val; return; } |
1570 | if (idx == 1) { Weights_ = val; return; } |
1571 | if (idx == 2) { Bias_ = val; return; } |
1572 | if (idx == 3) { OriginalOutputForResult_ = val; return; } |
1573 | if (idx == 4) { GradOfOriginalOutputNamedResult_ = val; return; } |
1574 | idx -= 5; |
1575 | llvm_unreachable("Invalid index" ); |
1576 | } |
1577 | |
1578 | llvm::StringRef FullyConnectedGradNode::getOutputName(unsigned idx) const { |
1579 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
1580 | if (idx == 1) { return "GradOfInputNamedWeights" ; } |
1581 | if (idx == 2) { return "GradOfInputNamedBias" ; } |
1582 | llvm_unreachable("Invalid index" ); |
1583 | } |
1584 | |
1585 | std::string FullyConnectedGradNode::getDebugDesc() const { |
1586 | DescriptionBuilder db(getKindName()); |
1587 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1588 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1589 | db |
1590 | .addParam("Input" , *(getInput().getType())) |
1591 | .addParam("Weights" , *(getWeights().getType())) |
1592 | .addParam("Bias" , *(getBias().getType())) |
1593 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
1594 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
1595 | .addParam("Users" , getNumUsers()); |
1596 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
1597 | db.addParam("GradOfInputNamedWeights" , *(getGradOfInputNamedWeights().getType())); |
1598 | db.addParam("GradOfInputNamedBias" , *(getGradOfInputNamedBias().getType())); |
1599 | return db; |
1600 | } |
1601 | |
1602 | void FullyConnectedGradNode::visit(Node *parent, NodeWalker *visitor) { |
1603 | if (!visitor->shouldVisit(parent, this)) { return; } |
1604 | visitor->pre(parent, this); |
1605 | if (hasPredicate()) |
1606 | getPredicate().getNode()->visit(this, visitor); |
1607 | getInput().getNode()->visit(this, visitor); |
1608 | getWeights().getNode()->visit(this, visitor); |
1609 | getBias().getNode()->visit(this, visitor); |
1610 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
1611 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
1612 | visitor->post(parent, this); |
1613 | } |
1614 | |
1615 | bool FullyConnectedGradNode::isEqual(const FullyConnectedGradNode &other) const { |
1616 | return true && |
1617 | Input_ == other.Input_ && |
1618 | Weights_ == other.Weights_ && |
1619 | Bias_ == other.Bias_ && |
1620 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
1621 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
1622 | predicate_ == other.predicate_ && |
1623 | getType(0) == other.getType(0) && |
1624 | getType(1) == other.getType(1) && |
1625 | getType(2) == other.getType(2); |
1626 | } |
1627 | |
1628 | Node* FullyConnectedGradNode::clone() const { |
1629 | return new FullyConnectedGradNode(getName(), getInput(), getWeights(), getBias(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
1630 | } |
1631 | |
1632 | llvm::hash_code FullyConnectedGradNode::getHash() const { |
1633 | return llvm::hash_combine( |
1634 | Input_, |
1635 | Weights_, |
1636 | Bias_, |
1637 | OriginalOutputForResult_, |
1638 | GradOfOriginalOutputNamedResult_); |
1639 | } |
1640 | |
1641 | unsigned FullyConnectedNode::getNumInputs() const { |
1642 | return 3; |
1643 | } |
1644 | |
1645 | std::string FullyConnectedNode::getInputName(unsigned idx) const { |
1646 | if (idx == 0) { return "Input" ; } |
1647 | if (idx == 1) { return "Weights" ; } |
1648 | if (idx == 2) { return "Bias" ; } |
1649 | idx -= 3; |
1650 | llvm_unreachable("Invalid index" ); |
1651 | } |
1652 | |
1653 | NodeValue FullyConnectedNode::getNthInput(unsigned idx) { |
1654 | if (idx == 0) { return Input_; } |
1655 | if (idx == 1) { return Weights_; } |
1656 | if (idx == 2) { return Bias_; } |
1657 | idx -= 3; |
1658 | llvm_unreachable("Invalid index" ); |
1659 | } |
1660 | |
1661 | void FullyConnectedNode::setNthInput(unsigned idx, NodeValue val) { |
1662 | if (idx == 0) { Input_ = val; return; } |
1663 | if (idx == 1) { Weights_ = val; return; } |
1664 | if (idx == 2) { Bias_ = val; return; } |
1665 | idx -= 3; |
1666 | llvm_unreachable("Invalid index" ); |
1667 | } |
1668 | |
1669 | llvm::StringRef FullyConnectedNode::getOutputName(unsigned idx) const { |
1670 | if (idx == 0) { return "Result" ; } |
1671 | llvm_unreachable("Invalid index" ); |
1672 | } |
1673 | |
1674 | std::string FullyConnectedNode::getDebugDesc() const { |
1675 | DescriptionBuilder db(getKindName()); |
1676 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1677 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1678 | db |
1679 | .addParam("Input" , *(getInput().getType())) |
1680 | .addParam("Weights" , *(getWeights().getType())) |
1681 | .addParam("Bias" , *(getBias().getType())) |
1682 | .addParam("Users" , getNumUsers()); |
1683 | db.addParam("Result" , *(getResult().getType())); |
1684 | return db; |
1685 | } |
1686 | |
1687 | void FullyConnectedNode::visit(Node *parent, NodeWalker *visitor) { |
1688 | if (!visitor->shouldVisit(parent, this)) { return; } |
1689 | visitor->pre(parent, this); |
1690 | if (hasPredicate()) |
1691 | getPredicate().getNode()->visit(this, visitor); |
1692 | getInput().getNode()->visit(this, visitor); |
1693 | getWeights().getNode()->visit(this, visitor); |
1694 | getBias().getNode()->visit(this, visitor); |
1695 | visitor->post(parent, this); |
1696 | } |
1697 | |
1698 | bool FullyConnectedNode::isEqual(const FullyConnectedNode &other) const { |
1699 | return true && |
1700 | Input_ == other.Input_ && |
1701 | Weights_ == other.Weights_ && |
1702 | Bias_ == other.Bias_ && |
1703 | predicate_ == other.predicate_ && |
1704 | getType(0) == other.getType(0); |
1705 | } |
1706 | |
1707 | Node* FullyConnectedNode::clone() const { |
1708 | return new FullyConnectedNode(getName(), getResult().getType(), getInput(), getWeights(), getBias()); |
1709 | } |
1710 | |
1711 | llvm::hash_code FullyConnectedNode::getHash() const { |
1712 | return llvm::hash_combine( |
1713 | Input_, |
1714 | Weights_, |
1715 | Bias_); |
1716 | } |
1717 | |
1718 | FullyConnectedGradNode *FullyConnectedNode::getGrad(GraphGradMapper &builder) { |
1719 | auto *x = new FullyConnectedGradNode(getName().str() + "_grad" , getInput(), getWeights(), getBias(), getResult(), builder.getGradient(getResult())); |
1720 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
1721 | builder.addGradient(getWeights(), x->getGradOfInputNamedWeights()); |
1722 | builder.addGradient(getBias(), x->getGradOfInputNamedBias()); |
1723 | return x; |
1724 | } |
1725 | |
1726 | unsigned RowwiseQuantizedFullyConnectedNode::getNumInputs() const { |
1727 | return 5; |
1728 | } |
1729 | |
1730 | std::string RowwiseQuantizedFullyConnectedNode::getInputName(unsigned idx) const { |
1731 | if (idx == 0) { return "Input" ; } |
1732 | if (idx == 1) { return "Weights" ; } |
1733 | if (idx == 2) { return "Scales" ; } |
1734 | if (idx == 3) { return "Offsets" ; } |
1735 | if (idx == 4) { return "Bias" ; } |
1736 | idx -= 5; |
1737 | llvm_unreachable("Invalid index" ); |
1738 | } |
1739 | |
1740 | NodeValue RowwiseQuantizedFullyConnectedNode::getNthInput(unsigned idx) { |
1741 | if (idx == 0) { return Input_; } |
1742 | if (idx == 1) { return Weights_; } |
1743 | if (idx == 2) { return Scales_; } |
1744 | if (idx == 3) { return Offsets_; } |
1745 | if (idx == 4) { return Bias_; } |
1746 | idx -= 5; |
1747 | llvm_unreachable("Invalid index" ); |
1748 | } |
1749 | |
1750 | void RowwiseQuantizedFullyConnectedNode::setNthInput(unsigned idx, NodeValue val) { |
1751 | if (idx == 0) { Input_ = val; return; } |
1752 | if (idx == 1) { Weights_ = val; return; } |
1753 | if (idx == 2) { Scales_ = val; return; } |
1754 | if (idx == 3) { Offsets_ = val; return; } |
1755 | if (idx == 4) { Bias_ = val; return; } |
1756 | idx -= 5; |
1757 | llvm_unreachable("Invalid index" ); |
1758 | } |
1759 | |
1760 | llvm::StringRef RowwiseQuantizedFullyConnectedNode::getOutputName(unsigned idx) const { |
1761 | if (idx == 0) { return "Result" ; } |
1762 | llvm_unreachable("Invalid index" ); |
1763 | } |
1764 | |
1765 | std::string RowwiseQuantizedFullyConnectedNode::getDebugDesc() const { |
1766 | DescriptionBuilder db(getKindName()); |
1767 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1768 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1769 | db |
1770 | .addParam("Input" , *(getInput().getType())) |
1771 | .addParam("Weights" , *(getWeights().getType())) |
1772 | .addParam("Scales" , *(getScales().getType())) |
1773 | .addParam("Offsets" , *(getOffsets().getType())) |
1774 | .addParam("Bias" , *(getBias().getType())) |
1775 | .addParam("Users" , getNumUsers()); |
1776 | db.addParam("Result" , *(getResult().getType())); |
1777 | return db; |
1778 | } |
1779 | |
1780 | void RowwiseQuantizedFullyConnectedNode::visit(Node *parent, NodeWalker *visitor) { |
1781 | if (!visitor->shouldVisit(parent, this)) { return; } |
1782 | visitor->pre(parent, this); |
1783 | if (hasPredicate()) |
1784 | getPredicate().getNode()->visit(this, visitor); |
1785 | getInput().getNode()->visit(this, visitor); |
1786 | getWeights().getNode()->visit(this, visitor); |
1787 | getScales().getNode()->visit(this, visitor); |
1788 | getOffsets().getNode()->visit(this, visitor); |
1789 | getBias().getNode()->visit(this, visitor); |
1790 | visitor->post(parent, this); |
1791 | } |
1792 | |
1793 | bool RowwiseQuantizedFullyConnectedNode::isEqual(const RowwiseQuantizedFullyConnectedNode &other) const { |
1794 | return true && |
1795 | Input_ == other.Input_ && |
1796 | Weights_ == other.Weights_ && |
1797 | Scales_ == other.Scales_ && |
1798 | Offsets_ == other.Offsets_ && |
1799 | Bias_ == other.Bias_ && |
1800 | predicate_ == other.predicate_ && |
1801 | getType(0) == other.getType(0); |
1802 | } |
1803 | |
1804 | Node* RowwiseQuantizedFullyConnectedNode::clone() const { |
1805 | return new RowwiseQuantizedFullyConnectedNode(getName(), getResult().getType(), getInput(), getWeights(), getScales(), getOffsets(), getBias()); |
1806 | } |
1807 | |
1808 | llvm::hash_code RowwiseQuantizedFullyConnectedNode::getHash() const { |
1809 | return llvm::hash_combine( |
1810 | Input_, |
1811 | Weights_, |
1812 | Scales_, |
1813 | Offsets_, |
1814 | Bias_); |
1815 | } |
1816 | |
1817 | unsigned DynamicQuantizedFullyConnectedNode::getNumInputs() const { |
1818 | return 3; |
1819 | } |
1820 | |
1821 | std::string DynamicQuantizedFullyConnectedNode::getInputName(unsigned idx) const { |
1822 | if (idx == 0) { return "Input" ; } |
1823 | if (idx == 1) { return "Weights" ; } |
1824 | if (idx == 2) { return "Bias" ; } |
1825 | idx -= 3; |
1826 | llvm_unreachable("Invalid index" ); |
1827 | } |
1828 | |
1829 | NodeValue DynamicQuantizedFullyConnectedNode::getNthInput(unsigned idx) { |
1830 | if (idx == 0) { return Input_; } |
1831 | if (idx == 1) { return Weights_; } |
1832 | if (idx == 2) { return Bias_; } |
1833 | idx -= 3; |
1834 | llvm_unreachable("Invalid index" ); |
1835 | } |
1836 | |
1837 | void DynamicQuantizedFullyConnectedNode::setNthInput(unsigned idx, NodeValue val) { |
1838 | if (idx == 0) { Input_ = val; return; } |
1839 | if (idx == 1) { Weights_ = val; return; } |
1840 | if (idx == 2) { Bias_ = val; return; } |
1841 | idx -= 3; |
1842 | llvm_unreachable("Invalid index" ); |
1843 | } |
1844 | |
1845 | llvm::StringRef DynamicQuantizedFullyConnectedNode::getOutputName(unsigned idx) const { |
1846 | if (idx == 0) { return "Result" ; } |
1847 | llvm_unreachable("Invalid index" ); |
1848 | } |
1849 | |
1850 | std::string DynamicQuantizedFullyConnectedNode::getDebugDesc() const { |
1851 | DescriptionBuilder db(getKindName()); |
1852 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1853 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1854 | db |
1855 | .addParam("Input" , *(getInput().getType())) |
1856 | .addParam("Weights" , *(getWeights().getType())) |
1857 | .addParam("Bias" , *(getBias().getType())) |
1858 | .addParam("IsSymmetric" , getIsSymmetric()) |
1859 | .addParam("IsPerBatchElement" , getIsPerBatchElement()) |
1860 | .addParam("Users" , getNumUsers()); |
1861 | db.addParam("Result" , *(getResult().getType())); |
1862 | return db; |
1863 | } |
1864 | |
1865 | void DynamicQuantizedFullyConnectedNode::visit(Node *parent, NodeWalker *visitor) { |
1866 | if (!visitor->shouldVisit(parent, this)) { return; } |
1867 | visitor->pre(parent, this); |
1868 | if (hasPredicate()) |
1869 | getPredicate().getNode()->visit(this, visitor); |
1870 | getInput().getNode()->visit(this, visitor); |
1871 | getWeights().getNode()->visit(this, visitor); |
1872 | getBias().getNode()->visit(this, visitor); |
1873 | visitor->post(parent, this); |
1874 | } |
1875 | |
1876 | bool DynamicQuantizedFullyConnectedNode::isEqual(const DynamicQuantizedFullyConnectedNode &other) const { |
1877 | return true && |
1878 | Input_ == other.Input_ && |
1879 | Weights_ == other.Weights_ && |
1880 | Bias_ == other.Bias_ && |
1881 | predicate_ == other.predicate_ && |
1882 | IsSymmetric_ == other.IsSymmetric_ && |
1883 | IsPerBatchElement_ == other.IsPerBatchElement_ && |
1884 | getType(0) == other.getType(0); |
1885 | } |
1886 | |
1887 | Node* DynamicQuantizedFullyConnectedNode::clone() const { |
1888 | return new DynamicQuantizedFullyConnectedNode(getName(), getResult().getType(), getInput(), getWeights(), getBias(), getIsSymmetric(), getIsPerBatchElement()); |
1889 | } |
1890 | |
1891 | llvm::hash_code DynamicQuantizedFullyConnectedNode::getHash() const { |
1892 | return llvm::hash_combine( |
1893 | IsSymmetric_, |
1894 | IsPerBatchElement_, |
1895 | Input_, |
1896 | Weights_, |
1897 | Bias_); |
1898 | } |
1899 | |
1900 | unsigned DynamicRowwiseQuantizedFullyConnectedNode::getNumInputs() const { |
1901 | return 5; |
1902 | } |
1903 | |
1904 | std::string DynamicRowwiseQuantizedFullyConnectedNode::getInputName(unsigned idx) const { |
1905 | if (idx == 0) { return "Input" ; } |
1906 | if (idx == 1) { return "Weights" ; } |
1907 | if (idx == 2) { return "Bias" ; } |
1908 | if (idx == 3) { return "Scales" ; } |
1909 | if (idx == 4) { return "Offsets" ; } |
1910 | idx -= 5; |
1911 | llvm_unreachable("Invalid index" ); |
1912 | } |
1913 | |
1914 | NodeValue DynamicRowwiseQuantizedFullyConnectedNode::getNthInput(unsigned idx) { |
1915 | if (idx == 0) { return Input_; } |
1916 | if (idx == 1) { return Weights_; } |
1917 | if (idx == 2) { return Bias_; } |
1918 | if (idx == 3) { return Scales_; } |
1919 | if (idx == 4) { return Offsets_; } |
1920 | idx -= 5; |
1921 | llvm_unreachable("Invalid index" ); |
1922 | } |
1923 | |
1924 | void DynamicRowwiseQuantizedFullyConnectedNode::setNthInput(unsigned idx, NodeValue val) { |
1925 | if (idx == 0) { Input_ = val; return; } |
1926 | if (idx == 1) { Weights_ = val; return; } |
1927 | if (idx == 2) { Bias_ = val; return; } |
1928 | if (idx == 3) { Scales_ = val; return; } |
1929 | if (idx == 4) { Offsets_ = val; return; } |
1930 | idx -= 5; |
1931 | llvm_unreachable("Invalid index" ); |
1932 | } |
1933 | |
1934 | llvm::StringRef DynamicRowwiseQuantizedFullyConnectedNode::getOutputName(unsigned idx) const { |
1935 | if (idx == 0) { return "Result" ; } |
1936 | llvm_unreachable("Invalid index" ); |
1937 | } |
1938 | |
1939 | std::string DynamicRowwiseQuantizedFullyConnectedNode::getDebugDesc() const { |
1940 | DescriptionBuilder db(getKindName()); |
1941 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
1942 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
1943 | db |
1944 | .addParam("Input" , *(getInput().getType())) |
1945 | .addParam("Weights" , *(getWeights().getType())) |
1946 | .addParam("Bias" , *(getBias().getType())) |
1947 | .addParam("Scales" , *(getScales().getType())) |
1948 | .addParam("Offsets" , *(getOffsets().getType())) |
1949 | .addParam("IsSymmetric" , getIsSymmetric()) |
1950 | .addParam("IsPerBatchElement" , getIsPerBatchElement()) |
1951 | .addParam("Users" , getNumUsers()); |
1952 | db.addParam("Result" , *(getResult().getType())); |
1953 | return db; |
1954 | } |
1955 | |
1956 | void DynamicRowwiseQuantizedFullyConnectedNode::visit(Node *parent, NodeWalker *visitor) { |
1957 | if (!visitor->shouldVisit(parent, this)) { return; } |
1958 | visitor->pre(parent, this); |
1959 | if (hasPredicate()) |
1960 | getPredicate().getNode()->visit(this, visitor); |
1961 | getInput().getNode()->visit(this, visitor); |
1962 | getWeights().getNode()->visit(this, visitor); |
1963 | getBias().getNode()->visit(this, visitor); |
1964 | getScales().getNode()->visit(this, visitor); |
1965 | getOffsets().getNode()->visit(this, visitor); |
1966 | visitor->post(parent, this); |
1967 | } |
1968 | |
1969 | bool DynamicRowwiseQuantizedFullyConnectedNode::isEqual(const DynamicRowwiseQuantizedFullyConnectedNode &other) const { |
1970 | return true && |
1971 | Input_ == other.Input_ && |
1972 | Weights_ == other.Weights_ && |
1973 | Bias_ == other.Bias_ && |
1974 | Scales_ == other.Scales_ && |
1975 | Offsets_ == other.Offsets_ && |
1976 | predicate_ == other.predicate_ && |
1977 | IsSymmetric_ == other.IsSymmetric_ && |
1978 | IsPerBatchElement_ == other.IsPerBatchElement_ && |
1979 | getType(0) == other.getType(0); |
1980 | } |
1981 | |
1982 | Node* DynamicRowwiseQuantizedFullyConnectedNode::clone() const { |
1983 | return new DynamicRowwiseQuantizedFullyConnectedNode(getName(), getResult().getType(), getInput(), getWeights(), getBias(), getScales(), getOffsets(), getIsSymmetric(), getIsPerBatchElement()); |
1984 | } |
1985 | |
1986 | llvm::hash_code DynamicRowwiseQuantizedFullyConnectedNode::getHash() const { |
1987 | return llvm::hash_combine( |
1988 | IsSymmetric_, |
1989 | IsPerBatchElement_, |
1990 | Input_, |
1991 | Weights_, |
1992 | Bias_, |
1993 | Scales_, |
1994 | Offsets_); |
1995 | } |
1996 | |
1997 | unsigned BatchNormalizationGradNode::getNumInputs() const { |
1998 | return 7; |
1999 | } |
2000 | |
2001 | std::string BatchNormalizationGradNode::getInputName(unsigned idx) const { |
2002 | if (idx == 0) { return "Input" ; } |
2003 | if (idx == 1) { return "Scale" ; } |
2004 | if (idx == 2) { return "Bias" ; } |
2005 | if (idx == 3) { return "Mean" ; } |
2006 | if (idx == 4) { return "Var" ; } |
2007 | if (idx == 5) { return "OriginalOutputForResult" ; } |
2008 | if (idx == 6) { return "GradOfOriginalOutputNamedResult" ; } |
2009 | idx -= 7; |
2010 | llvm_unreachable("Invalid index" ); |
2011 | } |
2012 | |
2013 | NodeValue BatchNormalizationGradNode::getNthInput(unsigned idx) { |
2014 | if (idx == 0) { return Input_; } |
2015 | if (idx == 1) { return Scale_; } |
2016 | if (idx == 2) { return Bias_; } |
2017 | if (idx == 3) { return Mean_; } |
2018 | if (idx == 4) { return Var_; } |
2019 | if (idx == 5) { return OriginalOutputForResult_; } |
2020 | if (idx == 6) { return GradOfOriginalOutputNamedResult_; } |
2021 | idx -= 7; |
2022 | llvm_unreachable("Invalid index" ); |
2023 | } |
2024 | |
2025 | void BatchNormalizationGradNode::setNthInput(unsigned idx, NodeValue val) { |
2026 | if (idx == 0) { Input_ = val; return; } |
2027 | if (idx == 1) { Scale_ = val; return; } |
2028 | if (idx == 2) { Bias_ = val; return; } |
2029 | if (idx == 3) { Mean_ = val; return; } |
2030 | if (idx == 4) { Var_ = val; return; } |
2031 | if (idx == 5) { OriginalOutputForResult_ = val; return; } |
2032 | if (idx == 6) { GradOfOriginalOutputNamedResult_ = val; return; } |
2033 | idx -= 7; |
2034 | llvm_unreachable("Invalid index" ); |
2035 | } |
2036 | |
2037 | llvm::StringRef BatchNormalizationGradNode::getOutputName(unsigned idx) const { |
2038 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
2039 | if (idx == 1) { return "GradOfInputNamedScale" ; } |
2040 | if (idx == 2) { return "GradOfInputNamedBias" ; } |
2041 | if (idx == 3) { return "GradOfInputNamedMean" ; } |
2042 | if (idx == 4) { return "GradOfInputNamedVar" ; } |
2043 | llvm_unreachable("Invalid index" ); |
2044 | } |
2045 | |
2046 | std::string BatchNormalizationGradNode::getDebugDesc() const { |
2047 | DescriptionBuilder db(getKindName()); |
2048 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2049 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2050 | db |
2051 | .addParam("Input" , *(getInput().getType())) |
2052 | .addParam("Scale" , *(getScale().getType())) |
2053 | .addParam("Bias" , *(getBias().getType())) |
2054 | .addParam("Mean" , *(getMean().getType())) |
2055 | .addParam("Var" , *(getVar().getType())) |
2056 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
2057 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
2058 | .addParam("ChannelIdx" , getChannelIdx()) |
2059 | .addParam("Epsilon" , getEpsilon()) |
2060 | .addParam("Momentum" , getMomentum()) |
2061 | .addParam("Users" , getNumUsers()); |
2062 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
2063 | db.addParam("GradOfInputNamedScale" , *(getGradOfInputNamedScale().getType())); |
2064 | db.addParam("GradOfInputNamedBias" , *(getGradOfInputNamedBias().getType())); |
2065 | db.addParam("GradOfInputNamedMean" , *(getGradOfInputNamedMean().getType())); |
2066 | db.addParam("GradOfInputNamedVar" , *(getGradOfInputNamedVar().getType())); |
2067 | return db; |
2068 | } |
2069 | |
2070 | void BatchNormalizationGradNode::visit(Node *parent, NodeWalker *visitor) { |
2071 | if (!visitor->shouldVisit(parent, this)) { return; } |
2072 | visitor->pre(parent, this); |
2073 | if (hasPredicate()) |
2074 | getPredicate().getNode()->visit(this, visitor); |
2075 | getInput().getNode()->visit(this, visitor); |
2076 | getScale().getNode()->visit(this, visitor); |
2077 | getBias().getNode()->visit(this, visitor); |
2078 | getMean().getNode()->visit(this, visitor); |
2079 | getVar().getNode()->visit(this, visitor); |
2080 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
2081 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
2082 | visitor->post(parent, this); |
2083 | } |
2084 | |
2085 | bool BatchNormalizationGradNode::isEqual(const BatchNormalizationGradNode &other) const { |
2086 | return true && |
2087 | Input_ == other.Input_ && |
2088 | Scale_ == other.Scale_ && |
2089 | Bias_ == other.Bias_ && |
2090 | Mean_ == other.Mean_ && |
2091 | Var_ == other.Var_ && |
2092 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
2093 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
2094 | predicate_ == other.predicate_ && |
2095 | ChannelIdx_ == other.ChannelIdx_ && |
2096 | Epsilon_ == other.Epsilon_ && |
2097 | Momentum_ == other.Momentum_ && |
2098 | getType(0) == other.getType(0) && |
2099 | getType(1) == other.getType(1) && |
2100 | getType(2) == other.getType(2) && |
2101 | getType(3) == other.getType(3) && |
2102 | getType(4) == other.getType(4); |
2103 | } |
2104 | |
2105 | Node* BatchNormalizationGradNode::clone() const { |
2106 | return new BatchNormalizationGradNode(getName(), getInput(), getScale(), getBias(), getMean(), getVar(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getChannelIdx(), getEpsilon(), getMomentum()); |
2107 | } |
2108 | |
2109 | llvm::hash_code BatchNormalizationGradNode::getHash() const { |
2110 | return llvm::hash_combine( |
2111 | ChannelIdx_, |
2112 | toBinary(Epsilon_), |
2113 | toBinary(Momentum_), |
2114 | Input_, |
2115 | Scale_, |
2116 | Bias_, |
2117 | Mean_, |
2118 | Var_, |
2119 | OriginalOutputForResult_, |
2120 | GradOfOriginalOutputNamedResult_); |
2121 | } |
2122 | |
2123 | unsigned BatchNormalizationNode::getNumInputs() const { |
2124 | return 5; |
2125 | } |
2126 | |
2127 | std::string BatchNormalizationNode::getInputName(unsigned idx) const { |
2128 | if (idx == 0) { return "Input" ; } |
2129 | if (idx == 1) { return "Scale" ; } |
2130 | if (idx == 2) { return "Bias" ; } |
2131 | if (idx == 3) { return "Mean" ; } |
2132 | if (idx == 4) { return "Var" ; } |
2133 | idx -= 5; |
2134 | llvm_unreachable("Invalid index" ); |
2135 | } |
2136 | |
2137 | NodeValue BatchNormalizationNode::getNthInput(unsigned idx) { |
2138 | if (idx == 0) { return Input_; } |
2139 | if (idx == 1) { return Scale_; } |
2140 | if (idx == 2) { return Bias_; } |
2141 | if (idx == 3) { return Mean_; } |
2142 | if (idx == 4) { return Var_; } |
2143 | idx -= 5; |
2144 | llvm_unreachable("Invalid index" ); |
2145 | } |
2146 | |
2147 | void BatchNormalizationNode::setNthInput(unsigned idx, NodeValue val) { |
2148 | if (idx == 0) { Input_ = val; return; } |
2149 | if (idx == 1) { Scale_ = val; return; } |
2150 | if (idx == 2) { Bias_ = val; return; } |
2151 | if (idx == 3) { Mean_ = val; return; } |
2152 | if (idx == 4) { Var_ = val; return; } |
2153 | idx -= 5; |
2154 | llvm_unreachable("Invalid index" ); |
2155 | } |
2156 | |
2157 | llvm::StringRef BatchNormalizationNode::getOutputName(unsigned idx) const { |
2158 | if (idx == 0) { return "Result" ; } |
2159 | llvm_unreachable("Invalid index" ); |
2160 | } |
2161 | |
2162 | std::string BatchNormalizationNode::getDebugDesc() const { |
2163 | DescriptionBuilder db(getKindName()); |
2164 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2165 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2166 | db |
2167 | .addParam("Input" , *(getInput().getType())) |
2168 | .addParam("Scale" , *(getScale().getType())) |
2169 | .addParam("Bias" , *(getBias().getType())) |
2170 | .addParam("Mean" , *(getMean().getType())) |
2171 | .addParam("Var" , *(getVar().getType())) |
2172 | .addParam("ChannelIdx" , getChannelIdx()) |
2173 | .addParam("Epsilon" , getEpsilon()) |
2174 | .addParam("Momentum" , getMomentum()) |
2175 | .addParam("Users" , getNumUsers()); |
2176 | db.addParam("Result" , *(getResult().getType())); |
2177 | return db; |
2178 | } |
2179 | |
2180 | void BatchNormalizationNode::visit(Node *parent, NodeWalker *visitor) { |
2181 | if (!visitor->shouldVisit(parent, this)) { return; } |
2182 | visitor->pre(parent, this); |
2183 | if (hasPredicate()) |
2184 | getPredicate().getNode()->visit(this, visitor); |
2185 | getInput().getNode()->visit(this, visitor); |
2186 | getScale().getNode()->visit(this, visitor); |
2187 | getBias().getNode()->visit(this, visitor); |
2188 | getMean().getNode()->visit(this, visitor); |
2189 | getVar().getNode()->visit(this, visitor); |
2190 | visitor->post(parent, this); |
2191 | } |
2192 | |
2193 | bool BatchNormalizationNode::isEqual(const BatchNormalizationNode &other) const { |
2194 | return true && |
2195 | Input_ == other.Input_ && |
2196 | Scale_ == other.Scale_ && |
2197 | Bias_ == other.Bias_ && |
2198 | Mean_ == other.Mean_ && |
2199 | Var_ == other.Var_ && |
2200 | predicate_ == other.predicate_ && |
2201 | ChannelIdx_ == other.ChannelIdx_ && |
2202 | Epsilon_ == other.Epsilon_ && |
2203 | Momentum_ == other.Momentum_ && |
2204 | getType(0) == other.getType(0); |
2205 | } |
2206 | |
2207 | Node* BatchNormalizationNode::clone() const { |
2208 | return new BatchNormalizationNode(getName(), getResult().getType(), getInput(), getScale(), getBias(), getMean(), getVar(), getChannelIdx(), getEpsilon(), getMomentum()); |
2209 | } |
2210 | |
2211 | llvm::hash_code BatchNormalizationNode::getHash() const { |
2212 | return llvm::hash_combine( |
2213 | ChannelIdx_, |
2214 | toBinary(Epsilon_), |
2215 | toBinary(Momentum_), |
2216 | Input_, |
2217 | Scale_, |
2218 | Bias_, |
2219 | Mean_, |
2220 | Var_); |
2221 | } |
2222 | |
2223 | BatchNormalizationGradNode *BatchNormalizationNode::getGrad(GraphGradMapper &builder) { |
2224 | auto *x = new BatchNormalizationGradNode(getName().str() + "_grad" , getInput(), getScale(), getBias(), getMean(), getVar(), getResult(), builder.getGradient(getResult()), getChannelIdx(), getEpsilon(), getMomentum()); |
2225 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
2226 | builder.addGradient(getScale(), x->getGradOfInputNamedScale()); |
2227 | builder.addGradient(getBias(), x->getGradOfInputNamedBias()); |
2228 | builder.addGradient(getMean(), x->getGradOfInputNamedMean()); |
2229 | builder.addGradient(getVar(), x->getGradOfInputNamedVar()); |
2230 | return x; |
2231 | } |
2232 | |
2233 | unsigned InstanceNormalizationNode::getNumInputs() const { |
2234 | return 3; |
2235 | } |
2236 | |
2237 | std::string InstanceNormalizationNode::getInputName(unsigned idx) const { |
2238 | if (idx == 0) { return "Input" ; } |
2239 | if (idx == 1) { return "Scale" ; } |
2240 | if (idx == 2) { return "Bias" ; } |
2241 | idx -= 3; |
2242 | llvm_unreachable("Invalid index" ); |
2243 | } |
2244 | |
2245 | NodeValue InstanceNormalizationNode::getNthInput(unsigned idx) { |
2246 | if (idx == 0) { return Input_; } |
2247 | if (idx == 1) { return Scale_; } |
2248 | if (idx == 2) { return Bias_; } |
2249 | idx -= 3; |
2250 | llvm_unreachable("Invalid index" ); |
2251 | } |
2252 | |
2253 | void InstanceNormalizationNode::setNthInput(unsigned idx, NodeValue val) { |
2254 | if (idx == 0) { Input_ = val; return; } |
2255 | if (idx == 1) { Scale_ = val; return; } |
2256 | if (idx == 2) { Bias_ = val; return; } |
2257 | idx -= 3; |
2258 | llvm_unreachable("Invalid index" ); |
2259 | } |
2260 | |
2261 | llvm::StringRef InstanceNormalizationNode::getOutputName(unsigned idx) const { |
2262 | if (idx == 0) { return "Result" ; } |
2263 | llvm_unreachable("Invalid index" ); |
2264 | } |
2265 | |
2266 | std::string InstanceNormalizationNode::getDebugDesc() const { |
2267 | DescriptionBuilder db(getKindName()); |
2268 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2269 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2270 | db |
2271 | .addParam("Input" , *(getInput().getType())) |
2272 | .addParam("Scale" , *(getScale().getType())) |
2273 | .addParam("Bias" , *(getBias().getType())) |
2274 | .addParam("ChannelIdx" , getChannelIdx()) |
2275 | .addParam("Epsilon" , getEpsilon()) |
2276 | .addParam("Users" , getNumUsers()); |
2277 | db.addParam("Result" , *(getResult().getType())); |
2278 | return db; |
2279 | } |
2280 | |
2281 | void InstanceNormalizationNode::visit(Node *parent, NodeWalker *visitor) { |
2282 | if (!visitor->shouldVisit(parent, this)) { return; } |
2283 | visitor->pre(parent, this); |
2284 | if (hasPredicate()) |
2285 | getPredicate().getNode()->visit(this, visitor); |
2286 | getInput().getNode()->visit(this, visitor); |
2287 | getScale().getNode()->visit(this, visitor); |
2288 | getBias().getNode()->visit(this, visitor); |
2289 | visitor->post(parent, this); |
2290 | } |
2291 | |
2292 | bool InstanceNormalizationNode::isEqual(const InstanceNormalizationNode &other) const { |
2293 | return true && |
2294 | Input_ == other.Input_ && |
2295 | Scale_ == other.Scale_ && |
2296 | Bias_ == other.Bias_ && |
2297 | predicate_ == other.predicate_ && |
2298 | ChannelIdx_ == other.ChannelIdx_ && |
2299 | Epsilon_ == other.Epsilon_ && |
2300 | getType(0) == other.getType(0); |
2301 | } |
2302 | |
2303 | Node* InstanceNormalizationNode::clone() const { |
2304 | return new InstanceNormalizationNode(getName(), getInput(), getScale(), getBias(), getChannelIdx(), getEpsilon()); |
2305 | } |
2306 | |
2307 | llvm::hash_code InstanceNormalizationNode::getHash() const { |
2308 | return llvm::hash_combine( |
2309 | ChannelIdx_, |
2310 | toBinary(Epsilon_), |
2311 | Input_, |
2312 | Scale_, |
2313 | Bias_); |
2314 | } |
2315 | |
2316 | unsigned MeanVarNormalizationNode::getNumInputs() const { |
2317 | return 3; |
2318 | } |
2319 | |
2320 | std::string MeanVarNormalizationNode::getInputName(unsigned idx) const { |
2321 | if (idx == 0) { return "Input" ; } |
2322 | if (idx == 1) { return "Mean" ; } |
2323 | if (idx == 2) { return "Var" ; } |
2324 | idx -= 3; |
2325 | llvm_unreachable("Invalid index" ); |
2326 | } |
2327 | |
2328 | NodeValue MeanVarNormalizationNode::getNthInput(unsigned idx) { |
2329 | if (idx == 0) { return Input_; } |
2330 | if (idx == 1) { return Mean_; } |
2331 | if (idx == 2) { return Var_; } |
2332 | idx -= 3; |
2333 | llvm_unreachable("Invalid index" ); |
2334 | } |
2335 | |
2336 | void MeanVarNormalizationNode::setNthInput(unsigned idx, NodeValue val) { |
2337 | if (idx == 0) { Input_ = val; return; } |
2338 | if (idx == 1) { Mean_ = val; return; } |
2339 | if (idx == 2) { Var_ = val; return; } |
2340 | idx -= 3; |
2341 | llvm_unreachable("Invalid index" ); |
2342 | } |
2343 | |
2344 | llvm::StringRef MeanVarNormalizationNode::getOutputName(unsigned idx) const { |
2345 | if (idx == 0) { return "NewMean" ; } |
2346 | if (idx == 1) { return "NewVar" ; } |
2347 | llvm_unreachable("Invalid index" ); |
2348 | } |
2349 | |
2350 | std::string MeanVarNormalizationNode::getDebugDesc() const { |
2351 | DescriptionBuilder db(getKindName()); |
2352 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2353 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2354 | db |
2355 | .addParam("Input" , *(getInput().getType())) |
2356 | .addParam("Mean" , *(getMean().getType())) |
2357 | .addParam("Var" , *(getVar().getType())) |
2358 | .addParam("ChannelIdx" , getChannelIdx()) |
2359 | .addParam("Momentum" , getMomentum()) |
2360 | .addParam("Users" , getNumUsers()); |
2361 | db.addParam("NewMean" , *(getNewMean().getType())); |
2362 | db.addParam("NewVar" , *(getNewVar().getType())); |
2363 | return db; |
2364 | } |
2365 | |
2366 | void MeanVarNormalizationNode::visit(Node *parent, NodeWalker *visitor) { |
2367 | if (!visitor->shouldVisit(parent, this)) { return; } |
2368 | visitor->pre(parent, this); |
2369 | if (hasPredicate()) |
2370 | getPredicate().getNode()->visit(this, visitor); |
2371 | getInput().getNode()->visit(this, visitor); |
2372 | getMean().getNode()->visit(this, visitor); |
2373 | getVar().getNode()->visit(this, visitor); |
2374 | visitor->post(parent, this); |
2375 | } |
2376 | |
2377 | bool MeanVarNormalizationNode::isEqual(const MeanVarNormalizationNode &other) const { |
2378 | return true && |
2379 | Input_ == other.Input_ && |
2380 | Mean_ == other.Mean_ && |
2381 | Var_ == other.Var_ && |
2382 | predicate_ == other.predicate_ && |
2383 | ChannelIdx_ == other.ChannelIdx_ && |
2384 | Momentum_ == other.Momentum_ && |
2385 | getType(0) == other.getType(0) && |
2386 | getType(1) == other.getType(1); |
2387 | } |
2388 | |
2389 | Node* MeanVarNormalizationNode::clone() const { |
2390 | return new MeanVarNormalizationNode(getName(), getInput(), getMean(), getVar(), getChannelIdx(), getMomentum()); |
2391 | } |
2392 | |
2393 | llvm::hash_code MeanVarNormalizationNode::getHash() const { |
2394 | return llvm::hash_combine( |
2395 | ChannelIdx_, |
2396 | toBinary(Momentum_), |
2397 | Input_, |
2398 | Mean_, |
2399 | Var_); |
2400 | } |
2401 | |
2402 | unsigned LocalResponseNormalizationGradNode::getNumInputs() const { |
2403 | return 3; |
2404 | } |
2405 | |
2406 | std::string LocalResponseNormalizationGradNode::getInputName(unsigned idx) const { |
2407 | if (idx == 0) { return "Input" ; } |
2408 | if (idx == 1) { return "OriginalOutputForResult" ; } |
2409 | if (idx == 2) { return "GradOfOriginalOutputNamedResult" ; } |
2410 | idx -= 3; |
2411 | llvm_unreachable("Invalid index" ); |
2412 | } |
2413 | |
2414 | NodeValue LocalResponseNormalizationGradNode::getNthInput(unsigned idx) { |
2415 | if (idx == 0) { return Input_; } |
2416 | if (idx == 1) { return OriginalOutputForResult_; } |
2417 | if (idx == 2) { return GradOfOriginalOutputNamedResult_; } |
2418 | idx -= 3; |
2419 | llvm_unreachable("Invalid index" ); |
2420 | } |
2421 | |
2422 | void LocalResponseNormalizationGradNode::setNthInput(unsigned idx, NodeValue val) { |
2423 | if (idx == 0) { Input_ = val; return; } |
2424 | if (idx == 1) { OriginalOutputForResult_ = val; return; } |
2425 | if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; } |
2426 | idx -= 3; |
2427 | llvm_unreachable("Invalid index" ); |
2428 | } |
2429 | |
2430 | llvm::StringRef LocalResponseNormalizationGradNode::getOutputName(unsigned idx) const { |
2431 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
2432 | llvm_unreachable("Invalid index" ); |
2433 | } |
2434 | |
2435 | std::string LocalResponseNormalizationGradNode::getDebugDesc() const { |
2436 | DescriptionBuilder db(getKindName()); |
2437 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2438 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2439 | db |
2440 | .addParam("Input" , *(getInput().getType())) |
2441 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
2442 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
2443 | .addParam("HalfWindowSize" , getHalfWindowSize()) |
2444 | .addParam("Alpha" , getAlpha()) |
2445 | .addParam("Beta" , getBeta()) |
2446 | .addParam("K" , getK()) |
2447 | .addParam("Users" , getNumUsers()); |
2448 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
2449 | return db; |
2450 | } |
2451 | |
2452 | void LocalResponseNormalizationGradNode::visit(Node *parent, NodeWalker *visitor) { |
2453 | if (!visitor->shouldVisit(parent, this)) { return; } |
2454 | visitor->pre(parent, this); |
2455 | if (hasPredicate()) |
2456 | getPredicate().getNode()->visit(this, visitor); |
2457 | getInput().getNode()->visit(this, visitor); |
2458 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
2459 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
2460 | visitor->post(parent, this); |
2461 | } |
2462 | |
2463 | bool LocalResponseNormalizationGradNode::isEqual(const LocalResponseNormalizationGradNode &other) const { |
2464 | return true && |
2465 | Input_ == other.Input_ && |
2466 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
2467 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
2468 | predicate_ == other.predicate_ && |
2469 | HalfWindowSize_ == other.HalfWindowSize_ && |
2470 | Alpha_ == other.Alpha_ && |
2471 | Beta_ == other.Beta_ && |
2472 | K_ == other.K_ && |
2473 | getType(0) == other.getType(0); |
2474 | } |
2475 | |
2476 | Node* LocalResponseNormalizationGradNode::clone() const { |
2477 | return new LocalResponseNormalizationGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getHalfWindowSize(), getAlpha(), getBeta(), getK()); |
2478 | } |
2479 | |
2480 | llvm::hash_code LocalResponseNormalizationGradNode::getHash() const { |
2481 | return llvm::hash_combine( |
2482 | HalfWindowSize_, |
2483 | toBinary(Alpha_), |
2484 | toBinary(Beta_), |
2485 | toBinary(K_), |
2486 | Input_, |
2487 | OriginalOutputForResult_, |
2488 | GradOfOriginalOutputNamedResult_); |
2489 | } |
2490 | |
2491 | unsigned LocalResponseNormalizationNode::getNumInputs() const { |
2492 | return 1; |
2493 | } |
2494 | |
2495 | std::string LocalResponseNormalizationNode::getInputName(unsigned idx) const { |
2496 | if (idx == 0) { return "Input" ; } |
2497 | idx -= 1; |
2498 | llvm_unreachable("Invalid index" ); |
2499 | } |
2500 | |
2501 | NodeValue LocalResponseNormalizationNode::getNthInput(unsigned idx) { |
2502 | if (idx == 0) { return Input_; } |
2503 | idx -= 1; |
2504 | llvm_unreachable("Invalid index" ); |
2505 | } |
2506 | |
2507 | void LocalResponseNormalizationNode::setNthInput(unsigned idx, NodeValue val) { |
2508 | if (idx == 0) { Input_ = val; return; } |
2509 | idx -= 1; |
2510 | llvm_unreachable("Invalid index" ); |
2511 | } |
2512 | |
2513 | llvm::StringRef LocalResponseNormalizationNode::getOutputName(unsigned idx) const { |
2514 | if (idx == 0) { return "Result" ; } |
2515 | llvm_unreachable("Invalid index" ); |
2516 | } |
2517 | |
2518 | std::string LocalResponseNormalizationNode::getDebugDesc() const { |
2519 | DescriptionBuilder db(getKindName()); |
2520 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2521 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2522 | db |
2523 | .addParam("Input" , *(getInput().getType())) |
2524 | .addParam("HalfWindowSize" , getHalfWindowSize()) |
2525 | .addParam("Alpha" , getAlpha()) |
2526 | .addParam("Beta" , getBeta()) |
2527 | .addParam("K" , getK()) |
2528 | .addParam("Users" , getNumUsers()); |
2529 | db.addParam("Result" , *(getResult().getType())); |
2530 | return db; |
2531 | } |
2532 | |
2533 | void LocalResponseNormalizationNode::visit(Node *parent, NodeWalker *visitor) { |
2534 | if (!visitor->shouldVisit(parent, this)) { return; } |
2535 | visitor->pre(parent, this); |
2536 | if (hasPredicate()) |
2537 | getPredicate().getNode()->visit(this, visitor); |
2538 | getInput().getNode()->visit(this, visitor); |
2539 | visitor->post(parent, this); |
2540 | } |
2541 | |
2542 | bool LocalResponseNormalizationNode::isEqual(const LocalResponseNormalizationNode &other) const { |
2543 | return true && |
2544 | Input_ == other.Input_ && |
2545 | predicate_ == other.predicate_ && |
2546 | HalfWindowSize_ == other.HalfWindowSize_ && |
2547 | Alpha_ == other.Alpha_ && |
2548 | Beta_ == other.Beta_ && |
2549 | K_ == other.K_ && |
2550 | getType(0) == other.getType(0); |
2551 | } |
2552 | |
2553 | Node* LocalResponseNormalizationNode::clone() const { |
2554 | return new LocalResponseNormalizationNode(getName(), getInput(), getHalfWindowSize(), getAlpha(), getBeta(), getK()); |
2555 | } |
2556 | |
2557 | llvm::hash_code LocalResponseNormalizationNode::getHash() const { |
2558 | return llvm::hash_combine( |
2559 | HalfWindowSize_, |
2560 | toBinary(Alpha_), |
2561 | toBinary(Beta_), |
2562 | toBinary(K_), |
2563 | Input_); |
2564 | } |
2565 | |
2566 | LocalResponseNormalizationGradNode *LocalResponseNormalizationNode::getGrad(GraphGradMapper &builder) { |
2567 | auto *x = new LocalResponseNormalizationGradNode(getName().str() + "_grad" , getInput(), getResult(), builder.getGradient(getResult()), getHalfWindowSize(), getAlpha(), getBeta(), getK()); |
2568 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
2569 | return x; |
2570 | } |
2571 | |
2572 | unsigned LayerNormalizationNode::getNumInputs() const { |
2573 | return 3; |
2574 | } |
2575 | |
2576 | std::string LayerNormalizationNode::getInputName(unsigned idx) const { |
2577 | if (idx == 0) { return "Input" ; } |
2578 | if (idx == 1) { return "Scale" ; } |
2579 | if (idx == 2) { return "Bias" ; } |
2580 | idx -= 3; |
2581 | llvm_unreachable("Invalid index" ); |
2582 | } |
2583 | |
2584 | NodeValue LayerNormalizationNode::getNthInput(unsigned idx) { |
2585 | if (idx == 0) { return Input_; } |
2586 | if (idx == 1) { return Scale_; } |
2587 | if (idx == 2) { return Bias_; } |
2588 | idx -= 3; |
2589 | llvm_unreachable("Invalid index" ); |
2590 | } |
2591 | |
2592 | void LayerNormalizationNode::setNthInput(unsigned idx, NodeValue val) { |
2593 | if (idx == 0) { Input_ = val; return; } |
2594 | if (idx == 1) { Scale_ = val; return; } |
2595 | if (idx == 2) { Bias_ = val; return; } |
2596 | idx -= 3; |
2597 | llvm_unreachable("Invalid index" ); |
2598 | } |
2599 | |
2600 | llvm::StringRef LayerNormalizationNode::getOutputName(unsigned idx) const { |
2601 | if (idx == 0) { return "Result" ; } |
2602 | llvm_unreachable("Invalid index" ); |
2603 | } |
2604 | |
2605 | std::string LayerNormalizationNode::getDebugDesc() const { |
2606 | DescriptionBuilder db(getKindName()); |
2607 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2608 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2609 | db |
2610 | .addParam("Input" , *(getInput().getType())) |
2611 | .addParam("Scale" , *(getScale().getType())) |
2612 | .addParam("Bias" , *(getBias().getType())) |
2613 | .addParam("Epsilon" , getEpsilon()) |
2614 | .addParam("Users" , getNumUsers()); |
2615 | db.addParam("Result" , *(getResult().getType())); |
2616 | return db; |
2617 | } |
2618 | |
2619 | void LayerNormalizationNode::visit(Node *parent, NodeWalker *visitor) { |
2620 | if (!visitor->shouldVisit(parent, this)) { return; } |
2621 | visitor->pre(parent, this); |
2622 | if (hasPredicate()) |
2623 | getPredicate().getNode()->visit(this, visitor); |
2624 | getInput().getNode()->visit(this, visitor); |
2625 | getScale().getNode()->visit(this, visitor); |
2626 | getBias().getNode()->visit(this, visitor); |
2627 | visitor->post(parent, this); |
2628 | } |
2629 | |
2630 | bool LayerNormalizationNode::isEqual(const LayerNormalizationNode &other) const { |
2631 | return true && |
2632 | Input_ == other.Input_ && |
2633 | Scale_ == other.Scale_ && |
2634 | Bias_ == other.Bias_ && |
2635 | predicate_ == other.predicate_ && |
2636 | Epsilon_ == other.Epsilon_ && |
2637 | getType(0) == other.getType(0); |
2638 | } |
2639 | |
2640 | Node* LayerNormalizationNode::clone() const { |
2641 | return new LayerNormalizationNode(getName(), getResult().getType(), getInput(), getScale(), getBias(), getEpsilon()); |
2642 | } |
2643 | |
2644 | llvm::hash_code LayerNormalizationNode::getHash() const { |
2645 | return llvm::hash_combine( |
2646 | toBinary(Epsilon_), |
2647 | Input_, |
2648 | Scale_, |
2649 | Bias_); |
2650 | } |
2651 | |
2652 | unsigned BatchBoxCoxNode::getNumInputs() const { |
2653 | return 3; |
2654 | } |
2655 | |
2656 | std::string BatchBoxCoxNode::getInputName(unsigned idx) const { |
2657 | if (idx == 0) { return "Input" ; } |
2658 | if (idx == 1) { return "Lambda1" ; } |
2659 | if (idx == 2) { return "Lambda2" ; } |
2660 | idx -= 3; |
2661 | llvm_unreachable("Invalid index" ); |
2662 | } |
2663 | |
2664 | NodeValue BatchBoxCoxNode::getNthInput(unsigned idx) { |
2665 | if (idx == 0) { return Input_; } |
2666 | if (idx == 1) { return Lambda1_; } |
2667 | if (idx == 2) { return Lambda2_; } |
2668 | idx -= 3; |
2669 | llvm_unreachable("Invalid index" ); |
2670 | } |
2671 | |
2672 | void BatchBoxCoxNode::setNthInput(unsigned idx, NodeValue val) { |
2673 | if (idx == 0) { Input_ = val; return; } |
2674 | if (idx == 1) { Lambda1_ = val; return; } |
2675 | if (idx == 2) { Lambda2_ = val; return; } |
2676 | idx -= 3; |
2677 | llvm_unreachable("Invalid index" ); |
2678 | } |
2679 | |
2680 | llvm::StringRef BatchBoxCoxNode::getOutputName(unsigned idx) const { |
2681 | if (idx == 0) { return "Result" ; } |
2682 | llvm_unreachable("Invalid index" ); |
2683 | } |
2684 | |
2685 | std::string BatchBoxCoxNode::getDebugDesc() const { |
2686 | DescriptionBuilder db(getKindName()); |
2687 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2688 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2689 | db |
2690 | .addParam("Input" , *(getInput().getType())) |
2691 | .addParam("Lambda1" , *(getLambda1().getType())) |
2692 | .addParam("Lambda2" , *(getLambda2().getType())) |
2693 | .addParam("Epsilon" , getEpsilon()) |
2694 | .addParam("Users" , getNumUsers()); |
2695 | db.addParam("Result" , *(getResult().getType())); |
2696 | return db; |
2697 | } |
2698 | |
2699 | void BatchBoxCoxNode::visit(Node *parent, NodeWalker *visitor) { |
2700 | if (!visitor->shouldVisit(parent, this)) { return; } |
2701 | visitor->pre(parent, this); |
2702 | if (hasPredicate()) |
2703 | getPredicate().getNode()->visit(this, visitor); |
2704 | getInput().getNode()->visit(this, visitor); |
2705 | getLambda1().getNode()->visit(this, visitor); |
2706 | getLambda2().getNode()->visit(this, visitor); |
2707 | visitor->post(parent, this); |
2708 | } |
2709 | |
2710 | bool BatchBoxCoxNode::isEqual(const BatchBoxCoxNode &other) const { |
2711 | return true && |
2712 | Input_ == other.Input_ && |
2713 | Lambda1_ == other.Lambda1_ && |
2714 | Lambda2_ == other.Lambda2_ && |
2715 | predicate_ == other.predicate_ && |
2716 | Epsilon_ == other.Epsilon_ && |
2717 | getType(0) == other.getType(0); |
2718 | } |
2719 | |
2720 | Node* BatchBoxCoxNode::clone() const { |
2721 | return new BatchBoxCoxNode(getName(), getInput(), getLambda1(), getLambda2(), getEpsilon()); |
2722 | } |
2723 | |
2724 | llvm::hash_code BatchBoxCoxNode::getHash() const { |
2725 | return llvm::hash_combine( |
2726 | toBinary(Epsilon_), |
2727 | Input_, |
2728 | Lambda1_, |
2729 | Lambda2_); |
2730 | } |
2731 | |
2732 | unsigned VectorNormNode::getNumInputs() const { |
2733 | return 1; |
2734 | } |
2735 | |
2736 | std::string VectorNormNode::getInputName(unsigned idx) const { |
2737 | if (idx == 0) { return "Input" ; } |
2738 | idx -= 1; |
2739 | llvm_unreachable("Invalid index" ); |
2740 | } |
2741 | |
2742 | NodeValue VectorNormNode::getNthInput(unsigned idx) { |
2743 | if (idx == 0) { return Input_; } |
2744 | idx -= 1; |
2745 | llvm_unreachable("Invalid index" ); |
2746 | } |
2747 | |
2748 | void VectorNormNode::setNthInput(unsigned idx, NodeValue val) { |
2749 | if (idx == 0) { Input_ = val; return; } |
2750 | idx -= 1; |
2751 | llvm_unreachable("Invalid index" ); |
2752 | } |
2753 | |
2754 | llvm::StringRef VectorNormNode::getOutputName(unsigned idx) const { |
2755 | if (idx == 0) { return "Result" ; } |
2756 | llvm_unreachable("Invalid index" ); |
2757 | } |
2758 | |
2759 | std::string VectorNormNode::getDebugDesc() const { |
2760 | DescriptionBuilder db(getKindName()); |
2761 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2762 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2763 | db |
2764 | .addParam("Input" , *(getInput().getType())) |
2765 | .addParam("Axis" , getAxis()) |
2766 | .addParam("P" , getP()) |
2767 | .addParam("Users" , getNumUsers()); |
2768 | db.addParam("Result" , *(getResult().getType())); |
2769 | return db; |
2770 | } |
2771 | |
2772 | void VectorNormNode::visit(Node *parent, NodeWalker *visitor) { |
2773 | if (!visitor->shouldVisit(parent, this)) { return; } |
2774 | visitor->pre(parent, this); |
2775 | if (hasPredicate()) |
2776 | getPredicate().getNode()->visit(this, visitor); |
2777 | getInput().getNode()->visit(this, visitor); |
2778 | visitor->post(parent, this); |
2779 | } |
2780 | |
2781 | bool VectorNormNode::isEqual(const VectorNormNode &other) const { |
2782 | return true && |
2783 | Input_ == other.Input_ && |
2784 | predicate_ == other.predicate_ && |
2785 | Axis_ == other.Axis_ && |
2786 | P_ == other.P_ && |
2787 | getType(0) == other.getType(0); |
2788 | } |
2789 | |
2790 | Node* VectorNormNode::clone() const { |
2791 | return new VectorNormNode(getName(), getResult().getType(), getInput(), getAxis(), getP()); |
2792 | } |
2793 | |
2794 | llvm::hash_code VectorNormNode::getHash() const { |
2795 | return llvm::hash_combine( |
2796 | Axis_, |
2797 | P_, |
2798 | Input_); |
2799 | } |
2800 | |
2801 | unsigned BucketizeNode::getNumInputs() const { |
2802 | return 1; |
2803 | } |
2804 | |
2805 | std::string BucketizeNode::getInputName(unsigned idx) const { |
2806 | if (idx == 0) { return "Input" ; } |
2807 | idx -= 1; |
2808 | llvm_unreachable("Invalid index" ); |
2809 | } |
2810 | |
2811 | NodeValue BucketizeNode::getNthInput(unsigned idx) { |
2812 | if (idx == 0) { return Input_; } |
2813 | idx -= 1; |
2814 | llvm_unreachable("Invalid index" ); |
2815 | } |
2816 | |
2817 | void BucketizeNode::setNthInput(unsigned idx, NodeValue val) { |
2818 | if (idx == 0) { Input_ = val; return; } |
2819 | idx -= 1; |
2820 | llvm_unreachable("Invalid index" ); |
2821 | } |
2822 | |
2823 | llvm::StringRef BucketizeNode::getOutputName(unsigned idx) const { |
2824 | if (idx == 0) { return "Result" ; } |
2825 | llvm_unreachable("Invalid index" ); |
2826 | } |
2827 | |
2828 | std::string BucketizeNode::getDebugDesc() const { |
2829 | DescriptionBuilder db(getKindName()); |
2830 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2831 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2832 | db |
2833 | .addParam("Input" , *(getInput().getType())) |
2834 | .addParam("Boundaries" , getBoundaries()) |
2835 | .addParam("Users" , getNumUsers()); |
2836 | db.addParam("Result" , *(getResult().getType())); |
2837 | return db; |
2838 | } |
2839 | |
2840 | void BucketizeNode::visit(Node *parent, NodeWalker *visitor) { |
2841 | if (!visitor->shouldVisit(parent, this)) { return; } |
2842 | visitor->pre(parent, this); |
2843 | if (hasPredicate()) |
2844 | getPredicate().getNode()->visit(this, visitor); |
2845 | getInput().getNode()->visit(this, visitor); |
2846 | visitor->post(parent, this); |
2847 | } |
2848 | |
2849 | bool BucketizeNode::isEqual(const BucketizeNode &other) const { |
2850 | return true && |
2851 | Input_ == other.Input_ && |
2852 | predicate_ == other.predicate_ && |
2853 | Boundaries_ == other.Boundaries_ && |
2854 | getType(0) == other.getType(0); |
2855 | } |
2856 | |
2857 | Node* BucketizeNode::clone() const { |
2858 | return new BucketizeNode(getName(), getResult().getType(), getInput(), getBoundaries()); |
2859 | } |
2860 | |
2861 | llvm::hash_code BucketizeNode::getHash() const { |
2862 | return llvm::hash_combine( |
2863 | [](const std::vector<float>& floatVec) -> llvm::hash_code { |
2864 | std::vector<size_t> sizeVec = toBinary(floatVec); |
2865 | return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end()); |
2866 | }(Boundaries_), |
2867 | Input_); |
2868 | } |
2869 | |
2870 | unsigned SoftMaxGradNode::getNumInputs() const { |
2871 | return 4; |
2872 | } |
2873 | |
2874 | std::string SoftMaxGradNode::getInputName(unsigned idx) const { |
2875 | if (idx == 0) { return "Input" ; } |
2876 | if (idx == 1) { return "Selected" ; } |
2877 | if (idx == 2) { return "OriginalOutputForResult" ; } |
2878 | if (idx == 3) { return "GradOfOriginalOutputNamedResult" ; } |
2879 | idx -= 4; |
2880 | llvm_unreachable("Invalid index" ); |
2881 | } |
2882 | |
2883 | NodeValue SoftMaxGradNode::getNthInput(unsigned idx) { |
2884 | if (idx == 0) { return Input_; } |
2885 | if (idx == 1) { return Selected_; } |
2886 | if (idx == 2) { return OriginalOutputForResult_; } |
2887 | if (idx == 3) { return GradOfOriginalOutputNamedResult_; } |
2888 | idx -= 4; |
2889 | llvm_unreachable("Invalid index" ); |
2890 | } |
2891 | |
2892 | void SoftMaxGradNode::setNthInput(unsigned idx, NodeValue val) { |
2893 | if (idx == 0) { Input_ = val; return; } |
2894 | if (idx == 1) { Selected_ = val; return; } |
2895 | if (idx == 2) { OriginalOutputForResult_ = val; return; } |
2896 | if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; } |
2897 | idx -= 4; |
2898 | llvm_unreachable("Invalid index" ); |
2899 | } |
2900 | |
2901 | llvm::StringRef SoftMaxGradNode::getOutputName(unsigned idx) const { |
2902 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
2903 | if (idx == 1) { return "GradOfInputNamedSelected" ; } |
2904 | llvm_unreachable("Invalid index" ); |
2905 | } |
2906 | |
2907 | std::string SoftMaxGradNode::getDebugDesc() const { |
2908 | DescriptionBuilder db(getKindName()); |
2909 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2910 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2911 | db |
2912 | .addParam("Input" , *(getInput().getType())) |
2913 | .addParam("Selected" , *(getSelected().getType())) |
2914 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
2915 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
2916 | .addParam("Users" , getNumUsers()); |
2917 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
2918 | db.addParam("GradOfInputNamedSelected" , *(getGradOfInputNamedSelected().getType())); |
2919 | return db; |
2920 | } |
2921 | |
2922 | void SoftMaxGradNode::visit(Node *parent, NodeWalker *visitor) { |
2923 | if (!visitor->shouldVisit(parent, this)) { return; } |
2924 | visitor->pre(parent, this); |
2925 | if (hasPredicate()) |
2926 | getPredicate().getNode()->visit(this, visitor); |
2927 | getInput().getNode()->visit(this, visitor); |
2928 | getSelected().getNode()->visit(this, visitor); |
2929 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
2930 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
2931 | visitor->post(parent, this); |
2932 | } |
2933 | |
2934 | bool SoftMaxGradNode::isEqual(const SoftMaxGradNode &other) const { |
2935 | return true && |
2936 | Input_ == other.Input_ && |
2937 | Selected_ == other.Selected_ && |
2938 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
2939 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
2940 | predicate_ == other.predicate_ && |
2941 | getType(0) == other.getType(0) && |
2942 | getType(1) == other.getType(1); |
2943 | } |
2944 | |
2945 | Node* SoftMaxGradNode::clone() const { |
2946 | return new SoftMaxGradNode(getName(), getInput(), getSelected(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
2947 | } |
2948 | |
2949 | llvm::hash_code SoftMaxGradNode::getHash() const { |
2950 | return llvm::hash_combine( |
2951 | Input_, |
2952 | Selected_, |
2953 | OriginalOutputForResult_, |
2954 | GradOfOriginalOutputNamedResult_); |
2955 | } |
2956 | |
2957 | unsigned SoftMaxNode::getNumInputs() const { |
2958 | return 2; |
2959 | } |
2960 | |
2961 | std::string SoftMaxNode::getInputName(unsigned idx) const { |
2962 | if (idx == 0) { return "Input" ; } |
2963 | if (idx == 1) { return "Selected" ; } |
2964 | idx -= 2; |
2965 | llvm_unreachable("Invalid index" ); |
2966 | } |
2967 | |
2968 | NodeValue SoftMaxNode::getNthInput(unsigned idx) { |
2969 | if (idx == 0) { return Input_; } |
2970 | if (idx == 1) { return Selected_; } |
2971 | idx -= 2; |
2972 | llvm_unreachable("Invalid index" ); |
2973 | } |
2974 | |
2975 | void SoftMaxNode::setNthInput(unsigned idx, NodeValue val) { |
2976 | if (idx == 0) { Input_ = val; return; } |
2977 | if (idx == 1) { Selected_ = val; return; } |
2978 | idx -= 2; |
2979 | llvm_unreachable("Invalid index" ); |
2980 | } |
2981 | |
2982 | llvm::StringRef SoftMaxNode::getOutputName(unsigned idx) const { |
2983 | if (idx == 0) { return "Result" ; } |
2984 | llvm_unreachable("Invalid index" ); |
2985 | } |
2986 | |
2987 | std::string SoftMaxNode::getDebugDesc() const { |
2988 | DescriptionBuilder db(getKindName()); |
2989 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
2990 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
2991 | db |
2992 | .addParam("Input" , *(getInput().getType())) |
2993 | .addParam("Selected" , *(getSelected().getType())) |
2994 | .addParam("Users" , getNumUsers()); |
2995 | db.addParam("Result" , *(getResult().getType())); |
2996 | return db; |
2997 | } |
2998 | |
2999 | void SoftMaxNode::visit(Node *parent, NodeWalker *visitor) { |
3000 | if (!visitor->shouldVisit(parent, this)) { return; } |
3001 | visitor->pre(parent, this); |
3002 | if (hasPredicate()) |
3003 | getPredicate().getNode()->visit(this, visitor); |
3004 | getInput().getNode()->visit(this, visitor); |
3005 | getSelected().getNode()->visit(this, visitor); |
3006 | visitor->post(parent, this); |
3007 | } |
3008 | |
3009 | bool SoftMaxNode::isEqual(const SoftMaxNode &other) const { |
3010 | return true && |
3011 | Input_ == other.Input_ && |
3012 | Selected_ == other.Selected_ && |
3013 | predicate_ == other.predicate_ && |
3014 | getType(0) == other.getType(0); |
3015 | } |
3016 | |
3017 | Node* SoftMaxNode::clone() const { |
3018 | return new SoftMaxNode(getName(), getResult().getType(), getInput(), getSelected()); |
3019 | } |
3020 | |
3021 | llvm::hash_code SoftMaxNode::getHash() const { |
3022 | return llvm::hash_combine( |
3023 | Input_, |
3024 | Selected_); |
3025 | } |
3026 | |
3027 | SoftMaxGradNode *SoftMaxNode::getGrad(GraphGradMapper &builder) { |
3028 | auto *x = new SoftMaxGradNode(getName().str() + "_grad" , getInput(), getSelected(), getResult(), builder.getGradient(getResult())); |
3029 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
3030 | builder.addGradient(getSelected(), x->getGradOfInputNamedSelected()); |
3031 | return x; |
3032 | } |
3033 | |
3034 | unsigned LogSoftMaxGradNode::getNumInputs() const { |
3035 | return 4; |
3036 | } |
3037 | |
3038 | std::string LogSoftMaxGradNode::getInputName(unsigned idx) const { |
3039 | if (idx == 0) { return "Input" ; } |
3040 | if (idx == 1) { return "Selected" ; } |
3041 | if (idx == 2) { return "OriginalOutputForResult" ; } |
3042 | if (idx == 3) { return "GradOfOriginalOutputNamedResult" ; } |
3043 | idx -= 4; |
3044 | llvm_unreachable("Invalid index" ); |
3045 | } |
3046 | |
3047 | NodeValue LogSoftMaxGradNode::getNthInput(unsigned idx) { |
3048 | if (idx == 0) { return Input_; } |
3049 | if (idx == 1) { return Selected_; } |
3050 | if (idx == 2) { return OriginalOutputForResult_; } |
3051 | if (idx == 3) { return GradOfOriginalOutputNamedResult_; } |
3052 | idx -= 4; |
3053 | llvm_unreachable("Invalid index" ); |
3054 | } |
3055 | |
3056 | void LogSoftMaxGradNode::setNthInput(unsigned idx, NodeValue val) { |
3057 | if (idx == 0) { Input_ = val; return; } |
3058 | if (idx == 1) { Selected_ = val; return; } |
3059 | if (idx == 2) { OriginalOutputForResult_ = val; return; } |
3060 | if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; } |
3061 | idx -= 4; |
3062 | llvm_unreachable("Invalid index" ); |
3063 | } |
3064 | |
3065 | llvm::StringRef LogSoftMaxGradNode::getOutputName(unsigned idx) const { |
3066 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
3067 | if (idx == 1) { return "GradOfInputNamedSelected" ; } |
3068 | llvm_unreachable("Invalid index" ); |
3069 | } |
3070 | |
3071 | std::string LogSoftMaxGradNode::getDebugDesc() const { |
3072 | DescriptionBuilder db(getKindName()); |
3073 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3074 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3075 | db |
3076 | .addParam("Input" , *(getInput().getType())) |
3077 | .addParam("Selected" , *(getSelected().getType())) |
3078 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
3079 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
3080 | .addParam("Users" , getNumUsers()); |
3081 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
3082 | db.addParam("GradOfInputNamedSelected" , *(getGradOfInputNamedSelected().getType())); |
3083 | return db; |
3084 | } |
3085 | |
3086 | void LogSoftMaxGradNode::visit(Node *parent, NodeWalker *visitor) { |
3087 | if (!visitor->shouldVisit(parent, this)) { return; } |
3088 | visitor->pre(parent, this); |
3089 | if (hasPredicate()) |
3090 | getPredicate().getNode()->visit(this, visitor); |
3091 | getInput().getNode()->visit(this, visitor); |
3092 | getSelected().getNode()->visit(this, visitor); |
3093 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
3094 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
3095 | visitor->post(parent, this); |
3096 | } |
3097 | |
3098 | bool LogSoftMaxGradNode::isEqual(const LogSoftMaxGradNode &other) const { |
3099 | return true && |
3100 | Input_ == other.Input_ && |
3101 | Selected_ == other.Selected_ && |
3102 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
3103 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
3104 | predicate_ == other.predicate_ && |
3105 | getType(0) == other.getType(0) && |
3106 | getType(1) == other.getType(1); |
3107 | } |
3108 | |
3109 | Node* LogSoftMaxGradNode::clone() const { |
3110 | return new LogSoftMaxGradNode(getName(), getInput(), getSelected(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
3111 | } |
3112 | |
3113 | llvm::hash_code LogSoftMaxGradNode::getHash() const { |
3114 | return llvm::hash_combine( |
3115 | Input_, |
3116 | Selected_, |
3117 | OriginalOutputForResult_, |
3118 | GradOfOriginalOutputNamedResult_); |
3119 | } |
3120 | |
3121 | unsigned LogSoftMaxNode::getNumInputs() const { |
3122 | return 2; |
3123 | } |
3124 | |
3125 | std::string LogSoftMaxNode::getInputName(unsigned idx) const { |
3126 | if (idx == 0) { return "Input" ; } |
3127 | if (idx == 1) { return "Selected" ; } |
3128 | idx -= 2; |
3129 | llvm_unreachable("Invalid index" ); |
3130 | } |
3131 | |
3132 | NodeValue LogSoftMaxNode::getNthInput(unsigned idx) { |
3133 | if (idx == 0) { return Input_; } |
3134 | if (idx == 1) { return Selected_; } |
3135 | idx -= 2; |
3136 | llvm_unreachable("Invalid index" ); |
3137 | } |
3138 | |
3139 | void LogSoftMaxNode::setNthInput(unsigned idx, NodeValue val) { |
3140 | if (idx == 0) { Input_ = val; return; } |
3141 | if (idx == 1) { Selected_ = val; return; } |
3142 | idx -= 2; |
3143 | llvm_unreachable("Invalid index" ); |
3144 | } |
3145 | |
3146 | llvm::StringRef LogSoftMaxNode::getOutputName(unsigned idx) const { |
3147 | if (idx == 0) { return "Result" ; } |
3148 | llvm_unreachable("Invalid index" ); |
3149 | } |
3150 | |
3151 | std::string LogSoftMaxNode::getDebugDesc() const { |
3152 | DescriptionBuilder db(getKindName()); |
3153 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3154 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3155 | db |
3156 | .addParam("Input" , *(getInput().getType())) |
3157 | .addParam("Selected" , *(getSelected().getType())) |
3158 | .addParam("Users" , getNumUsers()); |
3159 | db.addParam("Result" , *(getResult().getType())); |
3160 | return db; |
3161 | } |
3162 | |
3163 | void LogSoftMaxNode::visit(Node *parent, NodeWalker *visitor) { |
3164 | if (!visitor->shouldVisit(parent, this)) { return; } |
3165 | visitor->pre(parent, this); |
3166 | if (hasPredicate()) |
3167 | getPredicate().getNode()->visit(this, visitor); |
3168 | getInput().getNode()->visit(this, visitor); |
3169 | getSelected().getNode()->visit(this, visitor); |
3170 | visitor->post(parent, this); |
3171 | } |
3172 | |
3173 | bool LogSoftMaxNode::isEqual(const LogSoftMaxNode &other) const { |
3174 | return true && |
3175 | Input_ == other.Input_ && |
3176 | Selected_ == other.Selected_ && |
3177 | predicate_ == other.predicate_ && |
3178 | getType(0) == other.getType(0); |
3179 | } |
3180 | |
3181 | Node* LogSoftMaxNode::clone() const { |
3182 | return new LogSoftMaxNode(getName(), getResult().getType(), getInput(), getSelected()); |
3183 | } |
3184 | |
3185 | llvm::hash_code LogSoftMaxNode::getHash() const { |
3186 | return llvm::hash_combine( |
3187 | Input_, |
3188 | Selected_); |
3189 | } |
3190 | |
3191 | LogSoftMaxGradNode *LogSoftMaxNode::getGrad(GraphGradMapper &builder) { |
3192 | auto *x = new LogSoftMaxGradNode(getName().str() + "_grad" , getInput(), getSelected(), getResult(), builder.getGradient(getResult())); |
3193 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
3194 | builder.addGradient(getSelected(), x->getGradOfInputNamedSelected()); |
3195 | return x; |
3196 | } |
3197 | |
3198 | unsigned CrossEntropyLossGradNode::getNumInputs() const { |
3199 | return 4; |
3200 | } |
3201 | |
3202 | std::string CrossEntropyLossGradNode::getInputName(unsigned idx) const { |
3203 | if (idx == 0) { return "P" ; } |
3204 | if (idx == 1) { return "Labels" ; } |
3205 | if (idx == 2) { return "OriginalOutputForCE" ; } |
3206 | if (idx == 3) { return "GradOfOriginalOutputNamedCE" ; } |
3207 | idx -= 4; |
3208 | llvm_unreachable("Invalid index" ); |
3209 | } |
3210 | |
3211 | NodeValue CrossEntropyLossGradNode::getNthInput(unsigned idx) { |
3212 | if (idx == 0) { return P_; } |
3213 | if (idx == 1) { return Labels_; } |
3214 | if (idx == 2) { return OriginalOutputForCE_; } |
3215 | if (idx == 3) { return GradOfOriginalOutputNamedCE_; } |
3216 | idx -= 4; |
3217 | llvm_unreachable("Invalid index" ); |
3218 | } |
3219 | |
3220 | void CrossEntropyLossGradNode::setNthInput(unsigned idx, NodeValue val) { |
3221 | if (idx == 0) { P_ = val; return; } |
3222 | if (idx == 1) { Labels_ = val; return; } |
3223 | if (idx == 2) { OriginalOutputForCE_ = val; return; } |
3224 | if (idx == 3) { GradOfOriginalOutputNamedCE_ = val; return; } |
3225 | idx -= 4; |
3226 | llvm_unreachable("Invalid index" ); |
3227 | } |
3228 | |
3229 | llvm::StringRef CrossEntropyLossGradNode::getOutputName(unsigned idx) const { |
3230 | if (idx == 0) { return "GradOfInputNamedP" ; } |
3231 | if (idx == 1) { return "GradOfInputNamedLabels" ; } |
3232 | llvm_unreachable("Invalid index" ); |
3233 | } |
3234 | |
3235 | std::string CrossEntropyLossGradNode::getDebugDesc() const { |
3236 | DescriptionBuilder db(getKindName()); |
3237 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3238 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3239 | db |
3240 | .addParam("P" , *(getP().getType())) |
3241 | .addParam("Labels" , *(getLabels().getType())) |
3242 | .addParam("OriginalOutputForCE" , *(getOriginalOutputForCE().getType())) |
3243 | .addParam("GradOfOriginalOutputNamedCE" , *(getGradOfOriginalOutputNamedCE().getType())) |
3244 | .addParam("Users" , getNumUsers()); |
3245 | db.addParam("GradOfInputNamedP" , *(getGradOfInputNamedP().getType())); |
3246 | db.addParam("GradOfInputNamedLabels" , *(getGradOfInputNamedLabels().getType())); |
3247 | return db; |
3248 | } |
3249 | |
3250 | void CrossEntropyLossGradNode::visit(Node *parent, NodeWalker *visitor) { |
3251 | if (!visitor->shouldVisit(parent, this)) { return; } |
3252 | visitor->pre(parent, this); |
3253 | if (hasPredicate()) |
3254 | getPredicate().getNode()->visit(this, visitor); |
3255 | getP().getNode()->visit(this, visitor); |
3256 | getLabels().getNode()->visit(this, visitor); |
3257 | getOriginalOutputForCE().getNode()->visit(this, visitor); |
3258 | getGradOfOriginalOutputNamedCE().getNode()->visit(this, visitor); |
3259 | visitor->post(parent, this); |
3260 | } |
3261 | |
3262 | bool CrossEntropyLossGradNode::isEqual(const CrossEntropyLossGradNode &other) const { |
3263 | return true && |
3264 | P_ == other.P_ && |
3265 | Labels_ == other.Labels_ && |
3266 | OriginalOutputForCE_ == other.OriginalOutputForCE_ && |
3267 | GradOfOriginalOutputNamedCE_ == other.GradOfOriginalOutputNamedCE_ && |
3268 | predicate_ == other.predicate_ && |
3269 | getType(0) == other.getType(0) && |
3270 | getType(1) == other.getType(1); |
3271 | } |
3272 | |
3273 | Node* CrossEntropyLossGradNode::clone() const { |
3274 | return new CrossEntropyLossGradNode(getName(), getP(), getLabels(), getOriginalOutputForCE(), getGradOfOriginalOutputNamedCE()); |
3275 | } |
3276 | |
3277 | llvm::hash_code CrossEntropyLossGradNode::getHash() const { |
3278 | return llvm::hash_combine( |
3279 | P_, |
3280 | Labels_, |
3281 | OriginalOutputForCE_, |
3282 | GradOfOriginalOutputNamedCE_); |
3283 | } |
3284 | |
3285 | unsigned CrossEntropyLossNode::getNumInputs() const { |
3286 | return 2; |
3287 | } |
3288 | |
3289 | std::string CrossEntropyLossNode::getInputName(unsigned idx) const { |
3290 | if (idx == 0) { return "P" ; } |
3291 | if (idx == 1) { return "Labels" ; } |
3292 | idx -= 2; |
3293 | llvm_unreachable("Invalid index" ); |
3294 | } |
3295 | |
3296 | NodeValue CrossEntropyLossNode::getNthInput(unsigned idx) { |
3297 | if (idx == 0) { return P_; } |
3298 | if (idx == 1) { return Labels_; } |
3299 | idx -= 2; |
3300 | llvm_unreachable("Invalid index" ); |
3301 | } |
3302 | |
3303 | void CrossEntropyLossNode::setNthInput(unsigned idx, NodeValue val) { |
3304 | if (idx == 0) { P_ = val; return; } |
3305 | if (idx == 1) { Labels_ = val; return; } |
3306 | idx -= 2; |
3307 | llvm_unreachable("Invalid index" ); |
3308 | } |
3309 | |
3310 | llvm::StringRef CrossEntropyLossNode::getOutputName(unsigned idx) const { |
3311 | if (idx == 0) { return "CE" ; } |
3312 | llvm_unreachable("Invalid index" ); |
3313 | } |
3314 | |
3315 | std::string CrossEntropyLossNode::getDebugDesc() const { |
3316 | DescriptionBuilder db(getKindName()); |
3317 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3318 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3319 | db |
3320 | .addParam("P" , *(getP().getType())) |
3321 | .addParam("Labels" , *(getLabels().getType())) |
3322 | .addParam("Users" , getNumUsers()); |
3323 | db.addParam("CE" , *(getCE().getType())); |
3324 | return db; |
3325 | } |
3326 | |
3327 | void CrossEntropyLossNode::visit(Node *parent, NodeWalker *visitor) { |
3328 | if (!visitor->shouldVisit(parent, this)) { return; } |
3329 | visitor->pre(parent, this); |
3330 | if (hasPredicate()) |
3331 | getPredicate().getNode()->visit(this, visitor); |
3332 | getP().getNode()->visit(this, visitor); |
3333 | getLabels().getNode()->visit(this, visitor); |
3334 | visitor->post(parent, this); |
3335 | } |
3336 | |
3337 | bool CrossEntropyLossNode::isEqual(const CrossEntropyLossNode &other) const { |
3338 | return true && |
3339 | P_ == other.P_ && |
3340 | Labels_ == other.Labels_ && |
3341 | predicate_ == other.predicate_ && |
3342 | getType(0) == other.getType(0); |
3343 | } |
3344 | |
3345 | Node* CrossEntropyLossNode::clone() const { |
3346 | return new CrossEntropyLossNode(getName(), getCE().getType(), getP(), getLabels()); |
3347 | } |
3348 | |
3349 | llvm::hash_code CrossEntropyLossNode::getHash() const { |
3350 | return llvm::hash_combine( |
3351 | P_, |
3352 | Labels_); |
3353 | } |
3354 | |
3355 | CrossEntropyLossGradNode *CrossEntropyLossNode::getGrad(GraphGradMapper &builder) { |
3356 | auto *x = new CrossEntropyLossGradNode(getName().str() + "_grad" , getP(), getLabels(), getCE(), builder.getGradient(getCE())); |
3357 | builder.addGradient(getP(), x->getGradOfInputNamedP()); |
3358 | builder.addGradient(getLabels(), x->getGradOfInputNamedLabels()); |
3359 | return x; |
3360 | } |
3361 | |
3362 | unsigned RegressionGradNode::getNumInputs() const { |
3363 | return 4; |
3364 | } |
3365 | |
3366 | std::string RegressionGradNode::getInputName(unsigned idx) const { |
3367 | if (idx == 0) { return "Input" ; } |
3368 | if (idx == 1) { return "Expected" ; } |
3369 | if (idx == 2) { return "OriginalOutputForResult" ; } |
3370 | if (idx == 3) { return "GradOfOriginalOutputNamedResult" ; } |
3371 | idx -= 4; |
3372 | llvm_unreachable("Invalid index" ); |
3373 | } |
3374 | |
3375 | NodeValue RegressionGradNode::getNthInput(unsigned idx) { |
3376 | if (idx == 0) { return Input_; } |
3377 | if (idx == 1) { return Expected_; } |
3378 | if (idx == 2) { return OriginalOutputForResult_; } |
3379 | if (idx == 3) { return GradOfOriginalOutputNamedResult_; } |
3380 | idx -= 4; |
3381 | llvm_unreachable("Invalid index" ); |
3382 | } |
3383 | |
3384 | void RegressionGradNode::setNthInput(unsigned idx, NodeValue val) { |
3385 | if (idx == 0) { Input_ = val; return; } |
3386 | if (idx == 1) { Expected_ = val; return; } |
3387 | if (idx == 2) { OriginalOutputForResult_ = val; return; } |
3388 | if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; } |
3389 | idx -= 4; |
3390 | llvm_unreachable("Invalid index" ); |
3391 | } |
3392 | |
3393 | llvm::StringRef RegressionGradNode::getOutputName(unsigned idx) const { |
3394 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
3395 | if (idx == 1) { return "GradOfInputNamedExpected" ; } |
3396 | llvm_unreachable("Invalid index" ); |
3397 | } |
3398 | |
3399 | std::string RegressionGradNode::getDebugDesc() const { |
3400 | DescriptionBuilder db(getKindName()); |
3401 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3402 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3403 | db |
3404 | .addParam("Input" , *(getInput().getType())) |
3405 | .addParam("Expected" , *(getExpected().getType())) |
3406 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
3407 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
3408 | .addParam("Users" , getNumUsers()); |
3409 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
3410 | db.addParam("GradOfInputNamedExpected" , *(getGradOfInputNamedExpected().getType())); |
3411 | return db; |
3412 | } |
3413 | |
3414 | void RegressionGradNode::visit(Node *parent, NodeWalker *visitor) { |
3415 | if (!visitor->shouldVisit(parent, this)) { return; } |
3416 | visitor->pre(parent, this); |
3417 | if (hasPredicate()) |
3418 | getPredicate().getNode()->visit(this, visitor); |
3419 | getInput().getNode()->visit(this, visitor); |
3420 | getExpected().getNode()->visit(this, visitor); |
3421 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
3422 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
3423 | visitor->post(parent, this); |
3424 | } |
3425 | |
3426 | bool RegressionGradNode::isEqual(const RegressionGradNode &other) const { |
3427 | return true && |
3428 | Input_ == other.Input_ && |
3429 | Expected_ == other.Expected_ && |
3430 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
3431 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
3432 | predicate_ == other.predicate_ && |
3433 | getType(0) == other.getType(0) && |
3434 | getType(1) == other.getType(1); |
3435 | } |
3436 | |
3437 | Node* RegressionGradNode::clone() const { |
3438 | return new RegressionGradNode(getName(), getInput(), getExpected(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
3439 | } |
3440 | |
3441 | llvm::hash_code RegressionGradNode::getHash() const { |
3442 | return llvm::hash_combine( |
3443 | Input_, |
3444 | Expected_, |
3445 | OriginalOutputForResult_, |
3446 | GradOfOriginalOutputNamedResult_); |
3447 | } |
3448 | |
3449 | unsigned RegressionNode::getNumInputs() const { |
3450 | return 2; |
3451 | } |
3452 | |
3453 | std::string RegressionNode::getInputName(unsigned idx) const { |
3454 | if (idx == 0) { return "Input" ; } |
3455 | if (idx == 1) { return "Expected" ; } |
3456 | idx -= 2; |
3457 | llvm_unreachable("Invalid index" ); |
3458 | } |
3459 | |
3460 | NodeValue RegressionNode::getNthInput(unsigned idx) { |
3461 | if (idx == 0) { return Input_; } |
3462 | if (idx == 1) { return Expected_; } |
3463 | idx -= 2; |
3464 | llvm_unreachable("Invalid index" ); |
3465 | } |
3466 | |
3467 | void RegressionNode::setNthInput(unsigned idx, NodeValue val) { |
3468 | if (idx == 0) { Input_ = val; return; } |
3469 | if (idx == 1) { Expected_ = val; return; } |
3470 | idx -= 2; |
3471 | llvm_unreachable("Invalid index" ); |
3472 | } |
3473 | |
3474 | llvm::StringRef RegressionNode::getOutputName(unsigned idx) const { |
3475 | if (idx == 0) { return "Result" ; } |
3476 | llvm_unreachable("Invalid index" ); |
3477 | } |
3478 | |
3479 | std::string RegressionNode::getDebugDesc() const { |
3480 | DescriptionBuilder db(getKindName()); |
3481 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3482 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3483 | db |
3484 | .addParam("Input" , *(getInput().getType())) |
3485 | .addParam("Expected" , *(getExpected().getType())) |
3486 | .addParam("Users" , getNumUsers()); |
3487 | db.addParam("Result" , *(getResult().getType())); |
3488 | return db; |
3489 | } |
3490 | |
3491 | void RegressionNode::visit(Node *parent, NodeWalker *visitor) { |
3492 | if (!visitor->shouldVisit(parent, this)) { return; } |
3493 | visitor->pre(parent, this); |
3494 | if (hasPredicate()) |
3495 | getPredicate().getNode()->visit(this, visitor); |
3496 | getInput().getNode()->visit(this, visitor); |
3497 | getExpected().getNode()->visit(this, visitor); |
3498 | visitor->post(parent, this); |
3499 | } |
3500 | |
3501 | bool RegressionNode::isEqual(const RegressionNode &other) const { |
3502 | return true && |
3503 | Input_ == other.Input_ && |
3504 | Expected_ == other.Expected_ && |
3505 | predicate_ == other.predicate_ && |
3506 | getType(0) == other.getType(0); |
3507 | } |
3508 | |
3509 | Node* RegressionNode::clone() const { |
3510 | return new RegressionNode(getName(), getInput(), getExpected()); |
3511 | } |
3512 | |
3513 | llvm::hash_code RegressionNode::getHash() const { |
3514 | return llvm::hash_combine( |
3515 | Input_, |
3516 | Expected_); |
3517 | } |
3518 | |
3519 | RegressionGradNode *RegressionNode::getGrad(GraphGradMapper &builder) { |
3520 | auto *x = new RegressionGradNode(getName().str() + "_grad" , getInput(), getExpected(), getResult(), builder.getGradient(getResult())); |
3521 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
3522 | builder.addGradient(getExpected(), x->getGradOfInputNamedExpected()); |
3523 | return x; |
3524 | } |
3525 | |
3526 | unsigned SigmoidCrossEntropyWithLogitsNode::getNumInputs() const { |
3527 | return 2; |
3528 | } |
3529 | |
3530 | std::string SigmoidCrossEntropyWithLogitsNode::getInputName(unsigned idx) const { |
3531 | if (idx == 0) { return "Logits" ; } |
3532 | if (idx == 1) { return "Targets" ; } |
3533 | idx -= 2; |
3534 | llvm_unreachable("Invalid index" ); |
3535 | } |
3536 | |
3537 | NodeValue SigmoidCrossEntropyWithLogitsNode::getNthInput(unsigned idx) { |
3538 | if (idx == 0) { return Logits_; } |
3539 | if (idx == 1) { return Targets_; } |
3540 | idx -= 2; |
3541 | llvm_unreachable("Invalid index" ); |
3542 | } |
3543 | |
3544 | void SigmoidCrossEntropyWithLogitsNode::setNthInput(unsigned idx, NodeValue val) { |
3545 | if (idx == 0) { Logits_ = val; return; } |
3546 | if (idx == 1) { Targets_ = val; return; } |
3547 | idx -= 2; |
3548 | llvm_unreachable("Invalid index" ); |
3549 | } |
3550 | |
3551 | llvm::StringRef SigmoidCrossEntropyWithLogitsNode::getOutputName(unsigned idx) const { |
3552 | if (idx == 0) { return "Result" ; } |
3553 | llvm_unreachable("Invalid index" ); |
3554 | } |
3555 | |
3556 | std::string SigmoidCrossEntropyWithLogitsNode::getDebugDesc() const { |
3557 | DescriptionBuilder db(getKindName()); |
3558 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3559 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3560 | db |
3561 | .addParam("Logits" , *(getLogits().getType())) |
3562 | .addParam("Targets" , *(getTargets().getType())) |
3563 | .addParam("Users" , getNumUsers()); |
3564 | db.addParam("Result" , *(getResult().getType())); |
3565 | return db; |
3566 | } |
3567 | |
3568 | void SigmoidCrossEntropyWithLogitsNode::visit(Node *parent, NodeWalker *visitor) { |
3569 | if (!visitor->shouldVisit(parent, this)) { return; } |
3570 | visitor->pre(parent, this); |
3571 | if (hasPredicate()) |
3572 | getPredicate().getNode()->visit(this, visitor); |
3573 | getLogits().getNode()->visit(this, visitor); |
3574 | getTargets().getNode()->visit(this, visitor); |
3575 | visitor->post(parent, this); |
3576 | } |
3577 | |
3578 | bool SigmoidCrossEntropyWithLogitsNode::isEqual(const SigmoidCrossEntropyWithLogitsNode &other) const { |
3579 | return true && |
3580 | Logits_ == other.Logits_ && |
3581 | Targets_ == other.Targets_ && |
3582 | predicate_ == other.predicate_ && |
3583 | getType(0) == other.getType(0); |
3584 | } |
3585 | |
3586 | Node* SigmoidCrossEntropyWithLogitsNode::clone() const { |
3587 | return new SigmoidCrossEntropyWithLogitsNode(getName(), getResult().getType(), getLogits(), getTargets()); |
3588 | } |
3589 | |
3590 | llvm::hash_code SigmoidCrossEntropyWithLogitsNode::getHash() const { |
3591 | return llvm::hash_combine( |
3592 | Logits_, |
3593 | Targets_); |
3594 | } |
3595 | |
3596 | unsigned AddGradNode::getNumInputs() const { |
3597 | return 4; |
3598 | } |
3599 | |
3600 | std::string AddGradNode::getInputName(unsigned idx) const { |
3601 | if (idx == 0) { return "LHS" ; } |
3602 | if (idx == 1) { return "RHS" ; } |
3603 | if (idx == 2) { return "OriginalOutputForResult" ; } |
3604 | if (idx == 3) { return "GradOfOriginalOutputNamedResult" ; } |
3605 | idx -= 4; |
3606 | llvm_unreachable("Invalid index" ); |
3607 | } |
3608 | |
3609 | NodeValue AddGradNode::getNthInput(unsigned idx) { |
3610 | if (idx == 0) { return LHS_; } |
3611 | if (idx == 1) { return RHS_; } |
3612 | if (idx == 2) { return OriginalOutputForResult_; } |
3613 | if (idx == 3) { return GradOfOriginalOutputNamedResult_; } |
3614 | idx -= 4; |
3615 | llvm_unreachable("Invalid index" ); |
3616 | } |
3617 | |
3618 | void AddGradNode::setNthInput(unsigned idx, NodeValue val) { |
3619 | if (idx == 0) { LHS_ = val; return; } |
3620 | if (idx == 1) { RHS_ = val; return; } |
3621 | if (idx == 2) { OriginalOutputForResult_ = val; return; } |
3622 | if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; } |
3623 | idx -= 4; |
3624 | llvm_unreachable("Invalid index" ); |
3625 | } |
3626 | |
3627 | llvm::StringRef AddGradNode::getOutputName(unsigned idx) const { |
3628 | if (idx == 0) { return "GradOfInputNamedLHS" ; } |
3629 | if (idx == 1) { return "GradOfInputNamedRHS" ; } |
3630 | llvm_unreachable("Invalid index" ); |
3631 | } |
3632 | |
3633 | std::string AddGradNode::getDebugDesc() const { |
3634 | DescriptionBuilder db(getKindName()); |
3635 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3636 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3637 | db |
3638 | .addParam("LHS" , *(getLHS().getType())) |
3639 | .addParam("RHS" , *(getRHS().getType())) |
3640 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
3641 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
3642 | .addParam("Users" , getNumUsers()); |
3643 | db.addParam("GradOfInputNamedLHS" , *(getGradOfInputNamedLHS().getType())); |
3644 | db.addParam("GradOfInputNamedRHS" , *(getGradOfInputNamedRHS().getType())); |
3645 | return db; |
3646 | } |
3647 | |
3648 | void AddGradNode::visit(Node *parent, NodeWalker *visitor) { |
3649 | if (!visitor->shouldVisit(parent, this)) { return; } |
3650 | visitor->pre(parent, this); |
3651 | if (hasPredicate()) |
3652 | getPredicate().getNode()->visit(this, visitor); |
3653 | getLHS().getNode()->visit(this, visitor); |
3654 | getRHS().getNode()->visit(this, visitor); |
3655 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
3656 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
3657 | visitor->post(parent, this); |
3658 | } |
3659 | |
3660 | bool AddGradNode::isEqual(const AddGradNode &other) const { |
3661 | return true && |
3662 | LHS_ == other.LHS_ && |
3663 | RHS_ == other.RHS_ && |
3664 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
3665 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
3666 | predicate_ == other.predicate_ && |
3667 | getType(0) == other.getType(0) && |
3668 | getType(1) == other.getType(1); |
3669 | } |
3670 | |
3671 | Node* AddGradNode::clone() const { |
3672 | return new AddGradNode(getName(), getLHS(), getRHS(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
3673 | } |
3674 | |
3675 | llvm::hash_code AddGradNode::getHash() const { |
3676 | return llvm::hash_combine( |
3677 | LHS_, |
3678 | RHS_, |
3679 | OriginalOutputForResult_, |
3680 | GradOfOriginalOutputNamedResult_); |
3681 | } |
3682 | |
3683 | unsigned AddNode::getNumInputs() const { |
3684 | return 2; |
3685 | } |
3686 | |
3687 | std::string AddNode::getInputName(unsigned idx) const { |
3688 | if (idx == 0) { return "LHS" ; } |
3689 | if (idx == 1) { return "RHS" ; } |
3690 | idx -= 2; |
3691 | llvm_unreachable("Invalid index" ); |
3692 | } |
3693 | |
3694 | NodeValue AddNode::getNthInput(unsigned idx) { |
3695 | if (idx == 0) { return LHS_; } |
3696 | if (idx == 1) { return RHS_; } |
3697 | idx -= 2; |
3698 | llvm_unreachable("Invalid index" ); |
3699 | } |
3700 | |
3701 | void AddNode::setNthInput(unsigned idx, NodeValue val) { |
3702 | if (idx == 0) { LHS_ = val; return; } |
3703 | if (idx == 1) { RHS_ = val; return; } |
3704 | idx -= 2; |
3705 | llvm_unreachable("Invalid index" ); |
3706 | } |
3707 | |
3708 | llvm::StringRef AddNode::getOutputName(unsigned idx) const { |
3709 | if (idx == 0) { return "Result" ; } |
3710 | llvm_unreachable("Invalid index" ); |
3711 | } |
3712 | |
3713 | std::string AddNode::getDebugDesc() const { |
3714 | DescriptionBuilder db(getKindName()); |
3715 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3716 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3717 | db |
3718 | .addParam("LHS" , *(getLHS().getType())) |
3719 | .addParam("RHS" , *(getRHS().getType())) |
3720 | .addParam("Users" , getNumUsers()); |
3721 | db.addParam("Result" , *(getResult().getType())); |
3722 | return db; |
3723 | } |
3724 | |
3725 | void AddNode::visit(Node *parent, NodeWalker *visitor) { |
3726 | if (!visitor->shouldVisit(parent, this)) { return; } |
3727 | visitor->pre(parent, this); |
3728 | if (hasPredicate()) |
3729 | getPredicate().getNode()->visit(this, visitor); |
3730 | getLHS().getNode()->visit(this, visitor); |
3731 | getRHS().getNode()->visit(this, visitor); |
3732 | visitor->post(parent, this); |
3733 | } |
3734 | |
3735 | bool AddNode::isEqual(const AddNode &other) const { |
3736 | return true && |
3737 | LHS_ == other.LHS_ && |
3738 | RHS_ == other.RHS_ && |
3739 | predicate_ == other.predicate_ && |
3740 | getType(0) == other.getType(0); |
3741 | } |
3742 | |
3743 | Node* AddNode::clone() const { |
3744 | return new AddNode(getName(), getResult().getType(), getLHS(), getRHS()); |
3745 | } |
3746 | |
3747 | llvm::hash_code AddNode::getHash() const { |
3748 | return llvm::hash_combine( |
3749 | LHS_, |
3750 | RHS_); |
3751 | } |
3752 | |
3753 | AddGradNode *AddNode::getGrad(GraphGradMapper &builder) { |
3754 | auto *x = new AddGradNode(getName().str() + "_grad" , getLHS(), getRHS(), getResult(), builder.getGradient(getResult())); |
3755 | builder.addGradient(getLHS(), x->getGradOfInputNamedLHS()); |
3756 | builder.addGradient(getRHS(), x->getGradOfInputNamedRHS()); |
3757 | return x; |
3758 | } |
3759 | |
3760 | unsigned MulGradNode::getNumInputs() const { |
3761 | return 4; |
3762 | } |
3763 | |
3764 | std::string MulGradNode::getInputName(unsigned idx) const { |
3765 | if (idx == 0) { return "LHS" ; } |
3766 | if (idx == 1) { return "RHS" ; } |
3767 | if (idx == 2) { return "OriginalOutputForResult" ; } |
3768 | if (idx == 3) { return "GradOfOriginalOutputNamedResult" ; } |
3769 | idx -= 4; |
3770 | llvm_unreachable("Invalid index" ); |
3771 | } |
3772 | |
3773 | NodeValue MulGradNode::getNthInput(unsigned idx) { |
3774 | if (idx == 0) { return LHS_; } |
3775 | if (idx == 1) { return RHS_; } |
3776 | if (idx == 2) { return OriginalOutputForResult_; } |
3777 | if (idx == 3) { return GradOfOriginalOutputNamedResult_; } |
3778 | idx -= 4; |
3779 | llvm_unreachable("Invalid index" ); |
3780 | } |
3781 | |
3782 | void MulGradNode::setNthInput(unsigned idx, NodeValue val) { |
3783 | if (idx == 0) { LHS_ = val; return; } |
3784 | if (idx == 1) { RHS_ = val; return; } |
3785 | if (idx == 2) { OriginalOutputForResult_ = val; return; } |
3786 | if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; } |
3787 | idx -= 4; |
3788 | llvm_unreachable("Invalid index" ); |
3789 | } |
3790 | |
3791 | llvm::StringRef MulGradNode::getOutputName(unsigned idx) const { |
3792 | if (idx == 0) { return "GradOfInputNamedLHS" ; } |
3793 | if (idx == 1) { return "GradOfInputNamedRHS" ; } |
3794 | llvm_unreachable("Invalid index" ); |
3795 | } |
3796 | |
3797 | std::string MulGradNode::getDebugDesc() const { |
3798 | DescriptionBuilder db(getKindName()); |
3799 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3800 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3801 | db |
3802 | .addParam("LHS" , *(getLHS().getType())) |
3803 | .addParam("RHS" , *(getRHS().getType())) |
3804 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
3805 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
3806 | .addParam("Users" , getNumUsers()); |
3807 | db.addParam("GradOfInputNamedLHS" , *(getGradOfInputNamedLHS().getType())); |
3808 | db.addParam("GradOfInputNamedRHS" , *(getGradOfInputNamedRHS().getType())); |
3809 | return db; |
3810 | } |
3811 | |
3812 | void MulGradNode::visit(Node *parent, NodeWalker *visitor) { |
3813 | if (!visitor->shouldVisit(parent, this)) { return; } |
3814 | visitor->pre(parent, this); |
3815 | if (hasPredicate()) |
3816 | getPredicate().getNode()->visit(this, visitor); |
3817 | getLHS().getNode()->visit(this, visitor); |
3818 | getRHS().getNode()->visit(this, visitor); |
3819 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
3820 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
3821 | visitor->post(parent, this); |
3822 | } |
3823 | |
3824 | bool MulGradNode::isEqual(const MulGradNode &other) const { |
3825 | return true && |
3826 | LHS_ == other.LHS_ && |
3827 | RHS_ == other.RHS_ && |
3828 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
3829 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
3830 | predicate_ == other.predicate_ && |
3831 | getType(0) == other.getType(0) && |
3832 | getType(1) == other.getType(1); |
3833 | } |
3834 | |
3835 | Node* MulGradNode::clone() const { |
3836 | return new MulGradNode(getName(), getLHS(), getRHS(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
3837 | } |
3838 | |
3839 | llvm::hash_code MulGradNode::getHash() const { |
3840 | return llvm::hash_combine( |
3841 | LHS_, |
3842 | RHS_, |
3843 | OriginalOutputForResult_, |
3844 | GradOfOriginalOutputNamedResult_); |
3845 | } |
3846 | |
3847 | unsigned MulNode::getNumInputs() const { |
3848 | return 2; |
3849 | } |
3850 | |
3851 | std::string MulNode::getInputName(unsigned idx) const { |
3852 | if (idx == 0) { return "LHS" ; } |
3853 | if (idx == 1) { return "RHS" ; } |
3854 | idx -= 2; |
3855 | llvm_unreachable("Invalid index" ); |
3856 | } |
3857 | |
3858 | NodeValue MulNode::getNthInput(unsigned idx) { |
3859 | if (idx == 0) { return LHS_; } |
3860 | if (idx == 1) { return RHS_; } |
3861 | idx -= 2; |
3862 | llvm_unreachable("Invalid index" ); |
3863 | } |
3864 | |
3865 | void MulNode::setNthInput(unsigned idx, NodeValue val) { |
3866 | if (idx == 0) { LHS_ = val; return; } |
3867 | if (idx == 1) { RHS_ = val; return; } |
3868 | idx -= 2; |
3869 | llvm_unreachable("Invalid index" ); |
3870 | } |
3871 | |
3872 | llvm::StringRef MulNode::getOutputName(unsigned idx) const { |
3873 | if (idx == 0) { return "Result" ; } |
3874 | llvm_unreachable("Invalid index" ); |
3875 | } |
3876 | |
3877 | std::string MulNode::getDebugDesc() const { |
3878 | DescriptionBuilder db(getKindName()); |
3879 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3880 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3881 | db |
3882 | .addParam("LHS" , *(getLHS().getType())) |
3883 | .addParam("RHS" , *(getRHS().getType())) |
3884 | .addParam("Users" , getNumUsers()); |
3885 | db.addParam("Result" , *(getResult().getType())); |
3886 | return db; |
3887 | } |
3888 | |
3889 | void MulNode::visit(Node *parent, NodeWalker *visitor) { |
3890 | if (!visitor->shouldVisit(parent, this)) { return; } |
3891 | visitor->pre(parent, this); |
3892 | if (hasPredicate()) |
3893 | getPredicate().getNode()->visit(this, visitor); |
3894 | getLHS().getNode()->visit(this, visitor); |
3895 | getRHS().getNode()->visit(this, visitor); |
3896 | visitor->post(parent, this); |
3897 | } |
3898 | |
3899 | bool MulNode::isEqual(const MulNode &other) const { |
3900 | return true && |
3901 | LHS_ == other.LHS_ && |
3902 | RHS_ == other.RHS_ && |
3903 | predicate_ == other.predicate_ && |
3904 | getType(0) == other.getType(0); |
3905 | } |
3906 | |
3907 | Node* MulNode::clone() const { |
3908 | return new MulNode(getName(), getResult().getType(), getLHS(), getRHS()); |
3909 | } |
3910 | |
3911 | llvm::hash_code MulNode::getHash() const { |
3912 | return llvm::hash_combine( |
3913 | LHS_, |
3914 | RHS_); |
3915 | } |
3916 | |
3917 | MulGradNode *MulNode::getGrad(GraphGradMapper &builder) { |
3918 | auto *x = new MulGradNode(getName().str() + "_grad" , getLHS(), getRHS(), getResult(), builder.getGradient(getResult())); |
3919 | builder.addGradient(getLHS(), x->getGradOfInputNamedLHS()); |
3920 | builder.addGradient(getRHS(), x->getGradOfInputNamedRHS()); |
3921 | return x; |
3922 | } |
3923 | |
3924 | unsigned SubGradNode::getNumInputs() const { |
3925 | return 4; |
3926 | } |
3927 | |
3928 | std::string SubGradNode::getInputName(unsigned idx) const { |
3929 | if (idx == 0) { return "LHS" ; } |
3930 | if (idx == 1) { return "RHS" ; } |
3931 | if (idx == 2) { return "OriginalOutputForResult" ; } |
3932 | if (idx == 3) { return "GradOfOriginalOutputNamedResult" ; } |
3933 | idx -= 4; |
3934 | llvm_unreachable("Invalid index" ); |
3935 | } |
3936 | |
3937 | NodeValue SubGradNode::getNthInput(unsigned idx) { |
3938 | if (idx == 0) { return LHS_; } |
3939 | if (idx == 1) { return RHS_; } |
3940 | if (idx == 2) { return OriginalOutputForResult_; } |
3941 | if (idx == 3) { return GradOfOriginalOutputNamedResult_; } |
3942 | idx -= 4; |
3943 | llvm_unreachable("Invalid index" ); |
3944 | } |
3945 | |
3946 | void SubGradNode::setNthInput(unsigned idx, NodeValue val) { |
3947 | if (idx == 0) { LHS_ = val; return; } |
3948 | if (idx == 1) { RHS_ = val; return; } |
3949 | if (idx == 2) { OriginalOutputForResult_ = val; return; } |
3950 | if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; } |
3951 | idx -= 4; |
3952 | llvm_unreachable("Invalid index" ); |
3953 | } |
3954 | |
3955 | llvm::StringRef SubGradNode::getOutputName(unsigned idx) const { |
3956 | if (idx == 0) { return "GradOfInputNamedLHS" ; } |
3957 | if (idx == 1) { return "GradOfInputNamedRHS" ; } |
3958 | llvm_unreachable("Invalid index" ); |
3959 | } |
3960 | |
3961 | std::string SubGradNode::getDebugDesc() const { |
3962 | DescriptionBuilder db(getKindName()); |
3963 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
3964 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
3965 | db |
3966 | .addParam("LHS" , *(getLHS().getType())) |
3967 | .addParam("RHS" , *(getRHS().getType())) |
3968 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
3969 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
3970 | .addParam("Users" , getNumUsers()); |
3971 | db.addParam("GradOfInputNamedLHS" , *(getGradOfInputNamedLHS().getType())); |
3972 | db.addParam("GradOfInputNamedRHS" , *(getGradOfInputNamedRHS().getType())); |
3973 | return db; |
3974 | } |
3975 | |
3976 | void SubGradNode::visit(Node *parent, NodeWalker *visitor) { |
3977 | if (!visitor->shouldVisit(parent, this)) { return; } |
3978 | visitor->pre(parent, this); |
3979 | if (hasPredicate()) |
3980 | getPredicate().getNode()->visit(this, visitor); |
3981 | getLHS().getNode()->visit(this, visitor); |
3982 | getRHS().getNode()->visit(this, visitor); |
3983 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
3984 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
3985 | visitor->post(parent, this); |
3986 | } |
3987 | |
3988 | bool SubGradNode::isEqual(const SubGradNode &other) const { |
3989 | return true && |
3990 | LHS_ == other.LHS_ && |
3991 | RHS_ == other.RHS_ && |
3992 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
3993 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
3994 | predicate_ == other.predicate_ && |
3995 | getType(0) == other.getType(0) && |
3996 | getType(1) == other.getType(1); |
3997 | } |
3998 | |
3999 | Node* SubGradNode::clone() const { |
4000 | return new SubGradNode(getName(), getLHS(), getRHS(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
4001 | } |
4002 | |
4003 | llvm::hash_code SubGradNode::getHash() const { |
4004 | return llvm::hash_combine( |
4005 | LHS_, |
4006 | RHS_, |
4007 | OriginalOutputForResult_, |
4008 | GradOfOriginalOutputNamedResult_); |
4009 | } |
4010 | |
4011 | unsigned SubNode::getNumInputs() const { |
4012 | return 2; |
4013 | } |
4014 | |
4015 | std::string SubNode::getInputName(unsigned idx) const { |
4016 | if (idx == 0) { return "LHS" ; } |
4017 | if (idx == 1) { return "RHS" ; } |
4018 | idx -= 2; |
4019 | llvm_unreachable("Invalid index" ); |
4020 | } |
4021 | |
4022 | NodeValue SubNode::getNthInput(unsigned idx) { |
4023 | if (idx == 0) { return LHS_; } |
4024 | if (idx == 1) { return RHS_; } |
4025 | idx -= 2; |
4026 | llvm_unreachable("Invalid index" ); |
4027 | } |
4028 | |
4029 | void SubNode::setNthInput(unsigned idx, NodeValue val) { |
4030 | if (idx == 0) { LHS_ = val; return; } |
4031 | if (idx == 1) { RHS_ = val; return; } |
4032 | idx -= 2; |
4033 | llvm_unreachable("Invalid index" ); |
4034 | } |
4035 | |
4036 | llvm::StringRef SubNode::getOutputName(unsigned idx) const { |
4037 | if (idx == 0) { return "Result" ; } |
4038 | llvm_unreachable("Invalid index" ); |
4039 | } |
4040 | |
4041 | std::string SubNode::getDebugDesc() const { |
4042 | DescriptionBuilder db(getKindName()); |
4043 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4044 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4045 | db |
4046 | .addParam("LHS" , *(getLHS().getType())) |
4047 | .addParam("RHS" , *(getRHS().getType())) |
4048 | .addParam("Users" , getNumUsers()); |
4049 | db.addParam("Result" , *(getResult().getType())); |
4050 | return db; |
4051 | } |
4052 | |
4053 | void SubNode::visit(Node *parent, NodeWalker *visitor) { |
4054 | if (!visitor->shouldVisit(parent, this)) { return; } |
4055 | visitor->pre(parent, this); |
4056 | if (hasPredicate()) |
4057 | getPredicate().getNode()->visit(this, visitor); |
4058 | getLHS().getNode()->visit(this, visitor); |
4059 | getRHS().getNode()->visit(this, visitor); |
4060 | visitor->post(parent, this); |
4061 | } |
4062 | |
4063 | bool SubNode::isEqual(const SubNode &other) const { |
4064 | return true && |
4065 | LHS_ == other.LHS_ && |
4066 | RHS_ == other.RHS_ && |
4067 | predicate_ == other.predicate_ && |
4068 | getType(0) == other.getType(0); |
4069 | } |
4070 | |
4071 | Node* SubNode::clone() const { |
4072 | return new SubNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4073 | } |
4074 | |
4075 | llvm::hash_code SubNode::getHash() const { |
4076 | return llvm::hash_combine( |
4077 | LHS_, |
4078 | RHS_); |
4079 | } |
4080 | |
4081 | SubGradNode *SubNode::getGrad(GraphGradMapper &builder) { |
4082 | auto *x = new SubGradNode(getName().str() + "_grad" , getLHS(), getRHS(), getResult(), builder.getGradient(getResult())); |
4083 | builder.addGradient(getLHS(), x->getGradOfInputNamedLHS()); |
4084 | builder.addGradient(getRHS(), x->getGradOfInputNamedRHS()); |
4085 | return x; |
4086 | } |
4087 | |
4088 | unsigned DivGradNode::getNumInputs() const { |
4089 | return 4; |
4090 | } |
4091 | |
4092 | std::string DivGradNode::getInputName(unsigned idx) const { |
4093 | if (idx == 0) { return "LHS" ; } |
4094 | if (idx == 1) { return "RHS" ; } |
4095 | if (idx == 2) { return "OriginalOutputForResult" ; } |
4096 | if (idx == 3) { return "GradOfOriginalOutputNamedResult" ; } |
4097 | idx -= 4; |
4098 | llvm_unreachable("Invalid index" ); |
4099 | } |
4100 | |
4101 | NodeValue DivGradNode::getNthInput(unsigned idx) { |
4102 | if (idx == 0) { return LHS_; } |
4103 | if (idx == 1) { return RHS_; } |
4104 | if (idx == 2) { return OriginalOutputForResult_; } |
4105 | if (idx == 3) { return GradOfOriginalOutputNamedResult_; } |
4106 | idx -= 4; |
4107 | llvm_unreachable("Invalid index" ); |
4108 | } |
4109 | |
4110 | void DivGradNode::setNthInput(unsigned idx, NodeValue val) { |
4111 | if (idx == 0) { LHS_ = val; return; } |
4112 | if (idx == 1) { RHS_ = val; return; } |
4113 | if (idx == 2) { OriginalOutputForResult_ = val; return; } |
4114 | if (idx == 3) { GradOfOriginalOutputNamedResult_ = val; return; } |
4115 | idx -= 4; |
4116 | llvm_unreachable("Invalid index" ); |
4117 | } |
4118 | |
4119 | llvm::StringRef DivGradNode::getOutputName(unsigned idx) const { |
4120 | if (idx == 0) { return "GradOfInputNamedLHS" ; } |
4121 | if (idx == 1) { return "GradOfInputNamedRHS" ; } |
4122 | llvm_unreachable("Invalid index" ); |
4123 | } |
4124 | |
4125 | std::string DivGradNode::getDebugDesc() const { |
4126 | DescriptionBuilder db(getKindName()); |
4127 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4128 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4129 | db |
4130 | .addParam("LHS" , *(getLHS().getType())) |
4131 | .addParam("RHS" , *(getRHS().getType())) |
4132 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
4133 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
4134 | .addParam("Users" , getNumUsers()); |
4135 | db.addParam("GradOfInputNamedLHS" , *(getGradOfInputNamedLHS().getType())); |
4136 | db.addParam("GradOfInputNamedRHS" , *(getGradOfInputNamedRHS().getType())); |
4137 | return db; |
4138 | } |
4139 | |
4140 | void DivGradNode::visit(Node *parent, NodeWalker *visitor) { |
4141 | if (!visitor->shouldVisit(parent, this)) { return; } |
4142 | visitor->pre(parent, this); |
4143 | if (hasPredicate()) |
4144 | getPredicate().getNode()->visit(this, visitor); |
4145 | getLHS().getNode()->visit(this, visitor); |
4146 | getRHS().getNode()->visit(this, visitor); |
4147 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
4148 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
4149 | visitor->post(parent, this); |
4150 | } |
4151 | |
4152 | bool DivGradNode::isEqual(const DivGradNode &other) const { |
4153 | return true && |
4154 | LHS_ == other.LHS_ && |
4155 | RHS_ == other.RHS_ && |
4156 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
4157 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
4158 | predicate_ == other.predicate_ && |
4159 | getType(0) == other.getType(0) && |
4160 | getType(1) == other.getType(1); |
4161 | } |
4162 | |
4163 | Node* DivGradNode::clone() const { |
4164 | return new DivGradNode(getName(), getLHS(), getRHS(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
4165 | } |
4166 | |
4167 | llvm::hash_code DivGradNode::getHash() const { |
4168 | return llvm::hash_combine( |
4169 | LHS_, |
4170 | RHS_, |
4171 | OriginalOutputForResult_, |
4172 | GradOfOriginalOutputNamedResult_); |
4173 | } |
4174 | |
4175 | unsigned DivNode::getNumInputs() const { |
4176 | return 2; |
4177 | } |
4178 | |
4179 | std::string DivNode::getInputName(unsigned idx) const { |
4180 | if (idx == 0) { return "LHS" ; } |
4181 | if (idx == 1) { return "RHS" ; } |
4182 | idx -= 2; |
4183 | llvm_unreachable("Invalid index" ); |
4184 | } |
4185 | |
4186 | NodeValue DivNode::getNthInput(unsigned idx) { |
4187 | if (idx == 0) { return LHS_; } |
4188 | if (idx == 1) { return RHS_; } |
4189 | idx -= 2; |
4190 | llvm_unreachable("Invalid index" ); |
4191 | } |
4192 | |
4193 | void DivNode::setNthInput(unsigned idx, NodeValue val) { |
4194 | if (idx == 0) { LHS_ = val; return; } |
4195 | if (idx == 1) { RHS_ = val; return; } |
4196 | idx -= 2; |
4197 | llvm_unreachable("Invalid index" ); |
4198 | } |
4199 | |
4200 | llvm::StringRef DivNode::getOutputName(unsigned idx) const { |
4201 | if (idx == 0) { return "Result" ; } |
4202 | llvm_unreachable("Invalid index" ); |
4203 | } |
4204 | |
4205 | std::string DivNode::getDebugDesc() const { |
4206 | DescriptionBuilder db(getKindName()); |
4207 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4208 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4209 | db |
4210 | .addParam("LHS" , *(getLHS().getType())) |
4211 | .addParam("RHS" , *(getRHS().getType())) |
4212 | .addParam("Users" , getNumUsers()); |
4213 | db.addParam("Result" , *(getResult().getType())); |
4214 | return db; |
4215 | } |
4216 | |
4217 | void DivNode::visit(Node *parent, NodeWalker *visitor) { |
4218 | if (!visitor->shouldVisit(parent, this)) { return; } |
4219 | visitor->pre(parent, this); |
4220 | if (hasPredicate()) |
4221 | getPredicate().getNode()->visit(this, visitor); |
4222 | getLHS().getNode()->visit(this, visitor); |
4223 | getRHS().getNode()->visit(this, visitor); |
4224 | visitor->post(parent, this); |
4225 | } |
4226 | |
4227 | bool DivNode::isEqual(const DivNode &other) const { |
4228 | return true && |
4229 | LHS_ == other.LHS_ && |
4230 | RHS_ == other.RHS_ && |
4231 | predicate_ == other.predicate_ && |
4232 | getType(0) == other.getType(0); |
4233 | } |
4234 | |
4235 | Node* DivNode::clone() const { |
4236 | return new DivNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4237 | } |
4238 | |
4239 | llvm::hash_code DivNode::getHash() const { |
4240 | return llvm::hash_combine( |
4241 | LHS_, |
4242 | RHS_); |
4243 | } |
4244 | |
4245 | DivGradNode *DivNode::getGrad(GraphGradMapper &builder) { |
4246 | auto *x = new DivGradNode(getName().str() + "_grad" , getLHS(), getRHS(), getResult(), builder.getGradient(getResult())); |
4247 | builder.addGradient(getLHS(), x->getGradOfInputNamedLHS()); |
4248 | builder.addGradient(getRHS(), x->getGradOfInputNamedRHS()); |
4249 | return x; |
4250 | } |
4251 | |
4252 | unsigned FloorDivNode::getNumInputs() const { |
4253 | return 2; |
4254 | } |
4255 | |
4256 | std::string FloorDivNode::getInputName(unsigned idx) const { |
4257 | if (idx == 0) { return "LHS" ; } |
4258 | if (idx == 1) { return "RHS" ; } |
4259 | idx -= 2; |
4260 | llvm_unreachable("Invalid index" ); |
4261 | } |
4262 | |
4263 | NodeValue FloorDivNode::getNthInput(unsigned idx) { |
4264 | if (idx == 0) { return LHS_; } |
4265 | if (idx == 1) { return RHS_; } |
4266 | idx -= 2; |
4267 | llvm_unreachable("Invalid index" ); |
4268 | } |
4269 | |
4270 | void FloorDivNode::setNthInput(unsigned idx, NodeValue val) { |
4271 | if (idx == 0) { LHS_ = val; return; } |
4272 | if (idx == 1) { RHS_ = val; return; } |
4273 | idx -= 2; |
4274 | llvm_unreachable("Invalid index" ); |
4275 | } |
4276 | |
4277 | llvm::StringRef FloorDivNode::getOutputName(unsigned idx) const { |
4278 | if (idx == 0) { return "Result" ; } |
4279 | llvm_unreachable("Invalid index" ); |
4280 | } |
4281 | |
4282 | std::string FloorDivNode::getDebugDesc() const { |
4283 | DescriptionBuilder db(getKindName()); |
4284 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4285 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4286 | db |
4287 | .addParam("LHS" , *(getLHS().getType())) |
4288 | .addParam("RHS" , *(getRHS().getType())) |
4289 | .addParam("Truncate" , getTruncate()) |
4290 | .addParam("Users" , getNumUsers()); |
4291 | db.addParam("Result" , *(getResult().getType())); |
4292 | return db; |
4293 | } |
4294 | |
4295 | void FloorDivNode::visit(Node *parent, NodeWalker *visitor) { |
4296 | if (!visitor->shouldVisit(parent, this)) { return; } |
4297 | visitor->pre(parent, this); |
4298 | if (hasPredicate()) |
4299 | getPredicate().getNode()->visit(this, visitor); |
4300 | getLHS().getNode()->visit(this, visitor); |
4301 | getRHS().getNode()->visit(this, visitor); |
4302 | visitor->post(parent, this); |
4303 | } |
4304 | |
4305 | bool FloorDivNode::isEqual(const FloorDivNode &other) const { |
4306 | return true && |
4307 | LHS_ == other.LHS_ && |
4308 | RHS_ == other.RHS_ && |
4309 | predicate_ == other.predicate_ && |
4310 | Truncate_ == other.Truncate_ && |
4311 | getType(0) == other.getType(0); |
4312 | } |
4313 | |
4314 | Node* FloorDivNode::clone() const { |
4315 | return new FloorDivNode(getName(), getResult().getType(), getLHS(), getRHS(), getTruncate()); |
4316 | } |
4317 | |
4318 | llvm::hash_code FloorDivNode::getHash() const { |
4319 | return llvm::hash_combine( |
4320 | Truncate_, |
4321 | LHS_, |
4322 | RHS_); |
4323 | } |
4324 | |
4325 | unsigned FmodNode::getNumInputs() const { |
4326 | return 2; |
4327 | } |
4328 | |
4329 | std::string FmodNode::getInputName(unsigned idx) const { |
4330 | if (idx == 0) { return "LHS" ; } |
4331 | if (idx == 1) { return "RHS" ; } |
4332 | idx -= 2; |
4333 | llvm_unreachable("Invalid index" ); |
4334 | } |
4335 | |
4336 | NodeValue FmodNode::getNthInput(unsigned idx) { |
4337 | if (idx == 0) { return LHS_; } |
4338 | if (idx == 1) { return RHS_; } |
4339 | idx -= 2; |
4340 | llvm_unreachable("Invalid index" ); |
4341 | } |
4342 | |
4343 | void FmodNode::setNthInput(unsigned idx, NodeValue val) { |
4344 | if (idx == 0) { LHS_ = val; return; } |
4345 | if (idx == 1) { RHS_ = val; return; } |
4346 | idx -= 2; |
4347 | llvm_unreachable("Invalid index" ); |
4348 | } |
4349 | |
4350 | llvm::StringRef FmodNode::getOutputName(unsigned idx) const { |
4351 | if (idx == 0) { return "Result" ; } |
4352 | llvm_unreachable("Invalid index" ); |
4353 | } |
4354 | |
4355 | std::string FmodNode::getDebugDesc() const { |
4356 | DescriptionBuilder db(getKindName()); |
4357 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4358 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4359 | db |
4360 | .addParam("LHS" , *(getLHS().getType())) |
4361 | .addParam("RHS" , *(getRHS().getType())) |
4362 | .addParam("Users" , getNumUsers()); |
4363 | db.addParam("Result" , *(getResult().getType())); |
4364 | return db; |
4365 | } |
4366 | |
4367 | void FmodNode::visit(Node *parent, NodeWalker *visitor) { |
4368 | if (!visitor->shouldVisit(parent, this)) { return; } |
4369 | visitor->pre(parent, this); |
4370 | if (hasPredicate()) |
4371 | getPredicate().getNode()->visit(this, visitor); |
4372 | getLHS().getNode()->visit(this, visitor); |
4373 | getRHS().getNode()->visit(this, visitor); |
4374 | visitor->post(parent, this); |
4375 | } |
4376 | |
4377 | bool FmodNode::isEqual(const FmodNode &other) const { |
4378 | return true && |
4379 | LHS_ == other.LHS_ && |
4380 | RHS_ == other.RHS_ && |
4381 | predicate_ == other.predicate_ && |
4382 | getType(0) == other.getType(0); |
4383 | } |
4384 | |
4385 | Node* FmodNode::clone() const { |
4386 | return new FmodNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4387 | } |
4388 | |
4389 | llvm::hash_code FmodNode::getHash() const { |
4390 | return llvm::hash_combine( |
4391 | LHS_, |
4392 | RHS_); |
4393 | } |
4394 | |
4395 | unsigned MaxNode::getNumInputs() const { |
4396 | return 2; |
4397 | } |
4398 | |
4399 | std::string MaxNode::getInputName(unsigned idx) const { |
4400 | if (idx == 0) { return "LHS" ; } |
4401 | if (idx == 1) { return "RHS" ; } |
4402 | idx -= 2; |
4403 | llvm_unreachable("Invalid index" ); |
4404 | } |
4405 | |
4406 | NodeValue MaxNode::getNthInput(unsigned idx) { |
4407 | if (idx == 0) { return LHS_; } |
4408 | if (idx == 1) { return RHS_; } |
4409 | idx -= 2; |
4410 | llvm_unreachable("Invalid index" ); |
4411 | } |
4412 | |
4413 | void MaxNode::setNthInput(unsigned idx, NodeValue val) { |
4414 | if (idx == 0) { LHS_ = val; return; } |
4415 | if (idx == 1) { RHS_ = val; return; } |
4416 | idx -= 2; |
4417 | llvm_unreachable("Invalid index" ); |
4418 | } |
4419 | |
4420 | llvm::StringRef MaxNode::getOutputName(unsigned idx) const { |
4421 | if (idx == 0) { return "Result" ; } |
4422 | llvm_unreachable("Invalid index" ); |
4423 | } |
4424 | |
4425 | std::string MaxNode::getDebugDesc() const { |
4426 | DescriptionBuilder db(getKindName()); |
4427 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4428 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4429 | db |
4430 | .addParam("LHS" , *(getLHS().getType())) |
4431 | .addParam("RHS" , *(getRHS().getType())) |
4432 | .addParam("Users" , getNumUsers()); |
4433 | db.addParam("Result" , *(getResult().getType())); |
4434 | return db; |
4435 | } |
4436 | |
4437 | void MaxNode::visit(Node *parent, NodeWalker *visitor) { |
4438 | if (!visitor->shouldVisit(parent, this)) { return; } |
4439 | visitor->pre(parent, this); |
4440 | if (hasPredicate()) |
4441 | getPredicate().getNode()->visit(this, visitor); |
4442 | getLHS().getNode()->visit(this, visitor); |
4443 | getRHS().getNode()->visit(this, visitor); |
4444 | visitor->post(parent, this); |
4445 | } |
4446 | |
4447 | bool MaxNode::isEqual(const MaxNode &other) const { |
4448 | return true && |
4449 | LHS_ == other.LHS_ && |
4450 | RHS_ == other.RHS_ && |
4451 | predicate_ == other.predicate_ && |
4452 | getType(0) == other.getType(0); |
4453 | } |
4454 | |
4455 | Node* MaxNode::clone() const { |
4456 | return new MaxNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4457 | } |
4458 | |
4459 | llvm::hash_code MaxNode::getHash() const { |
4460 | return llvm::hash_combine( |
4461 | LHS_, |
4462 | RHS_); |
4463 | } |
4464 | |
4465 | unsigned MinNode::getNumInputs() const { |
4466 | return 2; |
4467 | } |
4468 | |
4469 | std::string MinNode::getInputName(unsigned idx) const { |
4470 | if (idx == 0) { return "LHS" ; } |
4471 | if (idx == 1) { return "RHS" ; } |
4472 | idx -= 2; |
4473 | llvm_unreachable("Invalid index" ); |
4474 | } |
4475 | |
4476 | NodeValue MinNode::getNthInput(unsigned idx) { |
4477 | if (idx == 0) { return LHS_; } |
4478 | if (idx == 1) { return RHS_; } |
4479 | idx -= 2; |
4480 | llvm_unreachable("Invalid index" ); |
4481 | } |
4482 | |
4483 | void MinNode::setNthInput(unsigned idx, NodeValue val) { |
4484 | if (idx == 0) { LHS_ = val; return; } |
4485 | if (idx == 1) { RHS_ = val; return; } |
4486 | idx -= 2; |
4487 | llvm_unreachable("Invalid index" ); |
4488 | } |
4489 | |
4490 | llvm::StringRef MinNode::getOutputName(unsigned idx) const { |
4491 | if (idx == 0) { return "Result" ; } |
4492 | llvm_unreachable("Invalid index" ); |
4493 | } |
4494 | |
4495 | std::string MinNode::getDebugDesc() const { |
4496 | DescriptionBuilder db(getKindName()); |
4497 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4498 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4499 | db |
4500 | .addParam("LHS" , *(getLHS().getType())) |
4501 | .addParam("RHS" , *(getRHS().getType())) |
4502 | .addParam("Users" , getNumUsers()); |
4503 | db.addParam("Result" , *(getResult().getType())); |
4504 | return db; |
4505 | } |
4506 | |
4507 | void MinNode::visit(Node *parent, NodeWalker *visitor) { |
4508 | if (!visitor->shouldVisit(parent, this)) { return; } |
4509 | visitor->pre(parent, this); |
4510 | if (hasPredicate()) |
4511 | getPredicate().getNode()->visit(this, visitor); |
4512 | getLHS().getNode()->visit(this, visitor); |
4513 | getRHS().getNode()->visit(this, visitor); |
4514 | visitor->post(parent, this); |
4515 | } |
4516 | |
4517 | bool MinNode::isEqual(const MinNode &other) const { |
4518 | return true && |
4519 | LHS_ == other.LHS_ && |
4520 | RHS_ == other.RHS_ && |
4521 | predicate_ == other.predicate_ && |
4522 | getType(0) == other.getType(0); |
4523 | } |
4524 | |
4525 | Node* MinNode::clone() const { |
4526 | return new MinNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4527 | } |
4528 | |
4529 | llvm::hash_code MinNode::getHash() const { |
4530 | return llvm::hash_combine( |
4531 | LHS_, |
4532 | RHS_); |
4533 | } |
4534 | |
4535 | unsigned CmpEQNode::getNumInputs() const { |
4536 | return 2; |
4537 | } |
4538 | |
4539 | std::string CmpEQNode::getInputName(unsigned idx) const { |
4540 | if (idx == 0) { return "LHS" ; } |
4541 | if (idx == 1) { return "RHS" ; } |
4542 | idx -= 2; |
4543 | llvm_unreachable("Invalid index" ); |
4544 | } |
4545 | |
4546 | NodeValue CmpEQNode::getNthInput(unsigned idx) { |
4547 | if (idx == 0) { return LHS_; } |
4548 | if (idx == 1) { return RHS_; } |
4549 | idx -= 2; |
4550 | llvm_unreachable("Invalid index" ); |
4551 | } |
4552 | |
4553 | void CmpEQNode::setNthInput(unsigned idx, NodeValue val) { |
4554 | if (idx == 0) { LHS_ = val; return; } |
4555 | if (idx == 1) { RHS_ = val; return; } |
4556 | idx -= 2; |
4557 | llvm_unreachable("Invalid index" ); |
4558 | } |
4559 | |
4560 | llvm::StringRef CmpEQNode::getOutputName(unsigned idx) const { |
4561 | if (idx == 0) { return "Result" ; } |
4562 | llvm_unreachable("Invalid index" ); |
4563 | } |
4564 | |
4565 | std::string CmpEQNode::getDebugDesc() const { |
4566 | DescriptionBuilder db(getKindName()); |
4567 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4568 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4569 | db |
4570 | .addParam("LHS" , *(getLHS().getType())) |
4571 | .addParam("RHS" , *(getRHS().getType())) |
4572 | .addParam("Users" , getNumUsers()); |
4573 | db.addParam("Result" , *(getResult().getType())); |
4574 | return db; |
4575 | } |
4576 | |
4577 | void CmpEQNode::visit(Node *parent, NodeWalker *visitor) { |
4578 | if (!visitor->shouldVisit(parent, this)) { return; } |
4579 | visitor->pre(parent, this); |
4580 | if (hasPredicate()) |
4581 | getPredicate().getNode()->visit(this, visitor); |
4582 | getLHS().getNode()->visit(this, visitor); |
4583 | getRHS().getNode()->visit(this, visitor); |
4584 | visitor->post(parent, this); |
4585 | } |
4586 | |
4587 | bool CmpEQNode::isEqual(const CmpEQNode &other) const { |
4588 | return true && |
4589 | LHS_ == other.LHS_ && |
4590 | RHS_ == other.RHS_ && |
4591 | predicate_ == other.predicate_ && |
4592 | getType(0) == other.getType(0); |
4593 | } |
4594 | |
4595 | Node* CmpEQNode::clone() const { |
4596 | return new CmpEQNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4597 | } |
4598 | |
4599 | llvm::hash_code CmpEQNode::getHash() const { |
4600 | return llvm::hash_combine( |
4601 | LHS_, |
4602 | RHS_); |
4603 | } |
4604 | |
4605 | unsigned CmpNEQNode::getNumInputs() const { |
4606 | return 2; |
4607 | } |
4608 | |
4609 | std::string CmpNEQNode::getInputName(unsigned idx) const { |
4610 | if (idx == 0) { return "LHS" ; } |
4611 | if (idx == 1) { return "RHS" ; } |
4612 | idx -= 2; |
4613 | llvm_unreachable("Invalid index" ); |
4614 | } |
4615 | |
4616 | NodeValue CmpNEQNode::getNthInput(unsigned idx) { |
4617 | if (idx == 0) { return LHS_; } |
4618 | if (idx == 1) { return RHS_; } |
4619 | idx -= 2; |
4620 | llvm_unreachable("Invalid index" ); |
4621 | } |
4622 | |
4623 | void CmpNEQNode::setNthInput(unsigned idx, NodeValue val) { |
4624 | if (idx == 0) { LHS_ = val; return; } |
4625 | if (idx == 1) { RHS_ = val; return; } |
4626 | idx -= 2; |
4627 | llvm_unreachable("Invalid index" ); |
4628 | } |
4629 | |
4630 | llvm::StringRef CmpNEQNode::getOutputName(unsigned idx) const { |
4631 | if (idx == 0) { return "Result" ; } |
4632 | llvm_unreachable("Invalid index" ); |
4633 | } |
4634 | |
4635 | std::string CmpNEQNode::getDebugDesc() const { |
4636 | DescriptionBuilder db(getKindName()); |
4637 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4638 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4639 | db |
4640 | .addParam("LHS" , *(getLHS().getType())) |
4641 | .addParam("RHS" , *(getRHS().getType())) |
4642 | .addParam("Users" , getNumUsers()); |
4643 | db.addParam("Result" , *(getResult().getType())); |
4644 | return db; |
4645 | } |
4646 | |
4647 | void CmpNEQNode::visit(Node *parent, NodeWalker *visitor) { |
4648 | if (!visitor->shouldVisit(parent, this)) { return; } |
4649 | visitor->pre(parent, this); |
4650 | if (hasPredicate()) |
4651 | getPredicate().getNode()->visit(this, visitor); |
4652 | getLHS().getNode()->visit(this, visitor); |
4653 | getRHS().getNode()->visit(this, visitor); |
4654 | visitor->post(parent, this); |
4655 | } |
4656 | |
4657 | bool CmpNEQNode::isEqual(const CmpNEQNode &other) const { |
4658 | return true && |
4659 | LHS_ == other.LHS_ && |
4660 | RHS_ == other.RHS_ && |
4661 | predicate_ == other.predicate_ && |
4662 | getType(0) == other.getType(0); |
4663 | } |
4664 | |
4665 | Node* CmpNEQNode::clone() const { |
4666 | return new CmpNEQNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4667 | } |
4668 | |
4669 | llvm::hash_code CmpNEQNode::getHash() const { |
4670 | return llvm::hash_combine( |
4671 | LHS_, |
4672 | RHS_); |
4673 | } |
4674 | |
4675 | unsigned CmpLTNode::getNumInputs() const { |
4676 | return 2; |
4677 | } |
4678 | |
4679 | std::string CmpLTNode::getInputName(unsigned idx) const { |
4680 | if (idx == 0) { return "LHS" ; } |
4681 | if (idx == 1) { return "RHS" ; } |
4682 | idx -= 2; |
4683 | llvm_unreachable("Invalid index" ); |
4684 | } |
4685 | |
4686 | NodeValue CmpLTNode::getNthInput(unsigned idx) { |
4687 | if (idx == 0) { return LHS_; } |
4688 | if (idx == 1) { return RHS_; } |
4689 | idx -= 2; |
4690 | llvm_unreachable("Invalid index" ); |
4691 | } |
4692 | |
4693 | void CmpLTNode::setNthInput(unsigned idx, NodeValue val) { |
4694 | if (idx == 0) { LHS_ = val; return; } |
4695 | if (idx == 1) { RHS_ = val; return; } |
4696 | idx -= 2; |
4697 | llvm_unreachable("Invalid index" ); |
4698 | } |
4699 | |
4700 | llvm::StringRef CmpLTNode::getOutputName(unsigned idx) const { |
4701 | if (idx == 0) { return "Result" ; } |
4702 | llvm_unreachable("Invalid index" ); |
4703 | } |
4704 | |
4705 | std::string CmpLTNode::getDebugDesc() const { |
4706 | DescriptionBuilder db(getKindName()); |
4707 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4708 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4709 | db |
4710 | .addParam("LHS" , *(getLHS().getType())) |
4711 | .addParam("RHS" , *(getRHS().getType())) |
4712 | .addParam("Users" , getNumUsers()); |
4713 | db.addParam("Result" , *(getResult().getType())); |
4714 | return db; |
4715 | } |
4716 | |
4717 | void CmpLTNode::visit(Node *parent, NodeWalker *visitor) { |
4718 | if (!visitor->shouldVisit(parent, this)) { return; } |
4719 | visitor->pre(parent, this); |
4720 | if (hasPredicate()) |
4721 | getPredicate().getNode()->visit(this, visitor); |
4722 | getLHS().getNode()->visit(this, visitor); |
4723 | getRHS().getNode()->visit(this, visitor); |
4724 | visitor->post(parent, this); |
4725 | } |
4726 | |
4727 | bool CmpLTNode::isEqual(const CmpLTNode &other) const { |
4728 | return true && |
4729 | LHS_ == other.LHS_ && |
4730 | RHS_ == other.RHS_ && |
4731 | predicate_ == other.predicate_ && |
4732 | getType(0) == other.getType(0); |
4733 | } |
4734 | |
4735 | Node* CmpLTNode::clone() const { |
4736 | return new CmpLTNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4737 | } |
4738 | |
4739 | llvm::hash_code CmpLTNode::getHash() const { |
4740 | return llvm::hash_combine( |
4741 | LHS_, |
4742 | RHS_); |
4743 | } |
4744 | |
4745 | unsigned CmpLTENode::getNumInputs() const { |
4746 | return 2; |
4747 | } |
4748 | |
4749 | std::string CmpLTENode::getInputName(unsigned idx) const { |
4750 | if (idx == 0) { return "LHS" ; } |
4751 | if (idx == 1) { return "RHS" ; } |
4752 | idx -= 2; |
4753 | llvm_unreachable("Invalid index" ); |
4754 | } |
4755 | |
4756 | NodeValue CmpLTENode::getNthInput(unsigned idx) { |
4757 | if (idx == 0) { return LHS_; } |
4758 | if (idx == 1) { return RHS_; } |
4759 | idx -= 2; |
4760 | llvm_unreachable("Invalid index" ); |
4761 | } |
4762 | |
4763 | void CmpLTENode::setNthInput(unsigned idx, NodeValue val) { |
4764 | if (idx == 0) { LHS_ = val; return; } |
4765 | if (idx == 1) { RHS_ = val; return; } |
4766 | idx -= 2; |
4767 | llvm_unreachable("Invalid index" ); |
4768 | } |
4769 | |
4770 | llvm::StringRef CmpLTENode::getOutputName(unsigned idx) const { |
4771 | if (idx == 0) { return "Result" ; } |
4772 | llvm_unreachable("Invalid index" ); |
4773 | } |
4774 | |
4775 | std::string CmpLTENode::getDebugDesc() const { |
4776 | DescriptionBuilder db(getKindName()); |
4777 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4778 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4779 | db |
4780 | .addParam("LHS" , *(getLHS().getType())) |
4781 | .addParam("RHS" , *(getRHS().getType())) |
4782 | .addParam("Users" , getNumUsers()); |
4783 | db.addParam("Result" , *(getResult().getType())); |
4784 | return db; |
4785 | } |
4786 | |
4787 | void CmpLTENode::visit(Node *parent, NodeWalker *visitor) { |
4788 | if (!visitor->shouldVisit(parent, this)) { return; } |
4789 | visitor->pre(parent, this); |
4790 | if (hasPredicate()) |
4791 | getPredicate().getNode()->visit(this, visitor); |
4792 | getLHS().getNode()->visit(this, visitor); |
4793 | getRHS().getNode()->visit(this, visitor); |
4794 | visitor->post(parent, this); |
4795 | } |
4796 | |
4797 | bool CmpLTENode::isEqual(const CmpLTENode &other) const { |
4798 | return true && |
4799 | LHS_ == other.LHS_ && |
4800 | RHS_ == other.RHS_ && |
4801 | predicate_ == other.predicate_ && |
4802 | getType(0) == other.getType(0); |
4803 | } |
4804 | |
4805 | Node* CmpLTENode::clone() const { |
4806 | return new CmpLTENode(getName(), getResult().getType(), getLHS(), getRHS()); |
4807 | } |
4808 | |
4809 | llvm::hash_code CmpLTENode::getHash() const { |
4810 | return llvm::hash_combine( |
4811 | LHS_, |
4812 | RHS_); |
4813 | } |
4814 | |
4815 | unsigned PowNode::getNumInputs() const { |
4816 | return 2; |
4817 | } |
4818 | |
4819 | std::string PowNode::getInputName(unsigned idx) const { |
4820 | if (idx == 0) { return "LHS" ; } |
4821 | if (idx == 1) { return "RHS" ; } |
4822 | idx -= 2; |
4823 | llvm_unreachable("Invalid index" ); |
4824 | } |
4825 | |
4826 | NodeValue PowNode::getNthInput(unsigned idx) { |
4827 | if (idx == 0) { return LHS_; } |
4828 | if (idx == 1) { return RHS_; } |
4829 | idx -= 2; |
4830 | llvm_unreachable("Invalid index" ); |
4831 | } |
4832 | |
4833 | void PowNode::setNthInput(unsigned idx, NodeValue val) { |
4834 | if (idx == 0) { LHS_ = val; return; } |
4835 | if (idx == 1) { RHS_ = val; return; } |
4836 | idx -= 2; |
4837 | llvm_unreachable("Invalid index" ); |
4838 | } |
4839 | |
4840 | llvm::StringRef PowNode::getOutputName(unsigned idx) const { |
4841 | if (idx == 0) { return "Result" ; } |
4842 | llvm_unreachable("Invalid index" ); |
4843 | } |
4844 | |
4845 | std::string PowNode::getDebugDesc() const { |
4846 | DescriptionBuilder db(getKindName()); |
4847 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4848 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4849 | db |
4850 | .addParam("LHS" , *(getLHS().getType())) |
4851 | .addParam("RHS" , *(getRHS().getType())) |
4852 | .addParam("Users" , getNumUsers()); |
4853 | db.addParam("Result" , *(getResult().getType())); |
4854 | return db; |
4855 | } |
4856 | |
4857 | void PowNode::visit(Node *parent, NodeWalker *visitor) { |
4858 | if (!visitor->shouldVisit(parent, this)) { return; } |
4859 | visitor->pre(parent, this); |
4860 | if (hasPredicate()) |
4861 | getPredicate().getNode()->visit(this, visitor); |
4862 | getLHS().getNode()->visit(this, visitor); |
4863 | getRHS().getNode()->visit(this, visitor); |
4864 | visitor->post(parent, this); |
4865 | } |
4866 | |
4867 | bool PowNode::isEqual(const PowNode &other) const { |
4868 | return true && |
4869 | LHS_ == other.LHS_ && |
4870 | RHS_ == other.RHS_ && |
4871 | predicate_ == other.predicate_ && |
4872 | getType(0) == other.getType(0); |
4873 | } |
4874 | |
4875 | Node* PowNode::clone() const { |
4876 | return new PowNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4877 | } |
4878 | |
4879 | llvm::hash_code PowNode::getHash() const { |
4880 | return llvm::hash_combine( |
4881 | LHS_, |
4882 | RHS_); |
4883 | } |
4884 | |
4885 | unsigned AndNode::getNumInputs() const { |
4886 | return 2; |
4887 | } |
4888 | |
4889 | std::string AndNode::getInputName(unsigned idx) const { |
4890 | if (idx == 0) { return "LHS" ; } |
4891 | if (idx == 1) { return "RHS" ; } |
4892 | idx -= 2; |
4893 | llvm_unreachable("Invalid index" ); |
4894 | } |
4895 | |
4896 | NodeValue AndNode::getNthInput(unsigned idx) { |
4897 | if (idx == 0) { return LHS_; } |
4898 | if (idx == 1) { return RHS_; } |
4899 | idx -= 2; |
4900 | llvm_unreachable("Invalid index" ); |
4901 | } |
4902 | |
4903 | void AndNode::setNthInput(unsigned idx, NodeValue val) { |
4904 | if (idx == 0) { LHS_ = val; return; } |
4905 | if (idx == 1) { RHS_ = val; return; } |
4906 | idx -= 2; |
4907 | llvm_unreachable("Invalid index" ); |
4908 | } |
4909 | |
4910 | llvm::StringRef AndNode::getOutputName(unsigned idx) const { |
4911 | if (idx == 0) { return "Result" ; } |
4912 | llvm_unreachable("Invalid index" ); |
4913 | } |
4914 | |
4915 | std::string AndNode::getDebugDesc() const { |
4916 | DescriptionBuilder db(getKindName()); |
4917 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4918 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4919 | db |
4920 | .addParam("LHS" , *(getLHS().getType())) |
4921 | .addParam("RHS" , *(getRHS().getType())) |
4922 | .addParam("Users" , getNumUsers()); |
4923 | db.addParam("Result" , *(getResult().getType())); |
4924 | return db; |
4925 | } |
4926 | |
4927 | void AndNode::visit(Node *parent, NodeWalker *visitor) { |
4928 | if (!visitor->shouldVisit(parent, this)) { return; } |
4929 | visitor->pre(parent, this); |
4930 | if (hasPredicate()) |
4931 | getPredicate().getNode()->visit(this, visitor); |
4932 | getLHS().getNode()->visit(this, visitor); |
4933 | getRHS().getNode()->visit(this, visitor); |
4934 | visitor->post(parent, this); |
4935 | } |
4936 | |
4937 | bool AndNode::isEqual(const AndNode &other) const { |
4938 | return true && |
4939 | LHS_ == other.LHS_ && |
4940 | RHS_ == other.RHS_ && |
4941 | predicate_ == other.predicate_ && |
4942 | getType(0) == other.getType(0); |
4943 | } |
4944 | |
4945 | Node* AndNode::clone() const { |
4946 | return new AndNode(getName(), getResult().getType(), getLHS(), getRHS()); |
4947 | } |
4948 | |
4949 | llvm::hash_code AndNode::getHash() const { |
4950 | return llvm::hash_combine( |
4951 | LHS_, |
4952 | RHS_); |
4953 | } |
4954 | |
4955 | unsigned BitwiseAndNode::getNumInputs() const { |
4956 | return 2; |
4957 | } |
4958 | |
4959 | std::string BitwiseAndNode::getInputName(unsigned idx) const { |
4960 | if (idx == 0) { return "LHS" ; } |
4961 | if (idx == 1) { return "RHS" ; } |
4962 | idx -= 2; |
4963 | llvm_unreachable("Invalid index" ); |
4964 | } |
4965 | |
4966 | NodeValue BitwiseAndNode::getNthInput(unsigned idx) { |
4967 | if (idx == 0) { return LHS_; } |
4968 | if (idx == 1) { return RHS_; } |
4969 | idx -= 2; |
4970 | llvm_unreachable("Invalid index" ); |
4971 | } |
4972 | |
4973 | void BitwiseAndNode::setNthInput(unsigned idx, NodeValue val) { |
4974 | if (idx == 0) { LHS_ = val; return; } |
4975 | if (idx == 1) { RHS_ = val; return; } |
4976 | idx -= 2; |
4977 | llvm_unreachable("Invalid index" ); |
4978 | } |
4979 | |
4980 | llvm::StringRef BitwiseAndNode::getOutputName(unsigned idx) const { |
4981 | if (idx == 0) { return "Result" ; } |
4982 | llvm_unreachable("Invalid index" ); |
4983 | } |
4984 | |
4985 | std::string BitwiseAndNode::getDebugDesc() const { |
4986 | DescriptionBuilder db(getKindName()); |
4987 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
4988 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
4989 | db |
4990 | .addParam("LHS" , *(getLHS().getType())) |
4991 | .addParam("RHS" , *(getRHS().getType())) |
4992 | .addParam("Users" , getNumUsers()); |
4993 | db.addParam("Result" , *(getResult().getType())); |
4994 | return db; |
4995 | } |
4996 | |
4997 | void BitwiseAndNode::visit(Node *parent, NodeWalker *visitor) { |
4998 | if (!visitor->shouldVisit(parent, this)) { return; } |
4999 | visitor->pre(parent, this); |
5000 | if (hasPredicate()) |
5001 | getPredicate().getNode()->visit(this, visitor); |
5002 | getLHS().getNode()->visit(this, visitor); |
5003 | getRHS().getNode()->visit(this, visitor); |
5004 | visitor->post(parent, this); |
5005 | } |
5006 | |
5007 | bool BitwiseAndNode::isEqual(const BitwiseAndNode &other) const { |
5008 | return true && |
5009 | LHS_ == other.LHS_ && |
5010 | RHS_ == other.RHS_ && |
5011 | predicate_ == other.predicate_ && |
5012 | getType(0) == other.getType(0); |
5013 | } |
5014 | |
5015 | Node* BitwiseAndNode::clone() const { |
5016 | return new BitwiseAndNode(getName(), getResult().getType(), getLHS(), getRHS()); |
5017 | } |
5018 | |
5019 | llvm::hash_code BitwiseAndNode::getHash() const { |
5020 | return llvm::hash_combine( |
5021 | LHS_, |
5022 | RHS_); |
5023 | } |
5024 | |
5025 | unsigned OrNode::getNumInputs() const { |
5026 | return 2; |
5027 | } |
5028 | |
5029 | std::string OrNode::getInputName(unsigned idx) const { |
5030 | if (idx == 0) { return "LHS" ; } |
5031 | if (idx == 1) { return "RHS" ; } |
5032 | idx -= 2; |
5033 | llvm_unreachable("Invalid index" ); |
5034 | } |
5035 | |
5036 | NodeValue OrNode::getNthInput(unsigned idx) { |
5037 | if (idx == 0) { return LHS_; } |
5038 | if (idx == 1) { return RHS_; } |
5039 | idx -= 2; |
5040 | llvm_unreachable("Invalid index" ); |
5041 | } |
5042 | |
5043 | void OrNode::setNthInput(unsigned idx, NodeValue val) { |
5044 | if (idx == 0) { LHS_ = val; return; } |
5045 | if (idx == 1) { RHS_ = val; return; } |
5046 | idx -= 2; |
5047 | llvm_unreachable("Invalid index" ); |
5048 | } |
5049 | |
5050 | llvm::StringRef OrNode::getOutputName(unsigned idx) const { |
5051 | if (idx == 0) { return "Result" ; } |
5052 | llvm_unreachable("Invalid index" ); |
5053 | } |
5054 | |
5055 | std::string OrNode::getDebugDesc() const { |
5056 | DescriptionBuilder db(getKindName()); |
5057 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5058 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5059 | db |
5060 | .addParam("LHS" , *(getLHS().getType())) |
5061 | .addParam("RHS" , *(getRHS().getType())) |
5062 | .addParam("Users" , getNumUsers()); |
5063 | db.addParam("Result" , *(getResult().getType())); |
5064 | return db; |
5065 | } |
5066 | |
5067 | void OrNode::visit(Node *parent, NodeWalker *visitor) { |
5068 | if (!visitor->shouldVisit(parent, this)) { return; } |
5069 | visitor->pre(parent, this); |
5070 | if (hasPredicate()) |
5071 | getPredicate().getNode()->visit(this, visitor); |
5072 | getLHS().getNode()->visit(this, visitor); |
5073 | getRHS().getNode()->visit(this, visitor); |
5074 | visitor->post(parent, this); |
5075 | } |
5076 | |
5077 | bool OrNode::isEqual(const OrNode &other) const { |
5078 | return true && |
5079 | LHS_ == other.LHS_ && |
5080 | RHS_ == other.RHS_ && |
5081 | predicate_ == other.predicate_ && |
5082 | getType(0) == other.getType(0); |
5083 | } |
5084 | |
5085 | Node* OrNode::clone() const { |
5086 | return new OrNode(getName(), getResult().getType(), getLHS(), getRHS()); |
5087 | } |
5088 | |
5089 | llvm::hash_code OrNode::getHash() const { |
5090 | return llvm::hash_combine( |
5091 | LHS_, |
5092 | RHS_); |
5093 | } |
5094 | |
5095 | unsigned BitwiseOrNode::getNumInputs() const { |
5096 | return 2; |
5097 | } |
5098 | |
5099 | std::string BitwiseOrNode::getInputName(unsigned idx) const { |
5100 | if (idx == 0) { return "LHS" ; } |
5101 | if (idx == 1) { return "RHS" ; } |
5102 | idx -= 2; |
5103 | llvm_unreachable("Invalid index" ); |
5104 | } |
5105 | |
5106 | NodeValue BitwiseOrNode::getNthInput(unsigned idx) { |
5107 | if (idx == 0) { return LHS_; } |
5108 | if (idx == 1) { return RHS_; } |
5109 | idx -= 2; |
5110 | llvm_unreachable("Invalid index" ); |
5111 | } |
5112 | |
5113 | void BitwiseOrNode::setNthInput(unsigned idx, NodeValue val) { |
5114 | if (idx == 0) { LHS_ = val; return; } |
5115 | if (idx == 1) { RHS_ = val; return; } |
5116 | idx -= 2; |
5117 | llvm_unreachable("Invalid index" ); |
5118 | } |
5119 | |
5120 | llvm::StringRef BitwiseOrNode::getOutputName(unsigned idx) const { |
5121 | if (idx == 0) { return "Result" ; } |
5122 | llvm_unreachable("Invalid index" ); |
5123 | } |
5124 | |
5125 | std::string BitwiseOrNode::getDebugDesc() const { |
5126 | DescriptionBuilder db(getKindName()); |
5127 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5128 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5129 | db |
5130 | .addParam("LHS" , *(getLHS().getType())) |
5131 | .addParam("RHS" , *(getRHS().getType())) |
5132 | .addParam("Users" , getNumUsers()); |
5133 | db.addParam("Result" , *(getResult().getType())); |
5134 | return db; |
5135 | } |
5136 | |
5137 | void BitwiseOrNode::visit(Node *parent, NodeWalker *visitor) { |
5138 | if (!visitor->shouldVisit(parent, this)) { return; } |
5139 | visitor->pre(parent, this); |
5140 | if (hasPredicate()) |
5141 | getPredicate().getNode()->visit(this, visitor); |
5142 | getLHS().getNode()->visit(this, visitor); |
5143 | getRHS().getNode()->visit(this, visitor); |
5144 | visitor->post(parent, this); |
5145 | } |
5146 | |
5147 | bool BitwiseOrNode::isEqual(const BitwiseOrNode &other) const { |
5148 | return true && |
5149 | LHS_ == other.LHS_ && |
5150 | RHS_ == other.RHS_ && |
5151 | predicate_ == other.predicate_ && |
5152 | getType(0) == other.getType(0); |
5153 | } |
5154 | |
5155 | Node* BitwiseOrNode::clone() const { |
5156 | return new BitwiseOrNode(getName(), getResult().getType(), getLHS(), getRHS()); |
5157 | } |
5158 | |
5159 | llvm::hash_code BitwiseOrNode::getHash() const { |
5160 | return llvm::hash_combine( |
5161 | LHS_, |
5162 | RHS_); |
5163 | } |
5164 | |
5165 | unsigned XorNode::getNumInputs() const { |
5166 | return 2; |
5167 | } |
5168 | |
5169 | std::string XorNode::getInputName(unsigned idx) const { |
5170 | if (idx == 0) { return "LHS" ; } |
5171 | if (idx == 1) { return "RHS" ; } |
5172 | idx -= 2; |
5173 | llvm_unreachable("Invalid index" ); |
5174 | } |
5175 | |
5176 | NodeValue XorNode::getNthInput(unsigned idx) { |
5177 | if (idx == 0) { return LHS_; } |
5178 | if (idx == 1) { return RHS_; } |
5179 | idx -= 2; |
5180 | llvm_unreachable("Invalid index" ); |
5181 | } |
5182 | |
5183 | void XorNode::setNthInput(unsigned idx, NodeValue val) { |
5184 | if (idx == 0) { LHS_ = val; return; } |
5185 | if (idx == 1) { RHS_ = val; return; } |
5186 | idx -= 2; |
5187 | llvm_unreachable("Invalid index" ); |
5188 | } |
5189 | |
5190 | llvm::StringRef XorNode::getOutputName(unsigned idx) const { |
5191 | if (idx == 0) { return "Result" ; } |
5192 | llvm_unreachable("Invalid index" ); |
5193 | } |
5194 | |
5195 | std::string XorNode::getDebugDesc() const { |
5196 | DescriptionBuilder db(getKindName()); |
5197 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5198 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5199 | db |
5200 | .addParam("LHS" , *(getLHS().getType())) |
5201 | .addParam("RHS" , *(getRHS().getType())) |
5202 | .addParam("Users" , getNumUsers()); |
5203 | db.addParam("Result" , *(getResult().getType())); |
5204 | return db; |
5205 | } |
5206 | |
5207 | void XorNode::visit(Node *parent, NodeWalker *visitor) { |
5208 | if (!visitor->shouldVisit(parent, this)) { return; } |
5209 | visitor->pre(parent, this); |
5210 | if (hasPredicate()) |
5211 | getPredicate().getNode()->visit(this, visitor); |
5212 | getLHS().getNode()->visit(this, visitor); |
5213 | getRHS().getNode()->visit(this, visitor); |
5214 | visitor->post(parent, this); |
5215 | } |
5216 | |
5217 | bool XorNode::isEqual(const XorNode &other) const { |
5218 | return true && |
5219 | LHS_ == other.LHS_ && |
5220 | RHS_ == other.RHS_ && |
5221 | predicate_ == other.predicate_ && |
5222 | getType(0) == other.getType(0); |
5223 | } |
5224 | |
5225 | Node* XorNode::clone() const { |
5226 | return new XorNode(getName(), getResult().getType(), getLHS(), getRHS()); |
5227 | } |
5228 | |
5229 | llvm::hash_code XorNode::getHash() const { |
5230 | return llvm::hash_combine( |
5231 | LHS_, |
5232 | RHS_); |
5233 | } |
5234 | |
5235 | unsigned BitwiseXorNode::getNumInputs() const { |
5236 | return 2; |
5237 | } |
5238 | |
5239 | std::string BitwiseXorNode::getInputName(unsigned idx) const { |
5240 | if (idx == 0) { return "LHS" ; } |
5241 | if (idx == 1) { return "RHS" ; } |
5242 | idx -= 2; |
5243 | llvm_unreachable("Invalid index" ); |
5244 | } |
5245 | |
5246 | NodeValue BitwiseXorNode::getNthInput(unsigned idx) { |
5247 | if (idx == 0) { return LHS_; } |
5248 | if (idx == 1) { return RHS_; } |
5249 | idx -= 2; |
5250 | llvm_unreachable("Invalid index" ); |
5251 | } |
5252 | |
5253 | void BitwiseXorNode::setNthInput(unsigned idx, NodeValue val) { |
5254 | if (idx == 0) { LHS_ = val; return; } |
5255 | if (idx == 1) { RHS_ = val; return; } |
5256 | idx -= 2; |
5257 | llvm_unreachable("Invalid index" ); |
5258 | } |
5259 | |
5260 | llvm::StringRef BitwiseXorNode::getOutputName(unsigned idx) const { |
5261 | if (idx == 0) { return "Result" ; } |
5262 | llvm_unreachable("Invalid index" ); |
5263 | } |
5264 | |
5265 | std::string BitwiseXorNode::getDebugDesc() const { |
5266 | DescriptionBuilder db(getKindName()); |
5267 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5268 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5269 | db |
5270 | .addParam("LHS" , *(getLHS().getType())) |
5271 | .addParam("RHS" , *(getRHS().getType())) |
5272 | .addParam("Users" , getNumUsers()); |
5273 | db.addParam("Result" , *(getResult().getType())); |
5274 | return db; |
5275 | } |
5276 | |
5277 | void BitwiseXorNode::visit(Node *parent, NodeWalker *visitor) { |
5278 | if (!visitor->shouldVisit(parent, this)) { return; } |
5279 | visitor->pre(parent, this); |
5280 | if (hasPredicate()) |
5281 | getPredicate().getNode()->visit(this, visitor); |
5282 | getLHS().getNode()->visit(this, visitor); |
5283 | getRHS().getNode()->visit(this, visitor); |
5284 | visitor->post(parent, this); |
5285 | } |
5286 | |
5287 | bool BitwiseXorNode::isEqual(const BitwiseXorNode &other) const { |
5288 | return true && |
5289 | LHS_ == other.LHS_ && |
5290 | RHS_ == other.RHS_ && |
5291 | predicate_ == other.predicate_ && |
5292 | getType(0) == other.getType(0); |
5293 | } |
5294 | |
5295 | Node* BitwiseXorNode::clone() const { |
5296 | return new BitwiseXorNode(getName(), getResult().getType(), getLHS(), getRHS()); |
5297 | } |
5298 | |
5299 | llvm::hash_code BitwiseXorNode::getHash() const { |
5300 | return llvm::hash_combine( |
5301 | LHS_, |
5302 | RHS_); |
5303 | } |
5304 | |
5305 | unsigned NotNode::getNumInputs() const { |
5306 | return 1; |
5307 | } |
5308 | |
5309 | std::string NotNode::getInputName(unsigned idx) const { |
5310 | if (idx == 0) { return "Input" ; } |
5311 | idx -= 1; |
5312 | llvm_unreachable("Invalid index" ); |
5313 | } |
5314 | |
5315 | NodeValue NotNode::getNthInput(unsigned idx) { |
5316 | if (idx == 0) { return Input_; } |
5317 | idx -= 1; |
5318 | llvm_unreachable("Invalid index" ); |
5319 | } |
5320 | |
5321 | void NotNode::setNthInput(unsigned idx, NodeValue val) { |
5322 | if (idx == 0) { Input_ = val; return; } |
5323 | idx -= 1; |
5324 | llvm_unreachable("Invalid index" ); |
5325 | } |
5326 | |
5327 | llvm::StringRef NotNode::getOutputName(unsigned idx) const { |
5328 | if (idx == 0) { return "Result" ; } |
5329 | llvm_unreachable("Invalid index" ); |
5330 | } |
5331 | |
5332 | std::string NotNode::getDebugDesc() const { |
5333 | DescriptionBuilder db(getKindName()); |
5334 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5335 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5336 | db |
5337 | .addParam("Input" , *(getInput().getType())) |
5338 | .addParam("Users" , getNumUsers()); |
5339 | db.addParam("Result" , *(getResult().getType())); |
5340 | return db; |
5341 | } |
5342 | |
5343 | void NotNode::visit(Node *parent, NodeWalker *visitor) { |
5344 | if (!visitor->shouldVisit(parent, this)) { return; } |
5345 | visitor->pre(parent, this); |
5346 | if (hasPredicate()) |
5347 | getPredicate().getNode()->visit(this, visitor); |
5348 | getInput().getNode()->visit(this, visitor); |
5349 | visitor->post(parent, this); |
5350 | } |
5351 | |
5352 | bool NotNode::isEqual(const NotNode &other) const { |
5353 | return true && |
5354 | Input_ == other.Input_ && |
5355 | predicate_ == other.predicate_ && |
5356 | getType(0) == other.getType(0); |
5357 | } |
5358 | |
5359 | Node* NotNode::clone() const { |
5360 | return new NotNode(getName(), getResult().getType(), getInput()); |
5361 | } |
5362 | |
5363 | llvm::hash_code NotNode::getHash() const { |
5364 | return llvm::hash_combine( |
5365 | Input_); |
5366 | } |
5367 | |
5368 | unsigned BitwiseNotNode::getNumInputs() const { |
5369 | return 1; |
5370 | } |
5371 | |
5372 | std::string BitwiseNotNode::getInputName(unsigned idx) const { |
5373 | if (idx == 0) { return "Input" ; } |
5374 | idx -= 1; |
5375 | llvm_unreachable("Invalid index" ); |
5376 | } |
5377 | |
5378 | NodeValue BitwiseNotNode::getNthInput(unsigned idx) { |
5379 | if (idx == 0) { return Input_; } |
5380 | idx -= 1; |
5381 | llvm_unreachable("Invalid index" ); |
5382 | } |
5383 | |
5384 | void BitwiseNotNode::setNthInput(unsigned idx, NodeValue val) { |
5385 | if (idx == 0) { Input_ = val; return; } |
5386 | idx -= 1; |
5387 | llvm_unreachable("Invalid index" ); |
5388 | } |
5389 | |
5390 | llvm::StringRef BitwiseNotNode::getOutputName(unsigned idx) const { |
5391 | if (idx == 0) { return "Result" ; } |
5392 | llvm_unreachable("Invalid index" ); |
5393 | } |
5394 | |
5395 | std::string BitwiseNotNode::getDebugDesc() const { |
5396 | DescriptionBuilder db(getKindName()); |
5397 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5398 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5399 | db |
5400 | .addParam("Input" , *(getInput().getType())) |
5401 | .addParam("Users" , getNumUsers()); |
5402 | db.addParam("Result" , *(getResult().getType())); |
5403 | return db; |
5404 | } |
5405 | |
5406 | void BitwiseNotNode::visit(Node *parent, NodeWalker *visitor) { |
5407 | if (!visitor->shouldVisit(parent, this)) { return; } |
5408 | visitor->pre(parent, this); |
5409 | if (hasPredicate()) |
5410 | getPredicate().getNode()->visit(this, visitor); |
5411 | getInput().getNode()->visit(this, visitor); |
5412 | visitor->post(parent, this); |
5413 | } |
5414 | |
5415 | bool BitwiseNotNode::isEqual(const BitwiseNotNode &other) const { |
5416 | return true && |
5417 | Input_ == other.Input_ && |
5418 | predicate_ == other.predicate_ && |
5419 | getType(0) == other.getType(0); |
5420 | } |
5421 | |
5422 | Node* BitwiseNotNode::clone() const { |
5423 | return new BitwiseNotNode(getName(), getResult().getType(), getInput()); |
5424 | } |
5425 | |
5426 | llvm::hash_code BitwiseNotNode::getHash() const { |
5427 | return llvm::hash_combine( |
5428 | Input_); |
5429 | } |
5430 | |
5431 | unsigned NegNode::getNumInputs() const { |
5432 | return 1; |
5433 | } |
5434 | |
5435 | std::string NegNode::getInputName(unsigned idx) const { |
5436 | if (idx == 0) { return "Input" ; } |
5437 | idx -= 1; |
5438 | llvm_unreachable("Invalid index" ); |
5439 | } |
5440 | |
5441 | NodeValue NegNode::getNthInput(unsigned idx) { |
5442 | if (idx == 0) { return Input_; } |
5443 | idx -= 1; |
5444 | llvm_unreachable("Invalid index" ); |
5445 | } |
5446 | |
5447 | void NegNode::setNthInput(unsigned idx, NodeValue val) { |
5448 | if (idx == 0) { Input_ = val; return; } |
5449 | idx -= 1; |
5450 | llvm_unreachable("Invalid index" ); |
5451 | } |
5452 | |
5453 | llvm::StringRef NegNode::getOutputName(unsigned idx) const { |
5454 | if (idx == 0) { return "Result" ; } |
5455 | llvm_unreachable("Invalid index" ); |
5456 | } |
5457 | |
5458 | std::string NegNode::getDebugDesc() const { |
5459 | DescriptionBuilder db(getKindName()); |
5460 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5461 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5462 | db |
5463 | .addParam("Input" , *(getInput().getType())) |
5464 | .addParam("Users" , getNumUsers()); |
5465 | db.addParam("Result" , *(getResult().getType())); |
5466 | return db; |
5467 | } |
5468 | |
5469 | void NegNode::visit(Node *parent, NodeWalker *visitor) { |
5470 | if (!visitor->shouldVisit(parent, this)) { return; } |
5471 | visitor->pre(parent, this); |
5472 | if (hasPredicate()) |
5473 | getPredicate().getNode()->visit(this, visitor); |
5474 | getInput().getNode()->visit(this, visitor); |
5475 | visitor->post(parent, this); |
5476 | } |
5477 | |
5478 | bool NegNode::isEqual(const NegNode &other) const { |
5479 | return true && |
5480 | Input_ == other.Input_ && |
5481 | predicate_ == other.predicate_ && |
5482 | getType(0) == other.getType(0); |
5483 | } |
5484 | |
5485 | Node* NegNode::clone() const { |
5486 | return new NegNode(getName(), getResult().getType(), getInput()); |
5487 | } |
5488 | |
5489 | llvm::hash_code NegNode::getHash() const { |
5490 | return llvm::hash_combine( |
5491 | Input_); |
5492 | } |
5493 | |
5494 | unsigned AbsNode::getNumInputs() const { |
5495 | return 1; |
5496 | } |
5497 | |
5498 | std::string AbsNode::getInputName(unsigned idx) const { |
5499 | if (idx == 0) { return "Input" ; } |
5500 | idx -= 1; |
5501 | llvm_unreachable("Invalid index" ); |
5502 | } |
5503 | |
5504 | NodeValue AbsNode::getNthInput(unsigned idx) { |
5505 | if (idx == 0) { return Input_; } |
5506 | idx -= 1; |
5507 | llvm_unreachable("Invalid index" ); |
5508 | } |
5509 | |
5510 | void AbsNode::setNthInput(unsigned idx, NodeValue val) { |
5511 | if (idx == 0) { Input_ = val; return; } |
5512 | idx -= 1; |
5513 | llvm_unreachable("Invalid index" ); |
5514 | } |
5515 | |
5516 | llvm::StringRef AbsNode::getOutputName(unsigned idx) const { |
5517 | if (idx == 0) { return "Result" ; } |
5518 | llvm_unreachable("Invalid index" ); |
5519 | } |
5520 | |
5521 | std::string AbsNode::getDebugDesc() const { |
5522 | DescriptionBuilder db(getKindName()); |
5523 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5524 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5525 | db |
5526 | .addParam("Input" , *(getInput().getType())) |
5527 | .addParam("Users" , getNumUsers()); |
5528 | db.addParam("Result" , *(getResult().getType())); |
5529 | return db; |
5530 | } |
5531 | |
5532 | void AbsNode::visit(Node *parent, NodeWalker *visitor) { |
5533 | if (!visitor->shouldVisit(parent, this)) { return; } |
5534 | visitor->pre(parent, this); |
5535 | if (hasPredicate()) |
5536 | getPredicate().getNode()->visit(this, visitor); |
5537 | getInput().getNode()->visit(this, visitor); |
5538 | visitor->post(parent, this); |
5539 | } |
5540 | |
5541 | bool AbsNode::isEqual(const AbsNode &other) const { |
5542 | return true && |
5543 | Input_ == other.Input_ && |
5544 | predicate_ == other.predicate_ && |
5545 | getType(0) == other.getType(0); |
5546 | } |
5547 | |
5548 | Node* AbsNode::clone() const { |
5549 | return new AbsNode(getName(), getResult().getType(), getInput()); |
5550 | } |
5551 | |
5552 | llvm::hash_code AbsNode::getHash() const { |
5553 | return llvm::hash_combine( |
5554 | Input_); |
5555 | } |
5556 | |
5557 | unsigned FloorNode::getNumInputs() const { |
5558 | return 1; |
5559 | } |
5560 | |
5561 | std::string FloorNode::getInputName(unsigned idx) const { |
5562 | if (idx == 0) { return "Input" ; } |
5563 | idx -= 1; |
5564 | llvm_unreachable("Invalid index" ); |
5565 | } |
5566 | |
5567 | NodeValue FloorNode::getNthInput(unsigned idx) { |
5568 | if (idx == 0) { return Input_; } |
5569 | idx -= 1; |
5570 | llvm_unreachable("Invalid index" ); |
5571 | } |
5572 | |
5573 | void FloorNode::setNthInput(unsigned idx, NodeValue val) { |
5574 | if (idx == 0) { Input_ = val; return; } |
5575 | idx -= 1; |
5576 | llvm_unreachable("Invalid index" ); |
5577 | } |
5578 | |
5579 | llvm::StringRef FloorNode::getOutputName(unsigned idx) const { |
5580 | if (idx == 0) { return "Result" ; } |
5581 | llvm_unreachable("Invalid index" ); |
5582 | } |
5583 | |
5584 | std::string FloorNode::getDebugDesc() const { |
5585 | DescriptionBuilder db(getKindName()); |
5586 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5587 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5588 | db |
5589 | .addParam("Input" , *(getInput().getType())) |
5590 | .addParam("Users" , getNumUsers()); |
5591 | db.addParam("Result" , *(getResult().getType())); |
5592 | return db; |
5593 | } |
5594 | |
5595 | void FloorNode::visit(Node *parent, NodeWalker *visitor) { |
5596 | if (!visitor->shouldVisit(parent, this)) { return; } |
5597 | visitor->pre(parent, this); |
5598 | if (hasPredicate()) |
5599 | getPredicate().getNode()->visit(this, visitor); |
5600 | getInput().getNode()->visit(this, visitor); |
5601 | visitor->post(parent, this); |
5602 | } |
5603 | |
5604 | bool FloorNode::isEqual(const FloorNode &other) const { |
5605 | return true && |
5606 | Input_ == other.Input_ && |
5607 | predicate_ == other.predicate_ && |
5608 | getType(0) == other.getType(0); |
5609 | } |
5610 | |
5611 | Node* FloorNode::clone() const { |
5612 | return new FloorNode(getName(), getResult().getType(), getInput()); |
5613 | } |
5614 | |
5615 | llvm::hash_code FloorNode::getHash() const { |
5616 | return llvm::hash_combine( |
5617 | Input_); |
5618 | } |
5619 | |
5620 | unsigned SignNode::getNumInputs() const { |
5621 | return 1; |
5622 | } |
5623 | |
5624 | std::string SignNode::getInputName(unsigned idx) const { |
5625 | if (idx == 0) { return "Input" ; } |
5626 | idx -= 1; |
5627 | llvm_unreachable("Invalid index" ); |
5628 | } |
5629 | |
5630 | NodeValue SignNode::getNthInput(unsigned idx) { |
5631 | if (idx == 0) { return Input_; } |
5632 | idx -= 1; |
5633 | llvm_unreachable("Invalid index" ); |
5634 | } |
5635 | |
5636 | void SignNode::setNthInput(unsigned idx, NodeValue val) { |
5637 | if (idx == 0) { Input_ = val; return; } |
5638 | idx -= 1; |
5639 | llvm_unreachable("Invalid index" ); |
5640 | } |
5641 | |
5642 | llvm::StringRef SignNode::getOutputName(unsigned idx) const { |
5643 | if (idx == 0) { return "Result" ; } |
5644 | llvm_unreachable("Invalid index" ); |
5645 | } |
5646 | |
5647 | std::string SignNode::getDebugDesc() const { |
5648 | DescriptionBuilder db(getKindName()); |
5649 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5650 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5651 | db |
5652 | .addParam("Input" , *(getInput().getType())) |
5653 | .addParam("Users" , getNumUsers()); |
5654 | db.addParam("Result" , *(getResult().getType())); |
5655 | return db; |
5656 | } |
5657 | |
5658 | void SignNode::visit(Node *parent, NodeWalker *visitor) { |
5659 | if (!visitor->shouldVisit(parent, this)) { return; } |
5660 | visitor->pre(parent, this); |
5661 | if (hasPredicate()) |
5662 | getPredicate().getNode()->visit(this, visitor); |
5663 | getInput().getNode()->visit(this, visitor); |
5664 | visitor->post(parent, this); |
5665 | } |
5666 | |
5667 | bool SignNode::isEqual(const SignNode &other) const { |
5668 | return true && |
5669 | Input_ == other.Input_ && |
5670 | predicate_ == other.predicate_ && |
5671 | getType(0) == other.getType(0); |
5672 | } |
5673 | |
5674 | Node* SignNode::clone() const { |
5675 | return new SignNode(getName(), getResult().getType(), getInput()); |
5676 | } |
5677 | |
5678 | llvm::hash_code SignNode::getHash() const { |
5679 | return llvm::hash_combine( |
5680 | Input_); |
5681 | } |
5682 | |
5683 | unsigned CeilNode::getNumInputs() const { |
5684 | return 1; |
5685 | } |
5686 | |
5687 | std::string CeilNode::getInputName(unsigned idx) const { |
5688 | if (idx == 0) { return "Input" ; } |
5689 | idx -= 1; |
5690 | llvm_unreachable("Invalid index" ); |
5691 | } |
5692 | |
5693 | NodeValue CeilNode::getNthInput(unsigned idx) { |
5694 | if (idx == 0) { return Input_; } |
5695 | idx -= 1; |
5696 | llvm_unreachable("Invalid index" ); |
5697 | } |
5698 | |
5699 | void CeilNode::setNthInput(unsigned idx, NodeValue val) { |
5700 | if (idx == 0) { Input_ = val; return; } |
5701 | idx -= 1; |
5702 | llvm_unreachable("Invalid index" ); |
5703 | } |
5704 | |
5705 | llvm::StringRef CeilNode::getOutputName(unsigned idx) const { |
5706 | if (idx == 0) { return "Result" ; } |
5707 | llvm_unreachable("Invalid index" ); |
5708 | } |
5709 | |
5710 | std::string CeilNode::getDebugDesc() const { |
5711 | DescriptionBuilder db(getKindName()); |
5712 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5713 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5714 | db |
5715 | .addParam("Input" , *(getInput().getType())) |
5716 | .addParam("Users" , getNumUsers()); |
5717 | db.addParam("Result" , *(getResult().getType())); |
5718 | return db; |
5719 | } |
5720 | |
5721 | void CeilNode::visit(Node *parent, NodeWalker *visitor) { |
5722 | if (!visitor->shouldVisit(parent, this)) { return; } |
5723 | visitor->pre(parent, this); |
5724 | if (hasPredicate()) |
5725 | getPredicate().getNode()->visit(this, visitor); |
5726 | getInput().getNode()->visit(this, visitor); |
5727 | visitor->post(parent, this); |
5728 | } |
5729 | |
5730 | bool CeilNode::isEqual(const CeilNode &other) const { |
5731 | return true && |
5732 | Input_ == other.Input_ && |
5733 | predicate_ == other.predicate_ && |
5734 | getType(0) == other.getType(0); |
5735 | } |
5736 | |
5737 | Node* CeilNode::clone() const { |
5738 | return new CeilNode(getName(), getResult().getType(), getInput()); |
5739 | } |
5740 | |
5741 | llvm::hash_code CeilNode::getHash() const { |
5742 | return llvm::hash_combine( |
5743 | Input_); |
5744 | } |
5745 | |
5746 | unsigned RoundNode::getNumInputs() const { |
5747 | return 1; |
5748 | } |
5749 | |
5750 | std::string RoundNode::getInputName(unsigned idx) const { |
5751 | if (idx == 0) { return "Input" ; } |
5752 | idx -= 1; |
5753 | llvm_unreachable("Invalid index" ); |
5754 | } |
5755 | |
5756 | NodeValue RoundNode::getNthInput(unsigned idx) { |
5757 | if (idx == 0) { return Input_; } |
5758 | idx -= 1; |
5759 | llvm_unreachable("Invalid index" ); |
5760 | } |
5761 | |
5762 | void RoundNode::setNthInput(unsigned idx, NodeValue val) { |
5763 | if (idx == 0) { Input_ = val; return; } |
5764 | idx -= 1; |
5765 | llvm_unreachable("Invalid index" ); |
5766 | } |
5767 | |
5768 | llvm::StringRef RoundNode::getOutputName(unsigned idx) const { |
5769 | if (idx == 0) { return "Result" ; } |
5770 | llvm_unreachable("Invalid index" ); |
5771 | } |
5772 | |
5773 | std::string RoundNode::getDebugDesc() const { |
5774 | DescriptionBuilder db(getKindName()); |
5775 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5776 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5777 | db |
5778 | .addParam("Input" , *(getInput().getType())) |
5779 | .addParam("Users" , getNumUsers()); |
5780 | db.addParam("Result" , *(getResult().getType())); |
5781 | return db; |
5782 | } |
5783 | |
5784 | void RoundNode::visit(Node *parent, NodeWalker *visitor) { |
5785 | if (!visitor->shouldVisit(parent, this)) { return; } |
5786 | visitor->pre(parent, this); |
5787 | if (hasPredicate()) |
5788 | getPredicate().getNode()->visit(this, visitor); |
5789 | getInput().getNode()->visit(this, visitor); |
5790 | visitor->post(parent, this); |
5791 | } |
5792 | |
5793 | bool RoundNode::isEqual(const RoundNode &other) const { |
5794 | return true && |
5795 | Input_ == other.Input_ && |
5796 | predicate_ == other.predicate_ && |
5797 | getType(0) == other.getType(0); |
5798 | } |
5799 | |
5800 | Node* RoundNode::clone() const { |
5801 | return new RoundNode(getName(), getResult().getType(), getInput()); |
5802 | } |
5803 | |
5804 | llvm::hash_code RoundNode::getHash() const { |
5805 | return llvm::hash_combine( |
5806 | Input_); |
5807 | } |
5808 | |
5809 | unsigned TruncateNode::getNumInputs() const { |
5810 | return 1; |
5811 | } |
5812 | |
5813 | std::string TruncateNode::getInputName(unsigned idx) const { |
5814 | if (idx == 0) { return "Input" ; } |
5815 | idx -= 1; |
5816 | llvm_unreachable("Invalid index" ); |
5817 | } |
5818 | |
5819 | NodeValue TruncateNode::getNthInput(unsigned idx) { |
5820 | if (idx == 0) { return Input_; } |
5821 | idx -= 1; |
5822 | llvm_unreachable("Invalid index" ); |
5823 | } |
5824 | |
5825 | void TruncateNode::setNthInput(unsigned idx, NodeValue val) { |
5826 | if (idx == 0) { Input_ = val; return; } |
5827 | idx -= 1; |
5828 | llvm_unreachable("Invalid index" ); |
5829 | } |
5830 | |
5831 | llvm::StringRef TruncateNode::getOutputName(unsigned idx) const { |
5832 | if (idx == 0) { return "Result" ; } |
5833 | llvm_unreachable("Invalid index" ); |
5834 | } |
5835 | |
5836 | std::string TruncateNode::getDebugDesc() const { |
5837 | DescriptionBuilder db(getKindName()); |
5838 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5839 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5840 | db |
5841 | .addParam("Input" , *(getInput().getType())) |
5842 | .addParam("Users" , getNumUsers()); |
5843 | db.addParam("Result" , *(getResult().getType())); |
5844 | return db; |
5845 | } |
5846 | |
5847 | void TruncateNode::visit(Node *parent, NodeWalker *visitor) { |
5848 | if (!visitor->shouldVisit(parent, this)) { return; } |
5849 | visitor->pre(parent, this); |
5850 | if (hasPredicate()) |
5851 | getPredicate().getNode()->visit(this, visitor); |
5852 | getInput().getNode()->visit(this, visitor); |
5853 | visitor->post(parent, this); |
5854 | } |
5855 | |
5856 | bool TruncateNode::isEqual(const TruncateNode &other) const { |
5857 | return true && |
5858 | Input_ == other.Input_ && |
5859 | predicate_ == other.predicate_ && |
5860 | getType(0) == other.getType(0); |
5861 | } |
5862 | |
5863 | Node* TruncateNode::clone() const { |
5864 | return new TruncateNode(getName(), getResult().getType(), getInput()); |
5865 | } |
5866 | |
5867 | llvm::hash_code TruncateNode::getHash() const { |
5868 | return llvm::hash_combine( |
5869 | Input_); |
5870 | } |
5871 | |
5872 | unsigned SqrtNode::getNumInputs() const { |
5873 | return 1; |
5874 | } |
5875 | |
5876 | std::string SqrtNode::getInputName(unsigned idx) const { |
5877 | if (idx == 0) { return "Input" ; } |
5878 | idx -= 1; |
5879 | llvm_unreachable("Invalid index" ); |
5880 | } |
5881 | |
5882 | NodeValue SqrtNode::getNthInput(unsigned idx) { |
5883 | if (idx == 0) { return Input_; } |
5884 | idx -= 1; |
5885 | llvm_unreachable("Invalid index" ); |
5886 | } |
5887 | |
5888 | void SqrtNode::setNthInput(unsigned idx, NodeValue val) { |
5889 | if (idx == 0) { Input_ = val; return; } |
5890 | idx -= 1; |
5891 | llvm_unreachable("Invalid index" ); |
5892 | } |
5893 | |
5894 | llvm::StringRef SqrtNode::getOutputName(unsigned idx) const { |
5895 | if (idx == 0) { return "Result" ; } |
5896 | llvm_unreachable("Invalid index" ); |
5897 | } |
5898 | |
5899 | std::string SqrtNode::getDebugDesc() const { |
5900 | DescriptionBuilder db(getKindName()); |
5901 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5902 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5903 | db |
5904 | .addParam("Input" , *(getInput().getType())) |
5905 | .addParam("Users" , getNumUsers()); |
5906 | db.addParam("Result" , *(getResult().getType())); |
5907 | return db; |
5908 | } |
5909 | |
5910 | void SqrtNode::visit(Node *parent, NodeWalker *visitor) { |
5911 | if (!visitor->shouldVisit(parent, this)) { return; } |
5912 | visitor->pre(parent, this); |
5913 | if (hasPredicate()) |
5914 | getPredicate().getNode()->visit(this, visitor); |
5915 | getInput().getNode()->visit(this, visitor); |
5916 | visitor->post(parent, this); |
5917 | } |
5918 | |
5919 | bool SqrtNode::isEqual(const SqrtNode &other) const { |
5920 | return true && |
5921 | Input_ == other.Input_ && |
5922 | predicate_ == other.predicate_ && |
5923 | getType(0) == other.getType(0); |
5924 | } |
5925 | |
5926 | Node* SqrtNode::clone() const { |
5927 | return new SqrtNode(getName(), getResult().getType(), getInput()); |
5928 | } |
5929 | |
5930 | llvm::hash_code SqrtNode::getHash() const { |
5931 | return llvm::hash_combine( |
5932 | Input_); |
5933 | } |
5934 | |
5935 | unsigned RsqrtNode::getNumInputs() const { |
5936 | return 1; |
5937 | } |
5938 | |
5939 | std::string RsqrtNode::getInputName(unsigned idx) const { |
5940 | if (idx == 0) { return "Input" ; } |
5941 | idx -= 1; |
5942 | llvm_unreachable("Invalid index" ); |
5943 | } |
5944 | |
5945 | NodeValue RsqrtNode::getNthInput(unsigned idx) { |
5946 | if (idx == 0) { return Input_; } |
5947 | idx -= 1; |
5948 | llvm_unreachable("Invalid index" ); |
5949 | } |
5950 | |
5951 | void RsqrtNode::setNthInput(unsigned idx, NodeValue val) { |
5952 | if (idx == 0) { Input_ = val; return; } |
5953 | idx -= 1; |
5954 | llvm_unreachable("Invalid index" ); |
5955 | } |
5956 | |
5957 | llvm::StringRef RsqrtNode::getOutputName(unsigned idx) const { |
5958 | if (idx == 0) { return "Result" ; } |
5959 | llvm_unreachable("Invalid index" ); |
5960 | } |
5961 | |
5962 | std::string RsqrtNode::getDebugDesc() const { |
5963 | DescriptionBuilder db(getKindName()); |
5964 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
5965 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
5966 | db |
5967 | .addParam("Input" , *(getInput().getType())) |
5968 | .addParam("Users" , getNumUsers()); |
5969 | db.addParam("Result" , *(getResult().getType())); |
5970 | return db; |
5971 | } |
5972 | |
5973 | void RsqrtNode::visit(Node *parent, NodeWalker *visitor) { |
5974 | if (!visitor->shouldVisit(parent, this)) { return; } |
5975 | visitor->pre(parent, this); |
5976 | if (hasPredicate()) |
5977 | getPredicate().getNode()->visit(this, visitor); |
5978 | getInput().getNode()->visit(this, visitor); |
5979 | visitor->post(parent, this); |
5980 | } |
5981 | |
5982 | bool RsqrtNode::isEqual(const RsqrtNode &other) const { |
5983 | return true && |
5984 | Input_ == other.Input_ && |
5985 | predicate_ == other.predicate_ && |
5986 | getType(0) == other.getType(0); |
5987 | } |
5988 | |
5989 | Node* RsqrtNode::clone() const { |
5990 | return new RsqrtNode(getName(), getResult().getType(), getInput()); |
5991 | } |
5992 | |
5993 | llvm::hash_code RsqrtNode::getHash() const { |
5994 | return llvm::hash_combine( |
5995 | Input_); |
5996 | } |
5997 | |
5998 | unsigned ReciprocalNode::getNumInputs() const { |
5999 | return 1; |
6000 | } |
6001 | |
6002 | std::string ReciprocalNode::getInputName(unsigned idx) const { |
6003 | if (idx == 0) { return "Input" ; } |
6004 | idx -= 1; |
6005 | llvm_unreachable("Invalid index" ); |
6006 | } |
6007 | |
6008 | NodeValue ReciprocalNode::getNthInput(unsigned idx) { |
6009 | if (idx == 0) { return Input_; } |
6010 | idx -= 1; |
6011 | llvm_unreachable("Invalid index" ); |
6012 | } |
6013 | |
6014 | void ReciprocalNode::setNthInput(unsigned idx, NodeValue val) { |
6015 | if (idx == 0) { Input_ = val; return; } |
6016 | idx -= 1; |
6017 | llvm_unreachable("Invalid index" ); |
6018 | } |
6019 | |
6020 | llvm::StringRef ReciprocalNode::getOutputName(unsigned idx) const { |
6021 | if (idx == 0) { return "Result" ; } |
6022 | llvm_unreachable("Invalid index" ); |
6023 | } |
6024 | |
6025 | std::string ReciprocalNode::getDebugDesc() const { |
6026 | DescriptionBuilder db(getKindName()); |
6027 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6028 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6029 | db |
6030 | .addParam("Input" , *(getInput().getType())) |
6031 | .addParam("Users" , getNumUsers()); |
6032 | db.addParam("Result" , *(getResult().getType())); |
6033 | return db; |
6034 | } |
6035 | |
6036 | void ReciprocalNode::visit(Node *parent, NodeWalker *visitor) { |
6037 | if (!visitor->shouldVisit(parent, this)) { return; } |
6038 | visitor->pre(parent, this); |
6039 | if (hasPredicate()) |
6040 | getPredicate().getNode()->visit(this, visitor); |
6041 | getInput().getNode()->visit(this, visitor); |
6042 | visitor->post(parent, this); |
6043 | } |
6044 | |
6045 | bool ReciprocalNode::isEqual(const ReciprocalNode &other) const { |
6046 | return true && |
6047 | Input_ == other.Input_ && |
6048 | predicate_ == other.predicate_ && |
6049 | getType(0) == other.getType(0); |
6050 | } |
6051 | |
6052 | Node* ReciprocalNode::clone() const { |
6053 | return new ReciprocalNode(getName(), getResult().getType(), getInput()); |
6054 | } |
6055 | |
6056 | llvm::hash_code ReciprocalNode::getHash() const { |
6057 | return llvm::hash_combine( |
6058 | Input_); |
6059 | } |
6060 | |
6061 | unsigned SinNode::getNumInputs() const { |
6062 | return 1; |
6063 | } |
6064 | |
6065 | std::string SinNode::getInputName(unsigned idx) const { |
6066 | if (idx == 0) { return "Input" ; } |
6067 | idx -= 1; |
6068 | llvm_unreachable("Invalid index" ); |
6069 | } |
6070 | |
6071 | NodeValue SinNode::getNthInput(unsigned idx) { |
6072 | if (idx == 0) { return Input_; } |
6073 | idx -= 1; |
6074 | llvm_unreachable("Invalid index" ); |
6075 | } |
6076 | |
6077 | void SinNode::setNthInput(unsigned idx, NodeValue val) { |
6078 | if (idx == 0) { Input_ = val; return; } |
6079 | idx -= 1; |
6080 | llvm_unreachable("Invalid index" ); |
6081 | } |
6082 | |
6083 | llvm::StringRef SinNode::getOutputName(unsigned idx) const { |
6084 | if (idx == 0) { return "Result" ; } |
6085 | llvm_unreachable("Invalid index" ); |
6086 | } |
6087 | |
6088 | std::string SinNode::getDebugDesc() const { |
6089 | DescriptionBuilder db(getKindName()); |
6090 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6091 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6092 | db |
6093 | .addParam("Input" , *(getInput().getType())) |
6094 | .addParam("Users" , getNumUsers()); |
6095 | db.addParam("Result" , *(getResult().getType())); |
6096 | return db; |
6097 | } |
6098 | |
6099 | void SinNode::visit(Node *parent, NodeWalker *visitor) { |
6100 | if (!visitor->shouldVisit(parent, this)) { return; } |
6101 | visitor->pre(parent, this); |
6102 | if (hasPredicate()) |
6103 | getPredicate().getNode()->visit(this, visitor); |
6104 | getInput().getNode()->visit(this, visitor); |
6105 | visitor->post(parent, this); |
6106 | } |
6107 | |
6108 | bool SinNode::isEqual(const SinNode &other) const { |
6109 | return true && |
6110 | Input_ == other.Input_ && |
6111 | predicate_ == other.predicate_ && |
6112 | getType(0) == other.getType(0); |
6113 | } |
6114 | |
6115 | Node* SinNode::clone() const { |
6116 | return new SinNode(getName(), getResult().getType(), getInput()); |
6117 | } |
6118 | |
6119 | llvm::hash_code SinNode::getHash() const { |
6120 | return llvm::hash_combine( |
6121 | Input_); |
6122 | } |
6123 | |
6124 | unsigned CosNode::getNumInputs() const { |
6125 | return 1; |
6126 | } |
6127 | |
6128 | std::string CosNode::getInputName(unsigned idx) const { |
6129 | if (idx == 0) { return "Input" ; } |
6130 | idx -= 1; |
6131 | llvm_unreachable("Invalid index" ); |
6132 | } |
6133 | |
6134 | NodeValue CosNode::getNthInput(unsigned idx) { |
6135 | if (idx == 0) { return Input_; } |
6136 | idx -= 1; |
6137 | llvm_unreachable("Invalid index" ); |
6138 | } |
6139 | |
6140 | void CosNode::setNthInput(unsigned idx, NodeValue val) { |
6141 | if (idx == 0) { Input_ = val; return; } |
6142 | idx -= 1; |
6143 | llvm_unreachable("Invalid index" ); |
6144 | } |
6145 | |
6146 | llvm::StringRef CosNode::getOutputName(unsigned idx) const { |
6147 | if (idx == 0) { return "Result" ; } |
6148 | llvm_unreachable("Invalid index" ); |
6149 | } |
6150 | |
6151 | std::string CosNode::getDebugDesc() const { |
6152 | DescriptionBuilder db(getKindName()); |
6153 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6154 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6155 | db |
6156 | .addParam("Input" , *(getInput().getType())) |
6157 | .addParam("Users" , getNumUsers()); |
6158 | db.addParam("Result" , *(getResult().getType())); |
6159 | return db; |
6160 | } |
6161 | |
6162 | void CosNode::visit(Node *parent, NodeWalker *visitor) { |
6163 | if (!visitor->shouldVisit(parent, this)) { return; } |
6164 | visitor->pre(parent, this); |
6165 | if (hasPredicate()) |
6166 | getPredicate().getNode()->visit(this, visitor); |
6167 | getInput().getNode()->visit(this, visitor); |
6168 | visitor->post(parent, this); |
6169 | } |
6170 | |
6171 | bool CosNode::isEqual(const CosNode &other) const { |
6172 | return true && |
6173 | Input_ == other.Input_ && |
6174 | predicate_ == other.predicate_ && |
6175 | getType(0) == other.getType(0); |
6176 | } |
6177 | |
6178 | Node* CosNode::clone() const { |
6179 | return new CosNode(getName(), getResult().getType(), getInput()); |
6180 | } |
6181 | |
6182 | llvm::hash_code CosNode::getHash() const { |
6183 | return llvm::hash_combine( |
6184 | Input_); |
6185 | } |
6186 | |
6187 | unsigned LogNode::getNumInputs() const { |
6188 | return 1; |
6189 | } |
6190 | |
6191 | std::string LogNode::getInputName(unsigned idx) const { |
6192 | if (idx == 0) { return "Input" ; } |
6193 | idx -= 1; |
6194 | llvm_unreachable("Invalid index" ); |
6195 | } |
6196 | |
6197 | NodeValue LogNode::getNthInput(unsigned idx) { |
6198 | if (idx == 0) { return Input_; } |
6199 | idx -= 1; |
6200 | llvm_unreachable("Invalid index" ); |
6201 | } |
6202 | |
6203 | void LogNode::setNthInput(unsigned idx, NodeValue val) { |
6204 | if (idx == 0) { Input_ = val; return; } |
6205 | idx -= 1; |
6206 | llvm_unreachable("Invalid index" ); |
6207 | } |
6208 | |
6209 | llvm::StringRef LogNode::getOutputName(unsigned idx) const { |
6210 | if (idx == 0) { return "Result" ; } |
6211 | llvm_unreachable("Invalid index" ); |
6212 | } |
6213 | |
6214 | std::string LogNode::getDebugDesc() const { |
6215 | DescriptionBuilder db(getKindName()); |
6216 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6217 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6218 | db |
6219 | .addParam("Input" , *(getInput().getType())) |
6220 | .addParam("Users" , getNumUsers()); |
6221 | db.addParam("Result" , *(getResult().getType())); |
6222 | return db; |
6223 | } |
6224 | |
6225 | void LogNode::visit(Node *parent, NodeWalker *visitor) { |
6226 | if (!visitor->shouldVisit(parent, this)) { return; } |
6227 | visitor->pre(parent, this); |
6228 | if (hasPredicate()) |
6229 | getPredicate().getNode()->visit(this, visitor); |
6230 | getInput().getNode()->visit(this, visitor); |
6231 | visitor->post(parent, this); |
6232 | } |
6233 | |
6234 | bool LogNode::isEqual(const LogNode &other) const { |
6235 | return true && |
6236 | Input_ == other.Input_ && |
6237 | predicate_ == other.predicate_ && |
6238 | getType(0) == other.getType(0); |
6239 | } |
6240 | |
6241 | Node* LogNode::clone() const { |
6242 | return new LogNode(getName(), getResult().getType(), getInput()); |
6243 | } |
6244 | |
6245 | llvm::hash_code LogNode::getHash() const { |
6246 | return llvm::hash_combine( |
6247 | Input_); |
6248 | } |
6249 | |
6250 | unsigned AcosNode::getNumInputs() const { |
6251 | return 1; |
6252 | } |
6253 | |
6254 | std::string AcosNode::getInputName(unsigned idx) const { |
6255 | if (idx == 0) { return "Input" ; } |
6256 | idx -= 1; |
6257 | llvm_unreachable("Invalid index" ); |
6258 | } |
6259 | |
6260 | NodeValue AcosNode::getNthInput(unsigned idx) { |
6261 | if (idx == 0) { return Input_; } |
6262 | idx -= 1; |
6263 | llvm_unreachable("Invalid index" ); |
6264 | } |
6265 | |
6266 | void AcosNode::setNthInput(unsigned idx, NodeValue val) { |
6267 | if (idx == 0) { Input_ = val; return; } |
6268 | idx -= 1; |
6269 | llvm_unreachable("Invalid index" ); |
6270 | } |
6271 | |
6272 | llvm::StringRef AcosNode::getOutputName(unsigned idx) const { |
6273 | if (idx == 0) { return "Result" ; } |
6274 | llvm_unreachable("Invalid index" ); |
6275 | } |
6276 | |
6277 | std::string AcosNode::getDebugDesc() const { |
6278 | DescriptionBuilder db(getKindName()); |
6279 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6280 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6281 | db |
6282 | .addParam("Input" , *(getInput().getType())) |
6283 | .addParam("Users" , getNumUsers()); |
6284 | db.addParam("Result" , *(getResult().getType())); |
6285 | return db; |
6286 | } |
6287 | |
6288 | void AcosNode::visit(Node *parent, NodeWalker *visitor) { |
6289 | if (!visitor->shouldVisit(parent, this)) { return; } |
6290 | visitor->pre(parent, this); |
6291 | if (hasPredicate()) |
6292 | getPredicate().getNode()->visit(this, visitor); |
6293 | getInput().getNode()->visit(this, visitor); |
6294 | visitor->post(parent, this); |
6295 | } |
6296 | |
6297 | bool AcosNode::isEqual(const AcosNode &other) const { |
6298 | return true && |
6299 | Input_ == other.Input_ && |
6300 | predicate_ == other.predicate_ && |
6301 | getType(0) == other.getType(0); |
6302 | } |
6303 | |
6304 | Node* AcosNode::clone() const { |
6305 | return new AcosNode(getName(), getResult().getType(), getInput()); |
6306 | } |
6307 | |
6308 | llvm::hash_code AcosNode::getHash() const { |
6309 | return llvm::hash_combine( |
6310 | Input_); |
6311 | } |
6312 | |
6313 | unsigned AsinNode::getNumInputs() const { |
6314 | return 1; |
6315 | } |
6316 | |
6317 | std::string AsinNode::getInputName(unsigned idx) const { |
6318 | if (idx == 0) { return "Input" ; } |
6319 | idx -= 1; |
6320 | llvm_unreachable("Invalid index" ); |
6321 | } |
6322 | |
6323 | NodeValue AsinNode::getNthInput(unsigned idx) { |
6324 | if (idx == 0) { return Input_; } |
6325 | idx -= 1; |
6326 | llvm_unreachable("Invalid index" ); |
6327 | } |
6328 | |
6329 | void AsinNode::setNthInput(unsigned idx, NodeValue val) { |
6330 | if (idx == 0) { Input_ = val; return; } |
6331 | idx -= 1; |
6332 | llvm_unreachable("Invalid index" ); |
6333 | } |
6334 | |
6335 | llvm::StringRef AsinNode::getOutputName(unsigned idx) const { |
6336 | if (idx == 0) { return "Result" ; } |
6337 | llvm_unreachable("Invalid index" ); |
6338 | } |
6339 | |
6340 | std::string AsinNode::getDebugDesc() const { |
6341 | DescriptionBuilder db(getKindName()); |
6342 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6343 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6344 | db |
6345 | .addParam("Input" , *(getInput().getType())) |
6346 | .addParam("Users" , getNumUsers()); |
6347 | db.addParam("Result" , *(getResult().getType())); |
6348 | return db; |
6349 | } |
6350 | |
6351 | void AsinNode::visit(Node *parent, NodeWalker *visitor) { |
6352 | if (!visitor->shouldVisit(parent, this)) { return; } |
6353 | visitor->pre(parent, this); |
6354 | if (hasPredicate()) |
6355 | getPredicate().getNode()->visit(this, visitor); |
6356 | getInput().getNode()->visit(this, visitor); |
6357 | visitor->post(parent, this); |
6358 | } |
6359 | |
6360 | bool AsinNode::isEqual(const AsinNode &other) const { |
6361 | return true && |
6362 | Input_ == other.Input_ && |
6363 | predicate_ == other.predicate_ && |
6364 | getType(0) == other.getType(0); |
6365 | } |
6366 | |
6367 | Node* AsinNode::clone() const { |
6368 | return new AsinNode(getName(), getResult().getType(), getInput()); |
6369 | } |
6370 | |
6371 | llvm::hash_code AsinNode::getHash() const { |
6372 | return llvm::hash_combine( |
6373 | Input_); |
6374 | } |
6375 | |
6376 | unsigned AtanNode::getNumInputs() const { |
6377 | return 1; |
6378 | } |
6379 | |
6380 | std::string AtanNode::getInputName(unsigned idx) const { |
6381 | if (idx == 0) { return "Input" ; } |
6382 | idx -= 1; |
6383 | llvm_unreachable("Invalid index" ); |
6384 | } |
6385 | |
6386 | NodeValue AtanNode::getNthInput(unsigned idx) { |
6387 | if (idx == 0) { return Input_; } |
6388 | idx -= 1; |
6389 | llvm_unreachable("Invalid index" ); |
6390 | } |
6391 | |
6392 | void AtanNode::setNthInput(unsigned idx, NodeValue val) { |
6393 | if (idx == 0) { Input_ = val; return; } |
6394 | idx -= 1; |
6395 | llvm_unreachable("Invalid index" ); |
6396 | } |
6397 | |
6398 | llvm::StringRef AtanNode::getOutputName(unsigned idx) const { |
6399 | if (idx == 0) { return "Result" ; } |
6400 | llvm_unreachable("Invalid index" ); |
6401 | } |
6402 | |
6403 | std::string AtanNode::getDebugDesc() const { |
6404 | DescriptionBuilder db(getKindName()); |
6405 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6406 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6407 | db |
6408 | .addParam("Input" , *(getInput().getType())) |
6409 | .addParam("Users" , getNumUsers()); |
6410 | db.addParam("Result" , *(getResult().getType())); |
6411 | return db; |
6412 | } |
6413 | |
6414 | void AtanNode::visit(Node *parent, NodeWalker *visitor) { |
6415 | if (!visitor->shouldVisit(parent, this)) { return; } |
6416 | visitor->pre(parent, this); |
6417 | if (hasPredicate()) |
6418 | getPredicate().getNode()->visit(this, visitor); |
6419 | getInput().getNode()->visit(this, visitor); |
6420 | visitor->post(parent, this); |
6421 | } |
6422 | |
6423 | bool AtanNode::isEqual(const AtanNode &other) const { |
6424 | return true && |
6425 | Input_ == other.Input_ && |
6426 | predicate_ == other.predicate_ && |
6427 | getType(0) == other.getType(0); |
6428 | } |
6429 | |
6430 | Node* AtanNode::clone() const { |
6431 | return new AtanNode(getName(), getResult().getType(), getInput()); |
6432 | } |
6433 | |
6434 | llvm::hash_code AtanNode::getHash() const { |
6435 | return llvm::hash_combine( |
6436 | Input_); |
6437 | } |
6438 | |
6439 | unsigned ErfNode::getNumInputs() const { |
6440 | return 1; |
6441 | } |
6442 | |
6443 | std::string ErfNode::getInputName(unsigned idx) const { |
6444 | if (idx == 0) { return "Input" ; } |
6445 | idx -= 1; |
6446 | llvm_unreachable("Invalid index" ); |
6447 | } |
6448 | |
6449 | NodeValue ErfNode::getNthInput(unsigned idx) { |
6450 | if (idx == 0) { return Input_; } |
6451 | idx -= 1; |
6452 | llvm_unreachable("Invalid index" ); |
6453 | } |
6454 | |
6455 | void ErfNode::setNthInput(unsigned idx, NodeValue val) { |
6456 | if (idx == 0) { Input_ = val; return; } |
6457 | idx -= 1; |
6458 | llvm_unreachable("Invalid index" ); |
6459 | } |
6460 | |
6461 | llvm::StringRef ErfNode::getOutputName(unsigned idx) const { |
6462 | if (idx == 0) { return "Result" ; } |
6463 | llvm_unreachable("Invalid index" ); |
6464 | } |
6465 | |
6466 | std::string ErfNode::getDebugDesc() const { |
6467 | DescriptionBuilder db(getKindName()); |
6468 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6469 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6470 | db |
6471 | .addParam("Input" , *(getInput().getType())) |
6472 | .addParam("Users" , getNumUsers()); |
6473 | db.addParam("Result" , *(getResult().getType())); |
6474 | return db; |
6475 | } |
6476 | |
6477 | void ErfNode::visit(Node *parent, NodeWalker *visitor) { |
6478 | if (!visitor->shouldVisit(parent, this)) { return; } |
6479 | visitor->pre(parent, this); |
6480 | if (hasPredicate()) |
6481 | getPredicate().getNode()->visit(this, visitor); |
6482 | getInput().getNode()->visit(this, visitor); |
6483 | visitor->post(parent, this); |
6484 | } |
6485 | |
6486 | bool ErfNode::isEqual(const ErfNode &other) const { |
6487 | return true && |
6488 | Input_ == other.Input_ && |
6489 | predicate_ == other.predicate_ && |
6490 | getType(0) == other.getType(0); |
6491 | } |
6492 | |
6493 | Node* ErfNode::clone() const { |
6494 | return new ErfNode(getName(), getResult().getType(), getInput()); |
6495 | } |
6496 | |
6497 | llvm::hash_code ErfNode::getHash() const { |
6498 | return llvm::hash_combine( |
6499 | Input_); |
6500 | } |
6501 | |
6502 | unsigned ExpNode::getNumInputs() const { |
6503 | return 1; |
6504 | } |
6505 | |
6506 | std::string ExpNode::getInputName(unsigned idx) const { |
6507 | if (idx == 0) { return "Input" ; } |
6508 | idx -= 1; |
6509 | llvm_unreachable("Invalid index" ); |
6510 | } |
6511 | |
6512 | NodeValue ExpNode::getNthInput(unsigned idx) { |
6513 | if (idx == 0) { return Input_; } |
6514 | idx -= 1; |
6515 | llvm_unreachable("Invalid index" ); |
6516 | } |
6517 | |
6518 | void ExpNode::setNthInput(unsigned idx, NodeValue val) { |
6519 | if (idx == 0) { Input_ = val; return; } |
6520 | idx -= 1; |
6521 | llvm_unreachable("Invalid index" ); |
6522 | } |
6523 | |
6524 | llvm::StringRef ExpNode::getOutputName(unsigned idx) const { |
6525 | if (idx == 0) { return "Result" ; } |
6526 | llvm_unreachable("Invalid index" ); |
6527 | } |
6528 | |
6529 | std::string ExpNode::getDebugDesc() const { |
6530 | DescriptionBuilder db(getKindName()); |
6531 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6532 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6533 | db |
6534 | .addParam("Input" , *(getInput().getType())) |
6535 | .addParam("Users" , getNumUsers()); |
6536 | db.addParam("Result" , *(getResult().getType())); |
6537 | return db; |
6538 | } |
6539 | |
6540 | void ExpNode::visit(Node *parent, NodeWalker *visitor) { |
6541 | if (!visitor->shouldVisit(parent, this)) { return; } |
6542 | visitor->pre(parent, this); |
6543 | if (hasPredicate()) |
6544 | getPredicate().getNode()->visit(this, visitor); |
6545 | getInput().getNode()->visit(this, visitor); |
6546 | visitor->post(parent, this); |
6547 | } |
6548 | |
6549 | bool ExpNode::isEqual(const ExpNode &other) const { |
6550 | return true && |
6551 | Input_ == other.Input_ && |
6552 | predicate_ == other.predicate_ && |
6553 | getType(0) == other.getType(0); |
6554 | } |
6555 | |
6556 | Node* ExpNode::clone() const { |
6557 | return new ExpNode(getName(), getResult().getType(), getInput()); |
6558 | } |
6559 | |
6560 | llvm::hash_code ExpNode::getHash() const { |
6561 | return llvm::hash_combine( |
6562 | Input_); |
6563 | } |
6564 | |
6565 | unsigned LogitNode::getNumInputs() const { |
6566 | return 1; |
6567 | } |
6568 | |
6569 | std::string LogitNode::getInputName(unsigned idx) const { |
6570 | if (idx == 0) { return "Input" ; } |
6571 | idx -= 1; |
6572 | llvm_unreachable("Invalid index" ); |
6573 | } |
6574 | |
6575 | NodeValue LogitNode::getNthInput(unsigned idx) { |
6576 | if (idx == 0) { return Input_; } |
6577 | idx -= 1; |
6578 | llvm_unreachable("Invalid index" ); |
6579 | } |
6580 | |
6581 | void LogitNode::setNthInput(unsigned idx, NodeValue val) { |
6582 | if (idx == 0) { Input_ = val; return; } |
6583 | idx -= 1; |
6584 | llvm_unreachable("Invalid index" ); |
6585 | } |
6586 | |
6587 | llvm::StringRef LogitNode::getOutputName(unsigned idx) const { |
6588 | if (idx == 0) { return "Result" ; } |
6589 | llvm_unreachable("Invalid index" ); |
6590 | } |
6591 | |
6592 | std::string LogitNode::getDebugDesc() const { |
6593 | DescriptionBuilder db(getKindName()); |
6594 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6595 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6596 | db |
6597 | .addParam("Input" , *(getInput().getType())) |
6598 | .addParam("Epsilon" , getEpsilon()) |
6599 | .addParam("Users" , getNumUsers()); |
6600 | db.addParam("Result" , *(getResult().getType())); |
6601 | return db; |
6602 | } |
6603 | |
6604 | void LogitNode::visit(Node *parent, NodeWalker *visitor) { |
6605 | if (!visitor->shouldVisit(parent, this)) { return; } |
6606 | visitor->pre(parent, this); |
6607 | if (hasPredicate()) |
6608 | getPredicate().getNode()->visit(this, visitor); |
6609 | getInput().getNode()->visit(this, visitor); |
6610 | visitor->post(parent, this); |
6611 | } |
6612 | |
6613 | bool LogitNode::isEqual(const LogitNode &other) const { |
6614 | return true && |
6615 | Input_ == other.Input_ && |
6616 | predicate_ == other.predicate_ && |
6617 | Epsilon_ == other.Epsilon_ && |
6618 | getType(0) == other.getType(0); |
6619 | } |
6620 | |
6621 | Node* LogitNode::clone() const { |
6622 | return new LogitNode(getName(), getResult().getType(), getInput(), getEpsilon()); |
6623 | } |
6624 | |
6625 | llvm::hash_code LogitNode::getHash() const { |
6626 | return llvm::hash_combine( |
6627 | toBinary(Epsilon_), |
6628 | Input_); |
6629 | } |
6630 | |
6631 | unsigned NonZeroNode::getNumInputs() const { |
6632 | return 1; |
6633 | } |
6634 | |
6635 | std::string NonZeroNode::getInputName(unsigned idx) const { |
6636 | if (idx == 0) { return "Cond" ; } |
6637 | idx -= 1; |
6638 | llvm_unreachable("Invalid index" ); |
6639 | } |
6640 | |
6641 | NodeValue NonZeroNode::getNthInput(unsigned idx) { |
6642 | if (idx == 0) { return Cond_; } |
6643 | idx -= 1; |
6644 | llvm_unreachable("Invalid index" ); |
6645 | } |
6646 | |
6647 | void NonZeroNode::setNthInput(unsigned idx, NodeValue val) { |
6648 | if (idx == 0) { Cond_ = val; return; } |
6649 | idx -= 1; |
6650 | llvm_unreachable("Invalid index" ); |
6651 | } |
6652 | |
6653 | llvm::StringRef NonZeroNode::getOutputName(unsigned idx) const { |
6654 | if (idx == 0) { return "Result" ; } |
6655 | llvm_unreachable("Invalid index" ); |
6656 | } |
6657 | |
6658 | std::string NonZeroNode::getDebugDesc() const { |
6659 | DescriptionBuilder db(getKindName()); |
6660 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6661 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6662 | db |
6663 | .addParam("Cond" , *(getCond().getType())) |
6664 | .addParam("Users" , getNumUsers()); |
6665 | db.addParam("Result" , *(getResult().getType())); |
6666 | return db; |
6667 | } |
6668 | |
6669 | void NonZeroNode::visit(Node *parent, NodeWalker *visitor) { |
6670 | if (!visitor->shouldVisit(parent, this)) { return; } |
6671 | visitor->pre(parent, this); |
6672 | if (hasPredicate()) |
6673 | getPredicate().getNode()->visit(this, visitor); |
6674 | getCond().getNode()->visit(this, visitor); |
6675 | visitor->post(parent, this); |
6676 | } |
6677 | |
6678 | bool NonZeroNode::isEqual(const NonZeroNode &other) const { |
6679 | return true && |
6680 | Cond_ == other.Cond_ && |
6681 | predicate_ == other.predicate_ && |
6682 | getType(0) == other.getType(0); |
6683 | } |
6684 | |
6685 | Node* NonZeroNode::clone() const { |
6686 | return new NonZeroNode(getName(), getResult().getType(), getCond()); |
6687 | } |
6688 | |
6689 | llvm::hash_code NonZeroNode::getHash() const { |
6690 | return llvm::hash_combine( |
6691 | Cond_); |
6692 | } |
6693 | |
6694 | unsigned SelectNode::getNumInputs() const { |
6695 | return 3; |
6696 | } |
6697 | |
6698 | std::string SelectNode::getInputName(unsigned idx) const { |
6699 | if (idx == 0) { return "Cond" ; } |
6700 | if (idx == 1) { return "LHS" ; } |
6701 | if (idx == 2) { return "RHS" ; } |
6702 | idx -= 3; |
6703 | llvm_unreachable("Invalid index" ); |
6704 | } |
6705 | |
6706 | NodeValue SelectNode::getNthInput(unsigned idx) { |
6707 | if (idx == 0) { return Cond_; } |
6708 | if (idx == 1) { return LHS_; } |
6709 | if (idx == 2) { return RHS_; } |
6710 | idx -= 3; |
6711 | llvm_unreachable("Invalid index" ); |
6712 | } |
6713 | |
6714 | void SelectNode::setNthInput(unsigned idx, NodeValue val) { |
6715 | if (idx == 0) { Cond_ = val; return; } |
6716 | if (idx == 1) { LHS_ = val; return; } |
6717 | if (idx == 2) { RHS_ = val; return; } |
6718 | idx -= 3; |
6719 | llvm_unreachable("Invalid index" ); |
6720 | } |
6721 | |
6722 | llvm::StringRef SelectNode::getOutputName(unsigned idx) const { |
6723 | if (idx == 0) { return "Result" ; } |
6724 | llvm_unreachable("Invalid index" ); |
6725 | } |
6726 | |
6727 | std::string SelectNode::getDebugDesc() const { |
6728 | DescriptionBuilder db(getKindName()); |
6729 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6730 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6731 | db |
6732 | .addParam("Cond" , *(getCond().getType())) |
6733 | .addParam("LHS" , *(getLHS().getType())) |
6734 | .addParam("RHS" , *(getRHS().getType())) |
6735 | .addParam("Users" , getNumUsers()); |
6736 | db.addParam("Result" , *(getResult().getType())); |
6737 | return db; |
6738 | } |
6739 | |
6740 | void SelectNode::visit(Node *parent, NodeWalker *visitor) { |
6741 | if (!visitor->shouldVisit(parent, this)) { return; } |
6742 | visitor->pre(parent, this); |
6743 | if (hasPredicate()) |
6744 | getPredicate().getNode()->visit(this, visitor); |
6745 | getCond().getNode()->visit(this, visitor); |
6746 | getLHS().getNode()->visit(this, visitor); |
6747 | getRHS().getNode()->visit(this, visitor); |
6748 | visitor->post(parent, this); |
6749 | } |
6750 | |
6751 | bool SelectNode::isEqual(const SelectNode &other) const { |
6752 | return true && |
6753 | Cond_ == other.Cond_ && |
6754 | LHS_ == other.LHS_ && |
6755 | RHS_ == other.RHS_ && |
6756 | predicate_ == other.predicate_ && |
6757 | getType(0) == other.getType(0); |
6758 | } |
6759 | |
6760 | Node* SelectNode::clone() const { |
6761 | return new SelectNode(getName(), getResult().getType(), getCond(), getLHS(), getRHS()); |
6762 | } |
6763 | |
6764 | llvm::hash_code SelectNode::getHash() const { |
6765 | return llvm::hash_combine( |
6766 | Cond_, |
6767 | LHS_, |
6768 | RHS_); |
6769 | } |
6770 | |
6771 | unsigned BatchedAddNode::getNumInputs() const { |
6772 | return 2; |
6773 | } |
6774 | |
6775 | std::string BatchedAddNode::getInputName(unsigned idx) const { |
6776 | if (idx == 0) { return "Batch" ; } |
6777 | if (idx == 1) { return "Slice" ; } |
6778 | idx -= 2; |
6779 | llvm_unreachable("Invalid index" ); |
6780 | } |
6781 | |
6782 | NodeValue BatchedAddNode::getNthInput(unsigned idx) { |
6783 | if (idx == 0) { return Batch_; } |
6784 | if (idx == 1) { return Slice_; } |
6785 | idx -= 2; |
6786 | llvm_unreachable("Invalid index" ); |
6787 | } |
6788 | |
6789 | void BatchedAddNode::setNthInput(unsigned idx, NodeValue val) { |
6790 | if (idx == 0) { Batch_ = val; return; } |
6791 | if (idx == 1) { Slice_ = val; return; } |
6792 | idx -= 2; |
6793 | llvm_unreachable("Invalid index" ); |
6794 | } |
6795 | |
6796 | llvm::StringRef BatchedAddNode::getOutputName(unsigned idx) const { |
6797 | if (idx == 0) { return "Result" ; } |
6798 | llvm_unreachable("Invalid index" ); |
6799 | } |
6800 | |
6801 | std::string BatchedAddNode::getDebugDesc() const { |
6802 | DescriptionBuilder db(getKindName()); |
6803 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6804 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6805 | db |
6806 | .addParam("Batch" , *(getBatch().getType())) |
6807 | .addParam("Slice" , *(getSlice().getType())) |
6808 | .addParam("Users" , getNumUsers()); |
6809 | db.addParam("Result" , *(getResult().getType())); |
6810 | return db; |
6811 | } |
6812 | |
6813 | void BatchedAddNode::visit(Node *parent, NodeWalker *visitor) { |
6814 | if (!visitor->shouldVisit(parent, this)) { return; } |
6815 | visitor->pre(parent, this); |
6816 | if (hasPredicate()) |
6817 | getPredicate().getNode()->visit(this, visitor); |
6818 | getBatch().getNode()->visit(this, visitor); |
6819 | getSlice().getNode()->visit(this, visitor); |
6820 | visitor->post(parent, this); |
6821 | } |
6822 | |
6823 | bool BatchedAddNode::isEqual(const BatchedAddNode &other) const { |
6824 | return true && |
6825 | Batch_ == other.Batch_ && |
6826 | Slice_ == other.Slice_ && |
6827 | predicate_ == other.predicate_ && |
6828 | getType(0) == other.getType(0); |
6829 | } |
6830 | |
6831 | Node* BatchedAddNode::clone() const { |
6832 | return new BatchedAddNode(getName(), getResult().getType(), getBatch(), getSlice()); |
6833 | } |
6834 | |
6835 | llvm::hash_code BatchedAddNode::getHash() const { |
6836 | return llvm::hash_combine( |
6837 | Batch_, |
6838 | Slice_); |
6839 | } |
6840 | |
6841 | unsigned BatchedMulNode::getNumInputs() const { |
6842 | return 2; |
6843 | } |
6844 | |
6845 | std::string BatchedMulNode::getInputName(unsigned idx) const { |
6846 | if (idx == 0) { return "Batch" ; } |
6847 | if (idx == 1) { return "Slice" ; } |
6848 | idx -= 2; |
6849 | llvm_unreachable("Invalid index" ); |
6850 | } |
6851 | |
6852 | NodeValue BatchedMulNode::getNthInput(unsigned idx) { |
6853 | if (idx == 0) { return Batch_; } |
6854 | if (idx == 1) { return Slice_; } |
6855 | idx -= 2; |
6856 | llvm_unreachable("Invalid index" ); |
6857 | } |
6858 | |
6859 | void BatchedMulNode::setNthInput(unsigned idx, NodeValue val) { |
6860 | if (idx == 0) { Batch_ = val; return; } |
6861 | if (idx == 1) { Slice_ = val; return; } |
6862 | idx -= 2; |
6863 | llvm_unreachable("Invalid index" ); |
6864 | } |
6865 | |
6866 | llvm::StringRef BatchedMulNode::getOutputName(unsigned idx) const { |
6867 | if (idx == 0) { return "Result" ; } |
6868 | llvm_unreachable("Invalid index" ); |
6869 | } |
6870 | |
6871 | std::string BatchedMulNode::getDebugDesc() const { |
6872 | DescriptionBuilder db(getKindName()); |
6873 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6874 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6875 | db |
6876 | .addParam("Batch" , *(getBatch().getType())) |
6877 | .addParam("Slice" , *(getSlice().getType())) |
6878 | .addParam("Users" , getNumUsers()); |
6879 | db.addParam("Result" , *(getResult().getType())); |
6880 | return db; |
6881 | } |
6882 | |
6883 | void BatchedMulNode::visit(Node *parent, NodeWalker *visitor) { |
6884 | if (!visitor->shouldVisit(parent, this)) { return; } |
6885 | visitor->pre(parent, this); |
6886 | if (hasPredicate()) |
6887 | getPredicate().getNode()->visit(this, visitor); |
6888 | getBatch().getNode()->visit(this, visitor); |
6889 | getSlice().getNode()->visit(this, visitor); |
6890 | visitor->post(parent, this); |
6891 | } |
6892 | |
6893 | bool BatchedMulNode::isEqual(const BatchedMulNode &other) const { |
6894 | return true && |
6895 | Batch_ == other.Batch_ && |
6896 | Slice_ == other.Slice_ && |
6897 | predicate_ == other.predicate_ && |
6898 | getType(0) == other.getType(0); |
6899 | } |
6900 | |
6901 | Node* BatchedMulNode::clone() const { |
6902 | return new BatchedMulNode(getName(), getResult().getType(), getBatch(), getSlice()); |
6903 | } |
6904 | |
6905 | llvm::hash_code BatchedMulNode::getHash() const { |
6906 | return llvm::hash_combine( |
6907 | Batch_, |
6908 | Slice_); |
6909 | } |
6910 | |
6911 | unsigned MatMulNode::getNumInputs() const { |
6912 | return 2; |
6913 | } |
6914 | |
6915 | std::string MatMulNode::getInputName(unsigned idx) const { |
6916 | if (idx == 0) { return "LHS" ; } |
6917 | if (idx == 1) { return "RHS" ; } |
6918 | idx -= 2; |
6919 | llvm_unreachable("Invalid index" ); |
6920 | } |
6921 | |
6922 | NodeValue MatMulNode::getNthInput(unsigned idx) { |
6923 | if (idx == 0) { return LHS_; } |
6924 | if (idx == 1) { return RHS_; } |
6925 | idx -= 2; |
6926 | llvm_unreachable("Invalid index" ); |
6927 | } |
6928 | |
6929 | void MatMulNode::setNthInput(unsigned idx, NodeValue val) { |
6930 | if (idx == 0) { LHS_ = val; return; } |
6931 | if (idx == 1) { RHS_ = val; return; } |
6932 | idx -= 2; |
6933 | llvm_unreachable("Invalid index" ); |
6934 | } |
6935 | |
6936 | llvm::StringRef MatMulNode::getOutputName(unsigned idx) const { |
6937 | if (idx == 0) { return "Result" ; } |
6938 | llvm_unreachable("Invalid index" ); |
6939 | } |
6940 | |
6941 | std::string MatMulNode::getDebugDesc() const { |
6942 | DescriptionBuilder db(getKindName()); |
6943 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
6944 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
6945 | db |
6946 | .addParam("LHS" , *(getLHS().getType())) |
6947 | .addParam("RHS" , *(getRHS().getType())) |
6948 | .addParam("Users" , getNumUsers()); |
6949 | db.addParam("Result" , *(getResult().getType())); |
6950 | return db; |
6951 | } |
6952 | |
6953 | void MatMulNode::visit(Node *parent, NodeWalker *visitor) { |
6954 | if (!visitor->shouldVisit(parent, this)) { return; } |
6955 | visitor->pre(parent, this); |
6956 | if (hasPredicate()) |
6957 | getPredicate().getNode()->visit(this, visitor); |
6958 | getLHS().getNode()->visit(this, visitor); |
6959 | getRHS().getNode()->visit(this, visitor); |
6960 | visitor->post(parent, this); |
6961 | } |
6962 | |
6963 | bool MatMulNode::isEqual(const MatMulNode &other) const { |
6964 | return true && |
6965 | LHS_ == other.LHS_ && |
6966 | RHS_ == other.RHS_ && |
6967 | predicate_ == other.predicate_ && |
6968 | getType(0) == other.getType(0); |
6969 | } |
6970 | |
6971 | Node* MatMulNode::clone() const { |
6972 | return new MatMulNode(getName(), getResult().getType(), getLHS(), getRHS()); |
6973 | } |
6974 | |
6975 | llvm::hash_code MatMulNode::getHash() const { |
6976 | return llvm::hash_combine( |
6977 | LHS_, |
6978 | RHS_); |
6979 | } |
6980 | |
6981 | unsigned BatchMatMulNode::getNumInputs() const { |
6982 | return 2; |
6983 | } |
6984 | |
6985 | std::string BatchMatMulNode::getInputName(unsigned idx) const { |
6986 | if (idx == 0) { return "LHS" ; } |
6987 | if (idx == 1) { return "RHS" ; } |
6988 | idx -= 2; |
6989 | llvm_unreachable("Invalid index" ); |
6990 | } |
6991 | |
6992 | NodeValue BatchMatMulNode::getNthInput(unsigned idx) { |
6993 | if (idx == 0) { return LHS_; } |
6994 | if (idx == 1) { return RHS_; } |
6995 | idx -= 2; |
6996 | llvm_unreachable("Invalid index" ); |
6997 | } |
6998 | |
6999 | void BatchMatMulNode::setNthInput(unsigned idx, NodeValue val) { |
7000 | if (idx == 0) { LHS_ = val; return; } |
7001 | if (idx == 1) { RHS_ = val; return; } |
7002 | idx -= 2; |
7003 | llvm_unreachable("Invalid index" ); |
7004 | } |
7005 | |
7006 | llvm::StringRef BatchMatMulNode::getOutputName(unsigned idx) const { |
7007 | if (idx == 0) { return "Result" ; } |
7008 | llvm_unreachable("Invalid index" ); |
7009 | } |
7010 | |
7011 | std::string BatchMatMulNode::getDebugDesc() const { |
7012 | DescriptionBuilder db(getKindName()); |
7013 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7014 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7015 | db |
7016 | .addParam("LHS" , *(getLHS().getType())) |
7017 | .addParam("RHS" , *(getRHS().getType())) |
7018 | .addParam("Users" , getNumUsers()); |
7019 | db.addParam("Result" , *(getResult().getType())); |
7020 | return db; |
7021 | } |
7022 | |
7023 | void BatchMatMulNode::visit(Node *parent, NodeWalker *visitor) { |
7024 | if (!visitor->shouldVisit(parent, this)) { return; } |
7025 | visitor->pre(parent, this); |
7026 | if (hasPredicate()) |
7027 | getPredicate().getNode()->visit(this, visitor); |
7028 | getLHS().getNode()->visit(this, visitor); |
7029 | getRHS().getNode()->visit(this, visitor); |
7030 | visitor->post(parent, this); |
7031 | } |
7032 | |
7033 | bool BatchMatMulNode::isEqual(const BatchMatMulNode &other) const { |
7034 | return true && |
7035 | LHS_ == other.LHS_ && |
7036 | RHS_ == other.RHS_ && |
7037 | predicate_ == other.predicate_ && |
7038 | getType(0) == other.getType(0); |
7039 | } |
7040 | |
7041 | Node* BatchMatMulNode::clone() const { |
7042 | return new BatchMatMulNode(getName(), getResult().getType(), getLHS(), getRHS()); |
7043 | } |
7044 | |
7045 | llvm::hash_code BatchMatMulNode::getHash() const { |
7046 | return llvm::hash_combine( |
7047 | LHS_, |
7048 | RHS_); |
7049 | } |
7050 | |
7051 | unsigned BatchedReduceAddNode::getNumInputs() const { |
7052 | return 1; |
7053 | } |
7054 | |
7055 | std::string BatchedReduceAddNode::getInputName(unsigned idx) const { |
7056 | if (idx == 0) { return "Batch" ; } |
7057 | idx -= 1; |
7058 | llvm_unreachable("Invalid index" ); |
7059 | } |
7060 | |
7061 | NodeValue BatchedReduceAddNode::getNthInput(unsigned idx) { |
7062 | if (idx == 0) { return Batch_; } |
7063 | idx -= 1; |
7064 | llvm_unreachable("Invalid index" ); |
7065 | } |
7066 | |
7067 | void BatchedReduceAddNode::setNthInput(unsigned idx, NodeValue val) { |
7068 | if (idx == 0) { Batch_ = val; return; } |
7069 | idx -= 1; |
7070 | llvm_unreachable("Invalid index" ); |
7071 | } |
7072 | |
7073 | llvm::StringRef BatchedReduceAddNode::getOutputName(unsigned idx) const { |
7074 | if (idx == 0) { return "Result" ; } |
7075 | llvm_unreachable("Invalid index" ); |
7076 | } |
7077 | |
7078 | std::string BatchedReduceAddNode::getDebugDesc() const { |
7079 | DescriptionBuilder db(getKindName()); |
7080 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7081 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7082 | db |
7083 | .addParam("Batch" , *(getBatch().getType())) |
7084 | .addParam("Axis" , getAxis()) |
7085 | .addParam("Users" , getNumUsers()); |
7086 | db.addParam("Result" , *(getResult().getType())); |
7087 | return db; |
7088 | } |
7089 | |
7090 | void BatchedReduceAddNode::visit(Node *parent, NodeWalker *visitor) { |
7091 | if (!visitor->shouldVisit(parent, this)) { return; } |
7092 | visitor->pre(parent, this); |
7093 | if (hasPredicate()) |
7094 | getPredicate().getNode()->visit(this, visitor); |
7095 | getBatch().getNode()->visit(this, visitor); |
7096 | visitor->post(parent, this); |
7097 | } |
7098 | |
7099 | bool BatchedReduceAddNode::isEqual(const BatchedReduceAddNode &other) const { |
7100 | return true && |
7101 | Batch_ == other.Batch_ && |
7102 | predicate_ == other.predicate_ && |
7103 | Axis_ == other.Axis_ && |
7104 | getType(0) == other.getType(0); |
7105 | } |
7106 | |
7107 | Node* BatchedReduceAddNode::clone() const { |
7108 | return new BatchedReduceAddNode(getName(), getResult().getType(), getBatch(), getAxis()); |
7109 | } |
7110 | |
7111 | llvm::hash_code BatchedReduceAddNode::getHash() const { |
7112 | return llvm::hash_combine( |
7113 | Axis_, |
7114 | Batch_); |
7115 | } |
7116 | |
7117 | unsigned BatchedReduceSumSquareNode::getNumInputs() const { |
7118 | return 1; |
7119 | } |
7120 | |
7121 | std::string BatchedReduceSumSquareNode::getInputName(unsigned idx) const { |
7122 | if (idx == 0) { return "Batch" ; } |
7123 | idx -= 1; |
7124 | llvm_unreachable("Invalid index" ); |
7125 | } |
7126 | |
7127 | NodeValue BatchedReduceSumSquareNode::getNthInput(unsigned idx) { |
7128 | if (idx == 0) { return Batch_; } |
7129 | idx -= 1; |
7130 | llvm_unreachable("Invalid index" ); |
7131 | } |
7132 | |
7133 | void BatchedReduceSumSquareNode::setNthInput(unsigned idx, NodeValue val) { |
7134 | if (idx == 0) { Batch_ = val; return; } |
7135 | idx -= 1; |
7136 | llvm_unreachable("Invalid index" ); |
7137 | } |
7138 | |
7139 | llvm::StringRef BatchedReduceSumSquareNode::getOutputName(unsigned idx) const { |
7140 | if (idx == 0) { return "Result" ; } |
7141 | llvm_unreachable("Invalid index" ); |
7142 | } |
7143 | |
7144 | std::string BatchedReduceSumSquareNode::getDebugDesc() const { |
7145 | DescriptionBuilder db(getKindName()); |
7146 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7147 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7148 | db |
7149 | .addParam("Batch" , *(getBatch().getType())) |
7150 | .addParam("Axis" , getAxis()) |
7151 | .addParam("Users" , getNumUsers()); |
7152 | db.addParam("Result" , *(getResult().getType())); |
7153 | return db; |
7154 | } |
7155 | |
7156 | void BatchedReduceSumSquareNode::visit(Node *parent, NodeWalker *visitor) { |
7157 | if (!visitor->shouldVisit(parent, this)) { return; } |
7158 | visitor->pre(parent, this); |
7159 | if (hasPredicate()) |
7160 | getPredicate().getNode()->visit(this, visitor); |
7161 | getBatch().getNode()->visit(this, visitor); |
7162 | visitor->post(parent, this); |
7163 | } |
7164 | |
7165 | bool BatchedReduceSumSquareNode::isEqual(const BatchedReduceSumSquareNode &other) const { |
7166 | return true && |
7167 | Batch_ == other.Batch_ && |
7168 | predicate_ == other.predicate_ && |
7169 | Axis_ == other.Axis_ && |
7170 | getType(0) == other.getType(0); |
7171 | } |
7172 | |
7173 | Node* BatchedReduceSumSquareNode::clone() const { |
7174 | return new BatchedReduceSumSquareNode(getName(), getResult().getType(), getBatch(), getAxis()); |
7175 | } |
7176 | |
7177 | llvm::hash_code BatchedReduceSumSquareNode::getHash() const { |
7178 | return llvm::hash_combine( |
7179 | Axis_, |
7180 | Batch_); |
7181 | } |
7182 | |
7183 | unsigned BatchedReduceMeanNode::getNumInputs() const { |
7184 | return 1; |
7185 | } |
7186 | |
7187 | std::string BatchedReduceMeanNode::getInputName(unsigned idx) const { |
7188 | if (idx == 0) { return "Batch" ; } |
7189 | idx -= 1; |
7190 | llvm_unreachable("Invalid index" ); |
7191 | } |
7192 | |
7193 | NodeValue BatchedReduceMeanNode::getNthInput(unsigned idx) { |
7194 | if (idx == 0) { return Batch_; } |
7195 | idx -= 1; |
7196 | llvm_unreachable("Invalid index" ); |
7197 | } |
7198 | |
7199 | void BatchedReduceMeanNode::setNthInput(unsigned idx, NodeValue val) { |
7200 | if (idx == 0) { Batch_ = val; return; } |
7201 | idx -= 1; |
7202 | llvm_unreachable("Invalid index" ); |
7203 | } |
7204 | |
7205 | llvm::StringRef BatchedReduceMeanNode::getOutputName(unsigned idx) const { |
7206 | if (idx == 0) { return "Result" ; } |
7207 | llvm_unreachable("Invalid index" ); |
7208 | } |
7209 | |
7210 | std::string BatchedReduceMeanNode::getDebugDesc() const { |
7211 | DescriptionBuilder db(getKindName()); |
7212 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7213 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7214 | db |
7215 | .addParam("Batch" , *(getBatch().getType())) |
7216 | .addParam("Axes" , getAxes()) |
7217 | .addParam("Users" , getNumUsers()); |
7218 | db.addParam("Result" , *(getResult().getType())); |
7219 | return db; |
7220 | } |
7221 | |
7222 | void BatchedReduceMeanNode::visit(Node *parent, NodeWalker *visitor) { |
7223 | if (!visitor->shouldVisit(parent, this)) { return; } |
7224 | visitor->pre(parent, this); |
7225 | if (hasPredicate()) |
7226 | getPredicate().getNode()->visit(this, visitor); |
7227 | getBatch().getNode()->visit(this, visitor); |
7228 | visitor->post(parent, this); |
7229 | } |
7230 | |
7231 | bool BatchedReduceMeanNode::isEqual(const BatchedReduceMeanNode &other) const { |
7232 | return true && |
7233 | Batch_ == other.Batch_ && |
7234 | predicate_ == other.predicate_ && |
7235 | Axes_ == other.Axes_ && |
7236 | getType(0) == other.getType(0); |
7237 | } |
7238 | |
7239 | Node* BatchedReduceMeanNode::clone() const { |
7240 | return new BatchedReduceMeanNode(getName(), getResult().getType(), getBatch(), getAxes()); |
7241 | } |
7242 | |
7243 | llvm::hash_code BatchedReduceMeanNode::getHash() const { |
7244 | return llvm::hash_combine( |
7245 | llvm::hash_combine_range(Axes_.begin(), Axes_.end()), |
7246 | Batch_); |
7247 | } |
7248 | |
7249 | unsigned BatchedReduceMinNode::getNumInputs() const { |
7250 | return 1; |
7251 | } |
7252 | |
7253 | std::string BatchedReduceMinNode::getInputName(unsigned idx) const { |
7254 | if (idx == 0) { return "Batch" ; } |
7255 | idx -= 1; |
7256 | llvm_unreachable("Invalid index" ); |
7257 | } |
7258 | |
7259 | NodeValue BatchedReduceMinNode::getNthInput(unsigned idx) { |
7260 | if (idx == 0) { return Batch_; } |
7261 | idx -= 1; |
7262 | llvm_unreachable("Invalid index" ); |
7263 | } |
7264 | |
7265 | void BatchedReduceMinNode::setNthInput(unsigned idx, NodeValue val) { |
7266 | if (idx == 0) { Batch_ = val; return; } |
7267 | idx -= 1; |
7268 | llvm_unreachable("Invalid index" ); |
7269 | } |
7270 | |
7271 | llvm::StringRef BatchedReduceMinNode::getOutputName(unsigned idx) const { |
7272 | if (idx == 0) { return "Result" ; } |
7273 | llvm_unreachable("Invalid index" ); |
7274 | } |
7275 | |
7276 | std::string BatchedReduceMinNode::getDebugDesc() const { |
7277 | DescriptionBuilder db(getKindName()); |
7278 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7279 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7280 | db |
7281 | .addParam("Batch" , *(getBatch().getType())) |
7282 | .addParam("Axes" , getAxes()) |
7283 | .addParam("Users" , getNumUsers()); |
7284 | db.addParam("Result" , *(getResult().getType())); |
7285 | return db; |
7286 | } |
7287 | |
7288 | void BatchedReduceMinNode::visit(Node *parent, NodeWalker *visitor) { |
7289 | if (!visitor->shouldVisit(parent, this)) { return; } |
7290 | visitor->pre(parent, this); |
7291 | if (hasPredicate()) |
7292 | getPredicate().getNode()->visit(this, visitor); |
7293 | getBatch().getNode()->visit(this, visitor); |
7294 | visitor->post(parent, this); |
7295 | } |
7296 | |
7297 | bool BatchedReduceMinNode::isEqual(const BatchedReduceMinNode &other) const { |
7298 | return true && |
7299 | Batch_ == other.Batch_ && |
7300 | predicate_ == other.predicate_ && |
7301 | Axes_ == other.Axes_ && |
7302 | getType(0) == other.getType(0); |
7303 | } |
7304 | |
7305 | Node* BatchedReduceMinNode::clone() const { |
7306 | return new BatchedReduceMinNode(getName(), getResult().getType(), getBatch(), getAxes()); |
7307 | } |
7308 | |
7309 | llvm::hash_code BatchedReduceMinNode::getHash() const { |
7310 | return llvm::hash_combine( |
7311 | llvm::hash_combine_range(Axes_.begin(), Axes_.end()), |
7312 | Batch_); |
7313 | } |
7314 | |
7315 | unsigned BatchedReduceMaxNode::getNumInputs() const { |
7316 | return 1; |
7317 | } |
7318 | |
7319 | std::string BatchedReduceMaxNode::getInputName(unsigned idx) const { |
7320 | if (idx == 0) { return "Batch" ; } |
7321 | idx -= 1; |
7322 | llvm_unreachable("Invalid index" ); |
7323 | } |
7324 | |
7325 | NodeValue BatchedReduceMaxNode::getNthInput(unsigned idx) { |
7326 | if (idx == 0) { return Batch_; } |
7327 | idx -= 1; |
7328 | llvm_unreachable("Invalid index" ); |
7329 | } |
7330 | |
7331 | void BatchedReduceMaxNode::setNthInput(unsigned idx, NodeValue val) { |
7332 | if (idx == 0) { Batch_ = val; return; } |
7333 | idx -= 1; |
7334 | llvm_unreachable("Invalid index" ); |
7335 | } |
7336 | |
7337 | llvm::StringRef BatchedReduceMaxNode::getOutputName(unsigned idx) const { |
7338 | if (idx == 0) { return "Result" ; } |
7339 | llvm_unreachable("Invalid index" ); |
7340 | } |
7341 | |
7342 | std::string BatchedReduceMaxNode::getDebugDesc() const { |
7343 | DescriptionBuilder db(getKindName()); |
7344 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7345 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7346 | db |
7347 | .addParam("Batch" , *(getBatch().getType())) |
7348 | .addParam("Axes" , getAxes()) |
7349 | .addParam("Users" , getNumUsers()); |
7350 | db.addParam("Result" , *(getResult().getType())); |
7351 | return db; |
7352 | } |
7353 | |
7354 | void BatchedReduceMaxNode::visit(Node *parent, NodeWalker *visitor) { |
7355 | if (!visitor->shouldVisit(parent, this)) { return; } |
7356 | visitor->pre(parent, this); |
7357 | if (hasPredicate()) |
7358 | getPredicate().getNode()->visit(this, visitor); |
7359 | getBatch().getNode()->visit(this, visitor); |
7360 | visitor->post(parent, this); |
7361 | } |
7362 | |
7363 | bool BatchedReduceMaxNode::isEqual(const BatchedReduceMaxNode &other) const { |
7364 | return true && |
7365 | Batch_ == other.Batch_ && |
7366 | predicate_ == other.predicate_ && |
7367 | Axes_ == other.Axes_ && |
7368 | getType(0) == other.getType(0); |
7369 | } |
7370 | |
7371 | Node* BatchedReduceMaxNode::clone() const { |
7372 | return new BatchedReduceMaxNode(getName(), getResult().getType(), getBatch(), getAxes()); |
7373 | } |
7374 | |
7375 | llvm::hash_code BatchedReduceMaxNode::getHash() const { |
7376 | return llvm::hash_combine( |
7377 | llvm::hash_combine_range(Axes_.begin(), Axes_.end()), |
7378 | Batch_); |
7379 | } |
7380 | |
7381 | unsigned BatchedReduceProdNode::getNumInputs() const { |
7382 | return 1; |
7383 | } |
7384 | |
7385 | std::string BatchedReduceProdNode::getInputName(unsigned idx) const { |
7386 | if (idx == 0) { return "Batch" ; } |
7387 | idx -= 1; |
7388 | llvm_unreachable("Invalid index" ); |
7389 | } |
7390 | |
7391 | NodeValue BatchedReduceProdNode::getNthInput(unsigned idx) { |
7392 | if (idx == 0) { return Batch_; } |
7393 | idx -= 1; |
7394 | llvm_unreachable("Invalid index" ); |
7395 | } |
7396 | |
7397 | void BatchedReduceProdNode::setNthInput(unsigned idx, NodeValue val) { |
7398 | if (idx == 0) { Batch_ = val; return; } |
7399 | idx -= 1; |
7400 | llvm_unreachable("Invalid index" ); |
7401 | } |
7402 | |
7403 | llvm::StringRef BatchedReduceProdNode::getOutputName(unsigned idx) const { |
7404 | if (idx == 0) { return "Result" ; } |
7405 | llvm_unreachable("Invalid index" ); |
7406 | } |
7407 | |
7408 | std::string BatchedReduceProdNode::getDebugDesc() const { |
7409 | DescriptionBuilder db(getKindName()); |
7410 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7411 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7412 | db |
7413 | .addParam("Batch" , *(getBatch().getType())) |
7414 | .addParam("Axis" , getAxis()) |
7415 | .addParam("Users" , getNumUsers()); |
7416 | db.addParam("Result" , *(getResult().getType())); |
7417 | return db; |
7418 | } |
7419 | |
7420 | void BatchedReduceProdNode::visit(Node *parent, NodeWalker *visitor) { |
7421 | if (!visitor->shouldVisit(parent, this)) { return; } |
7422 | visitor->pre(parent, this); |
7423 | if (hasPredicate()) |
7424 | getPredicate().getNode()->visit(this, visitor); |
7425 | getBatch().getNode()->visit(this, visitor); |
7426 | visitor->post(parent, this); |
7427 | } |
7428 | |
7429 | bool BatchedReduceProdNode::isEqual(const BatchedReduceProdNode &other) const { |
7430 | return true && |
7431 | Batch_ == other.Batch_ && |
7432 | predicate_ == other.predicate_ && |
7433 | Axis_ == other.Axis_ && |
7434 | getType(0) == other.getType(0); |
7435 | } |
7436 | |
7437 | Node* BatchedReduceProdNode::clone() const { |
7438 | return new BatchedReduceProdNode(getName(), getResult().getType(), getBatch(), getAxis()); |
7439 | } |
7440 | |
7441 | llvm::hash_code BatchedReduceProdNode::getHash() const { |
7442 | return llvm::hash_combine( |
7443 | Axis_, |
7444 | Batch_); |
7445 | } |
7446 | |
7447 | unsigned ChannelShuffleNode::getNumInputs() const { |
7448 | return 1; |
7449 | } |
7450 | |
7451 | std::string ChannelShuffleNode::getInputName(unsigned idx) const { |
7452 | if (idx == 0) { return "Input" ; } |
7453 | idx -= 1; |
7454 | llvm_unreachable("Invalid index" ); |
7455 | } |
7456 | |
7457 | NodeValue ChannelShuffleNode::getNthInput(unsigned idx) { |
7458 | if (idx == 0) { return Input_; } |
7459 | idx -= 1; |
7460 | llvm_unreachable("Invalid index" ); |
7461 | } |
7462 | |
7463 | void ChannelShuffleNode::setNthInput(unsigned idx, NodeValue val) { |
7464 | if (idx == 0) { Input_ = val; return; } |
7465 | idx -= 1; |
7466 | llvm_unreachable("Invalid index" ); |
7467 | } |
7468 | |
7469 | llvm::StringRef ChannelShuffleNode::getOutputName(unsigned idx) const { |
7470 | if (idx == 0) { return "Result" ; } |
7471 | llvm_unreachable("Invalid index" ); |
7472 | } |
7473 | |
7474 | std::string ChannelShuffleNode::getDebugDesc() const { |
7475 | DescriptionBuilder db(getKindName()); |
7476 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7477 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7478 | db |
7479 | .addParam("Input" , *(getInput().getType())) |
7480 | .addParam("Group" , getGroup()) |
7481 | .addParam("Kernel" , getKernel()) |
7482 | .addParam("Users" , getNumUsers()); |
7483 | db.addParam("Result" , *(getResult().getType())); |
7484 | return db; |
7485 | } |
7486 | |
7487 | void ChannelShuffleNode::visit(Node *parent, NodeWalker *visitor) { |
7488 | if (!visitor->shouldVisit(parent, this)) { return; } |
7489 | visitor->pre(parent, this); |
7490 | if (hasPredicate()) |
7491 | getPredicate().getNode()->visit(this, visitor); |
7492 | getInput().getNode()->visit(this, visitor); |
7493 | visitor->post(parent, this); |
7494 | } |
7495 | |
7496 | bool ChannelShuffleNode::isEqual(const ChannelShuffleNode &other) const { |
7497 | return true && |
7498 | Input_ == other.Input_ && |
7499 | predicate_ == other.predicate_ && |
7500 | Group_ == other.Group_ && |
7501 | Kernel_ == other.Kernel_ && |
7502 | getType(0) == other.getType(0); |
7503 | } |
7504 | |
7505 | Node* ChannelShuffleNode::clone() const { |
7506 | return new ChannelShuffleNode(getName(), getResult().getType(), getInput(), getGroup(), getKernel()); |
7507 | } |
7508 | |
7509 | llvm::hash_code ChannelShuffleNode::getHash() const { |
7510 | return llvm::hash_combine( |
7511 | Group_, |
7512 | Kernel_, |
7513 | Input_); |
7514 | } |
7515 | |
7516 | unsigned CumSumNode::getNumInputs() const { |
7517 | return 1; |
7518 | } |
7519 | |
7520 | std::string CumSumNode::getInputName(unsigned idx) const { |
7521 | if (idx == 0) { return "Input" ; } |
7522 | idx -= 1; |
7523 | llvm_unreachable("Invalid index" ); |
7524 | } |
7525 | |
7526 | NodeValue CumSumNode::getNthInput(unsigned idx) { |
7527 | if (idx == 0) { return Input_; } |
7528 | idx -= 1; |
7529 | llvm_unreachable("Invalid index" ); |
7530 | } |
7531 | |
7532 | void CumSumNode::setNthInput(unsigned idx, NodeValue val) { |
7533 | if (idx == 0) { Input_ = val; return; } |
7534 | idx -= 1; |
7535 | llvm_unreachable("Invalid index" ); |
7536 | } |
7537 | |
7538 | llvm::StringRef CumSumNode::getOutputName(unsigned idx) const { |
7539 | if (idx == 0) { return "Result" ; } |
7540 | llvm_unreachable("Invalid index" ); |
7541 | } |
7542 | |
7543 | std::string CumSumNode::getDebugDesc() const { |
7544 | DescriptionBuilder db(getKindName()); |
7545 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7546 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7547 | db |
7548 | .addParam("Input" , *(getInput().getType())) |
7549 | .addParam("Dim" , getDim()) |
7550 | .addParam("Exclusive" , getExclusive()) |
7551 | .addParam("Reverse" , getReverse()) |
7552 | .addParam("Users" , getNumUsers()); |
7553 | db.addParam("Result" , *(getResult().getType())); |
7554 | return db; |
7555 | } |
7556 | |
7557 | void CumSumNode::visit(Node *parent, NodeWalker *visitor) { |
7558 | if (!visitor->shouldVisit(parent, this)) { return; } |
7559 | visitor->pre(parent, this); |
7560 | if (hasPredicate()) |
7561 | getPredicate().getNode()->visit(this, visitor); |
7562 | getInput().getNode()->visit(this, visitor); |
7563 | visitor->post(parent, this); |
7564 | } |
7565 | |
7566 | bool CumSumNode::isEqual(const CumSumNode &other) const { |
7567 | return true && |
7568 | Input_ == other.Input_ && |
7569 | predicate_ == other.predicate_ && |
7570 | Dim_ == other.Dim_ && |
7571 | Exclusive_ == other.Exclusive_ && |
7572 | Reverse_ == other.Reverse_ && |
7573 | getType(0) == other.getType(0); |
7574 | } |
7575 | |
7576 | Node* CumSumNode::clone() const { |
7577 | return new CumSumNode(getName(), getResult().getType(), getInput(), getDim(), getExclusive(), getReverse()); |
7578 | } |
7579 | |
7580 | llvm::hash_code CumSumNode::getHash() const { |
7581 | return llvm::hash_combine( |
7582 | Dim_, |
7583 | Exclusive_, |
7584 | Reverse_, |
7585 | Input_); |
7586 | } |
7587 | |
7588 | unsigned LengthsSumNode::getNumInputs() const { |
7589 | return 2; |
7590 | } |
7591 | |
7592 | std::string LengthsSumNode::getInputName(unsigned idx) const { |
7593 | if (idx == 0) { return "Data" ; } |
7594 | if (idx == 1) { return "Lengths" ; } |
7595 | idx -= 2; |
7596 | llvm_unreachable("Invalid index" ); |
7597 | } |
7598 | |
7599 | NodeValue LengthsSumNode::getNthInput(unsigned idx) { |
7600 | if (idx == 0) { return Data_; } |
7601 | if (idx == 1) { return Lengths_; } |
7602 | idx -= 2; |
7603 | llvm_unreachable("Invalid index" ); |
7604 | } |
7605 | |
7606 | void LengthsSumNode::setNthInput(unsigned idx, NodeValue val) { |
7607 | if (idx == 0) { Data_ = val; return; } |
7608 | if (idx == 1) { Lengths_ = val; return; } |
7609 | idx -= 2; |
7610 | llvm_unreachable("Invalid index" ); |
7611 | } |
7612 | |
7613 | llvm::StringRef LengthsSumNode::getOutputName(unsigned idx) const { |
7614 | if (idx == 0) { return "Result" ; } |
7615 | llvm_unreachable("Invalid index" ); |
7616 | } |
7617 | |
7618 | std::string LengthsSumNode::getDebugDesc() const { |
7619 | DescriptionBuilder db(getKindName()); |
7620 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7621 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7622 | db |
7623 | .addParam("Data" , *(getData().getType())) |
7624 | .addParam("Lengths" , *(getLengths().getType())) |
7625 | .addParam("Users" , getNumUsers()); |
7626 | db.addParam("Result" , *(getResult().getType())); |
7627 | return db; |
7628 | } |
7629 | |
7630 | void LengthsSumNode::visit(Node *parent, NodeWalker *visitor) { |
7631 | if (!visitor->shouldVisit(parent, this)) { return; } |
7632 | visitor->pre(parent, this); |
7633 | if (hasPredicate()) |
7634 | getPredicate().getNode()->visit(this, visitor); |
7635 | getData().getNode()->visit(this, visitor); |
7636 | getLengths().getNode()->visit(this, visitor); |
7637 | visitor->post(parent, this); |
7638 | } |
7639 | |
7640 | bool LengthsSumNode::isEqual(const LengthsSumNode &other) const { |
7641 | return true && |
7642 | Data_ == other.Data_ && |
7643 | Lengths_ == other.Lengths_ && |
7644 | predicate_ == other.predicate_ && |
7645 | getType(0) == other.getType(0); |
7646 | } |
7647 | |
7648 | Node* LengthsSumNode::clone() const { |
7649 | return new LengthsSumNode(getName(), getResult().getType(), getData(), getLengths()); |
7650 | } |
7651 | |
7652 | llvm::hash_code LengthsSumNode::getHash() const { |
7653 | return llvm::hash_combine( |
7654 | Data_, |
7655 | Lengths_); |
7656 | } |
7657 | |
7658 | unsigned SparseLengthsSumGradNode::getNumInputs() const { |
7659 | return 5; |
7660 | } |
7661 | |
7662 | std::string SparseLengthsSumGradNode::getInputName(unsigned idx) const { |
7663 | if (idx == 0) { return "Data" ; } |
7664 | if (idx == 1) { return "Indices" ; } |
7665 | if (idx == 2) { return "Lengths" ; } |
7666 | if (idx == 3) { return "OriginalOutputForResult" ; } |
7667 | if (idx == 4) { return "GradOfOriginalOutputNamedResult" ; } |
7668 | idx -= 5; |
7669 | llvm_unreachable("Invalid index" ); |
7670 | } |
7671 | |
7672 | NodeValue SparseLengthsSumGradNode::getNthInput(unsigned idx) { |
7673 | if (idx == 0) { return Data_; } |
7674 | if (idx == 1) { return Indices_; } |
7675 | if (idx == 2) { return Lengths_; } |
7676 | if (idx == 3) { return OriginalOutputForResult_; } |
7677 | if (idx == 4) { return GradOfOriginalOutputNamedResult_; } |
7678 | idx -= 5; |
7679 | llvm_unreachable("Invalid index" ); |
7680 | } |
7681 | |
7682 | void SparseLengthsSumGradNode::setNthInput(unsigned idx, NodeValue val) { |
7683 | if (idx == 0) { Data_ = val; return; } |
7684 | if (idx == 1) { Indices_ = val; return; } |
7685 | if (idx == 2) { Lengths_ = val; return; } |
7686 | if (idx == 3) { OriginalOutputForResult_ = val; return; } |
7687 | if (idx == 4) { GradOfOriginalOutputNamedResult_ = val; return; } |
7688 | idx -= 5; |
7689 | llvm_unreachable("Invalid index" ); |
7690 | } |
7691 | |
7692 | llvm::StringRef SparseLengthsSumGradNode::getOutputName(unsigned idx) const { |
7693 | if (idx == 0) { return "GradOfInputNamedData" ; } |
7694 | if (idx == 1) { return "GradOfInputNamedIndices" ; } |
7695 | if (idx == 2) { return "GradOfInputNamedLengths" ; } |
7696 | llvm_unreachable("Invalid index" ); |
7697 | } |
7698 | |
7699 | std::string SparseLengthsSumGradNode::getDebugDesc() const { |
7700 | DescriptionBuilder db(getKindName()); |
7701 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7702 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7703 | db |
7704 | .addParam("Data" , *(getData().getType())) |
7705 | .addParam("Indices" , *(getIndices().getType())) |
7706 | .addParam("Lengths" , *(getLengths().getType())) |
7707 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
7708 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
7709 | .addParam("LengthsMode" , getLengthsMode()) |
7710 | .addParam("AvgLength" , getAvgLength()) |
7711 | .addParam("Users" , getNumUsers()); |
7712 | db.addParam("GradOfInputNamedData" , *(getGradOfInputNamedData().getType())); |
7713 | db.addParam("GradOfInputNamedIndices" , *(getGradOfInputNamedIndices().getType())); |
7714 | db.addParam("GradOfInputNamedLengths" , *(getGradOfInputNamedLengths().getType())); |
7715 | return db; |
7716 | } |
7717 | |
7718 | void SparseLengthsSumGradNode::visit(Node *parent, NodeWalker *visitor) { |
7719 | if (!visitor->shouldVisit(parent, this)) { return; } |
7720 | visitor->pre(parent, this); |
7721 | if (hasPredicate()) |
7722 | getPredicate().getNode()->visit(this, visitor); |
7723 | getData().getNode()->visit(this, visitor); |
7724 | getIndices().getNode()->visit(this, visitor); |
7725 | getLengths().getNode()->visit(this, visitor); |
7726 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
7727 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
7728 | visitor->post(parent, this); |
7729 | } |
7730 | |
7731 | bool SparseLengthsSumGradNode::isEqual(const SparseLengthsSumGradNode &other) const { |
7732 | return true && |
7733 | Data_ == other.Data_ && |
7734 | Indices_ == other.Indices_ && |
7735 | Lengths_ == other.Lengths_ && |
7736 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
7737 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
7738 | predicate_ == other.predicate_ && |
7739 | LengthsMode_ == other.LengthsMode_ && |
7740 | AvgLength_ == other.AvgLength_ && |
7741 | getType(0) == other.getType(0) && |
7742 | getType(1) == other.getType(1) && |
7743 | getType(2) == other.getType(2); |
7744 | } |
7745 | |
7746 | Node* SparseLengthsSumGradNode::clone() const { |
7747 | return new SparseLengthsSumGradNode(getName(), getData(), getIndices(), getLengths(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getLengthsMode(), getAvgLength()); |
7748 | } |
7749 | |
7750 | llvm::hash_code SparseLengthsSumGradNode::getHash() const { |
7751 | return llvm::hash_combine( |
7752 | LengthsMode_, |
7753 | toBinary(AvgLength_), |
7754 | Data_, |
7755 | Indices_, |
7756 | Lengths_, |
7757 | OriginalOutputForResult_, |
7758 | GradOfOriginalOutputNamedResult_); |
7759 | } |
7760 | |
7761 | unsigned SparseLengthsSumNode::getNumInputs() const { |
7762 | return 3; |
7763 | } |
7764 | |
7765 | std::string SparseLengthsSumNode::getInputName(unsigned idx) const { |
7766 | if (idx == 0) { return "Data" ; } |
7767 | if (idx == 1) { return "Indices" ; } |
7768 | if (idx == 2) { return "Lengths" ; } |
7769 | idx -= 3; |
7770 | llvm_unreachable("Invalid index" ); |
7771 | } |
7772 | |
7773 | NodeValue SparseLengthsSumNode::getNthInput(unsigned idx) { |
7774 | if (idx == 0) { return Data_; } |
7775 | if (idx == 1) { return Indices_; } |
7776 | if (idx == 2) { return Lengths_; } |
7777 | idx -= 3; |
7778 | llvm_unreachable("Invalid index" ); |
7779 | } |
7780 | |
7781 | void SparseLengthsSumNode::setNthInput(unsigned idx, NodeValue val) { |
7782 | if (idx == 0) { Data_ = val; return; } |
7783 | if (idx == 1) { Indices_ = val; return; } |
7784 | if (idx == 2) { Lengths_ = val; return; } |
7785 | idx -= 3; |
7786 | llvm_unreachable("Invalid index" ); |
7787 | } |
7788 | |
7789 | llvm::StringRef SparseLengthsSumNode::getOutputName(unsigned idx) const { |
7790 | if (idx == 0) { return "Result" ; } |
7791 | llvm_unreachable("Invalid index" ); |
7792 | } |
7793 | |
7794 | std::string SparseLengthsSumNode::getDebugDesc() const { |
7795 | DescriptionBuilder db(getKindName()); |
7796 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7797 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7798 | db |
7799 | .addParam("Data" , *(getData().getType())) |
7800 | .addParam("Indices" , *(getIndices().getType())) |
7801 | .addParam("Lengths" , *(getLengths().getType())) |
7802 | .addParam("LengthsMode" , getLengthsMode()) |
7803 | .addParam("AvgLength" , getAvgLength()) |
7804 | .addParam("Users" , getNumUsers()); |
7805 | db.addParam("Result" , *(getResult().getType())); |
7806 | return db; |
7807 | } |
7808 | |
7809 | void SparseLengthsSumNode::visit(Node *parent, NodeWalker *visitor) { |
7810 | if (!visitor->shouldVisit(parent, this)) { return; } |
7811 | visitor->pre(parent, this); |
7812 | if (hasPredicate()) |
7813 | getPredicate().getNode()->visit(this, visitor); |
7814 | getData().getNode()->visit(this, visitor); |
7815 | getIndices().getNode()->visit(this, visitor); |
7816 | getLengths().getNode()->visit(this, visitor); |
7817 | visitor->post(parent, this); |
7818 | } |
7819 | |
7820 | bool SparseLengthsSumNode::isEqual(const SparseLengthsSumNode &other) const { |
7821 | return true && |
7822 | Data_ == other.Data_ && |
7823 | Indices_ == other.Indices_ && |
7824 | Lengths_ == other.Lengths_ && |
7825 | predicate_ == other.predicate_ && |
7826 | LengthsMode_ == other.LengthsMode_ && |
7827 | AvgLength_ == other.AvgLength_ && |
7828 | getType(0) == other.getType(0); |
7829 | } |
7830 | |
7831 | Node* SparseLengthsSumNode::clone() const { |
7832 | return new SparseLengthsSumNode(getName(), getResult().getType(), getData(), getIndices(), getLengths(), getLengthsMode(), getAvgLength()); |
7833 | } |
7834 | |
7835 | llvm::hash_code SparseLengthsSumNode::getHash() const { |
7836 | return llvm::hash_combine( |
7837 | LengthsMode_, |
7838 | toBinary(AvgLength_), |
7839 | Data_, |
7840 | Indices_, |
7841 | Lengths_); |
7842 | } |
7843 | |
7844 | SparseLengthsSumGradNode *SparseLengthsSumNode::getGrad(GraphGradMapper &builder) { |
7845 | auto *x = new SparseLengthsSumGradNode(getName().str() + "_grad" , getData(), getIndices(), getLengths(), getResult(), builder.getGradient(getResult()), getLengthsMode(), getAvgLength()); |
7846 | builder.addGradient(getData(), x->getGradOfInputNamedData()); |
7847 | builder.addGradient(getIndices(), x->getGradOfInputNamedIndices()); |
7848 | builder.addGradient(getLengths(), x->getGradOfInputNamedLengths()); |
7849 | return x; |
7850 | } |
7851 | |
7852 | unsigned SparseLengthsWeightedSumGradNode::getNumInputs() const { |
7853 | return 6; |
7854 | } |
7855 | |
7856 | std::string SparseLengthsWeightedSumGradNode::getInputName(unsigned idx) const { |
7857 | if (idx == 0) { return "Data" ; } |
7858 | if (idx == 1) { return "Weights" ; } |
7859 | if (idx == 2) { return "Indices" ; } |
7860 | if (idx == 3) { return "Lengths" ; } |
7861 | if (idx == 4) { return "OriginalOutputForResult" ; } |
7862 | if (idx == 5) { return "GradOfOriginalOutputNamedResult" ; } |
7863 | idx -= 6; |
7864 | llvm_unreachable("Invalid index" ); |
7865 | } |
7866 | |
7867 | NodeValue SparseLengthsWeightedSumGradNode::getNthInput(unsigned idx) { |
7868 | if (idx == 0) { return Data_; } |
7869 | if (idx == 1) { return Weights_; } |
7870 | if (idx == 2) { return Indices_; } |
7871 | if (idx == 3) { return Lengths_; } |
7872 | if (idx == 4) { return OriginalOutputForResult_; } |
7873 | if (idx == 5) { return GradOfOriginalOutputNamedResult_; } |
7874 | idx -= 6; |
7875 | llvm_unreachable("Invalid index" ); |
7876 | } |
7877 | |
7878 | void SparseLengthsWeightedSumGradNode::setNthInput(unsigned idx, NodeValue val) { |
7879 | if (idx == 0) { Data_ = val; return; } |
7880 | if (idx == 1) { Weights_ = val; return; } |
7881 | if (idx == 2) { Indices_ = val; return; } |
7882 | if (idx == 3) { Lengths_ = val; return; } |
7883 | if (idx == 4) { OriginalOutputForResult_ = val; return; } |
7884 | if (idx == 5) { GradOfOriginalOutputNamedResult_ = val; return; } |
7885 | idx -= 6; |
7886 | llvm_unreachable("Invalid index" ); |
7887 | } |
7888 | |
7889 | llvm::StringRef SparseLengthsWeightedSumGradNode::getOutputName(unsigned idx) const { |
7890 | if (idx == 0) { return "GradOfInputNamedData" ; } |
7891 | if (idx == 1) { return "GradOfInputNamedWeights" ; } |
7892 | if (idx == 2) { return "GradOfInputNamedIndices" ; } |
7893 | if (idx == 3) { return "GradOfInputNamedLengths" ; } |
7894 | llvm_unreachable("Invalid index" ); |
7895 | } |
7896 | |
7897 | std::string SparseLengthsWeightedSumGradNode::getDebugDesc() const { |
7898 | DescriptionBuilder db(getKindName()); |
7899 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
7900 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
7901 | db |
7902 | .addParam("Data" , *(getData().getType())) |
7903 | .addParam("Weights" , *(getWeights().getType())) |
7904 | .addParam("Indices" , *(getIndices().getType())) |
7905 | .addParam("Lengths" , *(getLengths().getType())) |
7906 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
7907 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
7908 | .addParam("LengthsMode" , getLengthsMode()) |
7909 | .addParam("AvgLength" , getAvgLength()) |
7910 | .addParam("Users" , getNumUsers()); |
7911 | db.addParam("GradOfInputNamedData" , *(getGradOfInputNamedData().getType())); |
7912 | db.addParam("GradOfInputNamedWeights" , *(getGradOfInputNamedWeights().getType())); |
7913 | db.addParam("GradOfInputNamedIndices" , *(getGradOfInputNamedIndices().getType())); |
7914 | db.addParam("GradOfInputNamedLengths" , *(getGradOfInputNamedLengths().getType())); |
7915 | return db; |
7916 | } |
7917 | |
7918 | void SparseLengthsWeightedSumGradNode::visit(Node *parent, NodeWalker *visitor) { |
7919 | if (!visitor->shouldVisit(parent, this)) { return; } |
7920 | visitor->pre(parent, this); |
7921 | if (hasPredicate()) |
7922 | getPredicate().getNode()->visit(this, visitor); |
7923 | getData().getNode()->visit(this, visitor); |
7924 | getWeights().getNode()->visit(this, visitor); |
7925 | getIndices().getNode()->visit(this, visitor); |
7926 | getLengths().getNode()->visit(this, visitor); |
7927 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
7928 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
7929 | visitor->post(parent, this); |
7930 | } |
7931 | |
7932 | bool SparseLengthsWeightedSumGradNode::isEqual(const SparseLengthsWeightedSumGradNode &other) const { |
7933 | return true && |
7934 | Data_ == other.Data_ && |
7935 | Weights_ == other.Weights_ && |
7936 | Indices_ == other.Indices_ && |
7937 | Lengths_ == other.Lengths_ && |
7938 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
7939 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
7940 | predicate_ == other.predicate_ && |
7941 | LengthsMode_ == other.LengthsMode_ && |
7942 | AvgLength_ == other.AvgLength_ && |
7943 | getType(0) == other.getType(0) && |
7944 | getType(1) == other.getType(1) && |
7945 | getType(2) == other.getType(2) && |
7946 | getType(3) == other.getType(3); |
7947 | } |
7948 | |
7949 | Node* SparseLengthsWeightedSumGradNode::clone() const { |
7950 | return new SparseLengthsWeightedSumGradNode(getName(), getData(), getWeights(), getIndices(), getLengths(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), getLengthsMode(), getAvgLength()); |
7951 | } |
7952 | |
7953 | llvm::hash_code SparseLengthsWeightedSumGradNode::getHash() const { |
7954 | return llvm::hash_combine( |
7955 | LengthsMode_, |
7956 | toBinary(AvgLength_), |
7957 | Data_, |
7958 | Weights_, |
7959 | Indices_, |
7960 | Lengths_, |
7961 | OriginalOutputForResult_, |
7962 | GradOfOriginalOutputNamedResult_); |
7963 | } |
7964 | |
7965 | unsigned SparseLengthsWeightedSumNode::getNumInputs() const { |
7966 | return 4; |
7967 | } |
7968 | |
7969 | std::string SparseLengthsWeightedSumNode::getInputName(unsigned idx) const { |
7970 | if (idx == 0) { return "Data" ; } |
7971 | if (idx == 1) { return "Weights" ; } |
7972 | if (idx == 2) { return "Indices" ; } |
7973 | if (idx == 3) { return "Lengths" ; } |
7974 | idx -= 4; |
7975 | llvm_unreachable("Invalid index" ); |
7976 | } |
7977 | |
7978 | NodeValue SparseLengthsWeightedSumNode::getNthInput(unsigned idx) { |
7979 | if (idx == 0) { return Data_; } |
7980 | if (idx == 1) { return Weights_; } |
7981 | if (idx == 2) { return Indices_; } |
7982 | if (idx == 3) { return Lengths_; } |
7983 | idx -= 4; |
7984 | llvm_unreachable("Invalid index" ); |
7985 | } |
7986 | |
7987 | void SparseLengthsWeightedSumNode::setNthInput(unsigned idx, NodeValue val) { |
7988 | if (idx == 0) { Data_ = val; return; } |
7989 | if (idx == 1) { Weights_ = val; return; } |
7990 | if (idx == 2) { Indices_ = val; return; } |
7991 | if (idx == 3) { Lengths_ = val; return; } |
7992 | idx -= 4; |
7993 | llvm_unreachable("Invalid index" ); |
7994 | } |
7995 | |
7996 | llvm::StringRef SparseLengthsWeightedSumNode::getOutputName(unsigned idx) const { |
7997 | if (idx == 0) { return "Result" ; } |
7998 | llvm_unreachable("Invalid index" ); |
7999 | } |
8000 | |
8001 | std::string SparseLengthsWeightedSumNode::getDebugDesc() const { |
8002 | DescriptionBuilder db(getKindName()); |
8003 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8004 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8005 | db |
8006 | .addParam("Data" , *(getData().getType())) |
8007 | .addParam("Weights" , *(getWeights().getType())) |
8008 | .addParam("Indices" , *(getIndices().getType())) |
8009 | .addParam("Lengths" , *(getLengths().getType())) |
8010 | .addParam("LengthsMode" , getLengthsMode()) |
8011 | .addParam("AvgLength" , getAvgLength()) |
8012 | .addParam("Users" , getNumUsers()); |
8013 | db.addParam("Result" , *(getResult().getType())); |
8014 | return db; |
8015 | } |
8016 | |
8017 | void SparseLengthsWeightedSumNode::visit(Node *parent, NodeWalker *visitor) { |
8018 | if (!visitor->shouldVisit(parent, this)) { return; } |
8019 | visitor->pre(parent, this); |
8020 | if (hasPredicate()) |
8021 | getPredicate().getNode()->visit(this, visitor); |
8022 | getData().getNode()->visit(this, visitor); |
8023 | getWeights().getNode()->visit(this, visitor); |
8024 | getIndices().getNode()->visit(this, visitor); |
8025 | getLengths().getNode()->visit(this, visitor); |
8026 | visitor->post(parent, this); |
8027 | } |
8028 | |
8029 | bool SparseLengthsWeightedSumNode::isEqual(const SparseLengthsWeightedSumNode &other) const { |
8030 | return true && |
8031 | Data_ == other.Data_ && |
8032 | Weights_ == other.Weights_ && |
8033 | Indices_ == other.Indices_ && |
8034 | Lengths_ == other.Lengths_ && |
8035 | predicate_ == other.predicate_ && |
8036 | LengthsMode_ == other.LengthsMode_ && |
8037 | AvgLength_ == other.AvgLength_ && |
8038 | getType(0) == other.getType(0); |
8039 | } |
8040 | |
8041 | Node* SparseLengthsWeightedSumNode::clone() const { |
8042 | return new SparseLengthsWeightedSumNode(getName(), getResult().getType(), getData(), getWeights(), getIndices(), getLengths(), getLengthsMode(), getAvgLength()); |
8043 | } |
8044 | |
8045 | llvm::hash_code SparseLengthsWeightedSumNode::getHash() const { |
8046 | return llvm::hash_combine( |
8047 | LengthsMode_, |
8048 | toBinary(AvgLength_), |
8049 | Data_, |
8050 | Weights_, |
8051 | Indices_, |
8052 | Lengths_); |
8053 | } |
8054 | |
8055 | SparseLengthsWeightedSumGradNode *SparseLengthsWeightedSumNode::getGrad(GraphGradMapper &builder) { |
8056 | auto *x = new SparseLengthsWeightedSumGradNode(getName().str() + "_grad" , getData(), getWeights(), getIndices(), getLengths(), getResult(), builder.getGradient(getResult()), getLengthsMode(), getAvgLength()); |
8057 | builder.addGradient(getData(), x->getGradOfInputNamedData()); |
8058 | builder.addGradient(getWeights(), x->getGradOfInputNamedWeights()); |
8059 | builder.addGradient(getIndices(), x->getGradOfInputNamedIndices()); |
8060 | builder.addGradient(getLengths(), x->getGradOfInputNamedLengths()); |
8061 | return x; |
8062 | } |
8063 | |
8064 | unsigned EmbeddingNode::getNumInputs() const { |
8065 | return 2; |
8066 | } |
8067 | |
8068 | std::string EmbeddingNode::getInputName(unsigned idx) const { |
8069 | if (idx == 0) { return "Weights" ; } |
8070 | if (idx == 1) { return "Indices" ; } |
8071 | idx -= 2; |
8072 | llvm_unreachable("Invalid index" ); |
8073 | } |
8074 | |
8075 | NodeValue EmbeddingNode::getNthInput(unsigned idx) { |
8076 | if (idx == 0) { return Weights_; } |
8077 | if (idx == 1) { return Indices_; } |
8078 | idx -= 2; |
8079 | llvm_unreachable("Invalid index" ); |
8080 | } |
8081 | |
8082 | void EmbeddingNode::setNthInput(unsigned idx, NodeValue val) { |
8083 | if (idx == 0) { Weights_ = val; return; } |
8084 | if (idx == 1) { Indices_ = val; return; } |
8085 | idx -= 2; |
8086 | llvm_unreachable("Invalid index" ); |
8087 | } |
8088 | |
8089 | llvm::StringRef EmbeddingNode::getOutputName(unsigned idx) const { |
8090 | if (idx == 0) { return "Result" ; } |
8091 | llvm_unreachable("Invalid index" ); |
8092 | } |
8093 | |
8094 | std::string EmbeddingNode::getDebugDesc() const { |
8095 | DescriptionBuilder db(getKindName()); |
8096 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8097 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8098 | db |
8099 | .addParam("Weights" , *(getWeights().getType())) |
8100 | .addParam("Indices" , *(getIndices().getType())) |
8101 | .addParam("PadIdx" , getPadIdx()) |
8102 | .addParam("Scale" , getScale()) |
8103 | .addParam("Sparse" , getSparse()) |
8104 | .addParam("Users" , getNumUsers()); |
8105 | db.addParam("Result" , *(getResult().getType())); |
8106 | return db; |
8107 | } |
8108 | |
8109 | void EmbeddingNode::visit(Node *parent, NodeWalker *visitor) { |
8110 | if (!visitor->shouldVisit(parent, this)) { return; } |
8111 | visitor->pre(parent, this); |
8112 | if (hasPredicate()) |
8113 | getPredicate().getNode()->visit(this, visitor); |
8114 | getWeights().getNode()->visit(this, visitor); |
8115 | getIndices().getNode()->visit(this, visitor); |
8116 | visitor->post(parent, this); |
8117 | } |
8118 | |
8119 | bool EmbeddingNode::isEqual(const EmbeddingNode &other) const { |
8120 | return true && |
8121 | Weights_ == other.Weights_ && |
8122 | Indices_ == other.Indices_ && |
8123 | predicate_ == other.predicate_ && |
8124 | PadIdx_ == other.PadIdx_ && |
8125 | Scale_ == other.Scale_ && |
8126 | Sparse_ == other.Sparse_ && |
8127 | getType(0) == other.getType(0); |
8128 | } |
8129 | |
8130 | Node* EmbeddingNode::clone() const { |
8131 | return new EmbeddingNode(getName(), getResult().getType(), getWeights(), getIndices(), getPadIdx(), getScale(), getSparse()); |
8132 | } |
8133 | |
8134 | llvm::hash_code EmbeddingNode::getHash() const { |
8135 | return llvm::hash_combine( |
8136 | PadIdx_, |
8137 | Scale_, |
8138 | Sparse_, |
8139 | Weights_, |
8140 | Indices_); |
8141 | } |
8142 | |
8143 | unsigned EmbeddingBagNode::getNumInputs() const { |
8144 | return 4; |
8145 | } |
8146 | |
8147 | std::string EmbeddingBagNode::getInputName(unsigned idx) const { |
8148 | if (idx == 0) { return "Data" ; } |
8149 | if (idx == 1) { return "Weights" ; } |
8150 | if (idx == 2) { return "Indices" ; } |
8151 | if (idx == 3) { return "Offsets" ; } |
8152 | idx -= 4; |
8153 | llvm_unreachable("Invalid index" ); |
8154 | } |
8155 | |
8156 | NodeValue EmbeddingBagNode::getNthInput(unsigned idx) { |
8157 | if (idx == 0) { return Data_; } |
8158 | if (idx == 1) { return Weights_; } |
8159 | if (idx == 2) { return Indices_; } |
8160 | if (idx == 3) { return Offsets_; } |
8161 | idx -= 4; |
8162 | llvm_unreachable("Invalid index" ); |
8163 | } |
8164 | |
8165 | void EmbeddingBagNode::setNthInput(unsigned idx, NodeValue val) { |
8166 | if (idx == 0) { Data_ = val; return; } |
8167 | if (idx == 1) { Weights_ = val; return; } |
8168 | if (idx == 2) { Indices_ = val; return; } |
8169 | if (idx == 3) { Offsets_ = val; return; } |
8170 | idx -= 4; |
8171 | llvm_unreachable("Invalid index" ); |
8172 | } |
8173 | |
8174 | llvm::StringRef EmbeddingBagNode::getOutputName(unsigned idx) const { |
8175 | if (idx == 0) { return "Result" ; } |
8176 | llvm_unreachable("Invalid index" ); |
8177 | } |
8178 | |
8179 | std::string EmbeddingBagNode::getDebugDesc() const { |
8180 | DescriptionBuilder db(getKindName()); |
8181 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8182 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8183 | db |
8184 | .addParam("Data" , *(getData().getType())) |
8185 | .addParam("Weights" , *(getWeights().getType())) |
8186 | .addParam("Indices" , *(getIndices().getType())) |
8187 | .addParam("Offsets" , *(getOffsets().getType())) |
8188 | .addParam("HasEndOffset" , getHasEndOffset()) |
8189 | .addParam("LengthsMode" , getLengthsMode()) |
8190 | .addParam("AvgLength" , getAvgLength()) |
8191 | .addParam("Users" , getNumUsers()); |
8192 | db.addParam("Result" , *(getResult().getType())); |
8193 | return db; |
8194 | } |
8195 | |
8196 | void EmbeddingBagNode::visit(Node *parent, NodeWalker *visitor) { |
8197 | if (!visitor->shouldVisit(parent, this)) { return; } |
8198 | visitor->pre(parent, this); |
8199 | if (hasPredicate()) |
8200 | getPredicate().getNode()->visit(this, visitor); |
8201 | getData().getNode()->visit(this, visitor); |
8202 | getWeights().getNode()->visit(this, visitor); |
8203 | getIndices().getNode()->visit(this, visitor); |
8204 | getOffsets().getNode()->visit(this, visitor); |
8205 | visitor->post(parent, this); |
8206 | } |
8207 | |
8208 | bool EmbeddingBagNode::isEqual(const EmbeddingBagNode &other) const { |
8209 | return true && |
8210 | Data_ == other.Data_ && |
8211 | Weights_ == other.Weights_ && |
8212 | Indices_ == other.Indices_ && |
8213 | Offsets_ == other.Offsets_ && |
8214 | predicate_ == other.predicate_ && |
8215 | HasEndOffset_ == other.HasEndOffset_ && |
8216 | LengthsMode_ == other.LengthsMode_ && |
8217 | AvgLength_ == other.AvgLength_ && |
8218 | getType(0) == other.getType(0); |
8219 | } |
8220 | |
8221 | Node* EmbeddingBagNode::clone() const { |
8222 | return new EmbeddingBagNode(getName(), getResult().getType(), getData(), getWeights(), getIndices(), getOffsets(), getHasEndOffset(), getLengthsMode(), getAvgLength()); |
8223 | } |
8224 | |
8225 | llvm::hash_code EmbeddingBagNode::getHash() const { |
8226 | return llvm::hash_combine( |
8227 | HasEndOffset_, |
8228 | LengthsMode_, |
8229 | toBinary(AvgLength_), |
8230 | Data_, |
8231 | Weights_, |
8232 | Indices_, |
8233 | Offsets_); |
8234 | } |
8235 | |
8236 | unsigned EmbeddingBagByteRowwiseOffsetsNode::getNumInputs() const { |
8237 | return 4; |
8238 | } |
8239 | |
8240 | std::string EmbeddingBagByteRowwiseOffsetsNode::getInputName(unsigned idx) const { |
8241 | if (idx == 0) { return "Data" ; } |
8242 | if (idx == 1) { return "Weights" ; } |
8243 | if (idx == 2) { return "Indices" ; } |
8244 | if (idx == 3) { return "Offsets" ; } |
8245 | idx -= 4; |
8246 | llvm_unreachable("Invalid index" ); |
8247 | } |
8248 | |
8249 | NodeValue EmbeddingBagByteRowwiseOffsetsNode::getNthInput(unsigned idx) { |
8250 | if (idx == 0) { return Data_; } |
8251 | if (idx == 1) { return Weights_; } |
8252 | if (idx == 2) { return Indices_; } |
8253 | if (idx == 3) { return Offsets_; } |
8254 | idx -= 4; |
8255 | llvm_unreachable("Invalid index" ); |
8256 | } |
8257 | |
8258 | void EmbeddingBagByteRowwiseOffsetsNode::setNthInput(unsigned idx, NodeValue val) { |
8259 | if (idx == 0) { Data_ = val; return; } |
8260 | if (idx == 1) { Weights_ = val; return; } |
8261 | if (idx == 2) { Indices_ = val; return; } |
8262 | if (idx == 3) { Offsets_ = val; return; } |
8263 | idx -= 4; |
8264 | llvm_unreachable("Invalid index" ); |
8265 | } |
8266 | |
8267 | llvm::StringRef EmbeddingBagByteRowwiseOffsetsNode::getOutputName(unsigned idx) const { |
8268 | if (idx == 0) { return "Result" ; } |
8269 | llvm_unreachable("Invalid index" ); |
8270 | } |
8271 | |
8272 | std::string EmbeddingBagByteRowwiseOffsetsNode::getDebugDesc() const { |
8273 | DescriptionBuilder db(getKindName()); |
8274 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8275 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8276 | db |
8277 | .addParam("Data" , *(getData().getType())) |
8278 | .addParam("Weights" , *(getWeights().getType())) |
8279 | .addParam("Indices" , *(getIndices().getType())) |
8280 | .addParam("Offsets" , *(getOffsets().getType())) |
8281 | .addParam("UseFP16Accumulation" , getUseFP16Accumulation()) |
8282 | .addParam("HasEndOffset" , getHasEndOffset()) |
8283 | .addParam("LengthsMode" , getLengthsMode()) |
8284 | .addParam("AvgLength" , getAvgLength()) |
8285 | .addParam("Users" , getNumUsers()); |
8286 | db.addParam("Result" , *(getResult().getType())); |
8287 | return db; |
8288 | } |
8289 | |
8290 | void EmbeddingBagByteRowwiseOffsetsNode::visit(Node *parent, NodeWalker *visitor) { |
8291 | if (!visitor->shouldVisit(parent, this)) { return; } |
8292 | visitor->pre(parent, this); |
8293 | if (hasPredicate()) |
8294 | getPredicate().getNode()->visit(this, visitor); |
8295 | getData().getNode()->visit(this, visitor); |
8296 | getWeights().getNode()->visit(this, visitor); |
8297 | getIndices().getNode()->visit(this, visitor); |
8298 | getOffsets().getNode()->visit(this, visitor); |
8299 | visitor->post(parent, this); |
8300 | } |
8301 | |
8302 | bool EmbeddingBagByteRowwiseOffsetsNode::isEqual(const EmbeddingBagByteRowwiseOffsetsNode &other) const { |
8303 | return true && |
8304 | Data_ == other.Data_ && |
8305 | Weights_ == other.Weights_ && |
8306 | Indices_ == other.Indices_ && |
8307 | Offsets_ == other.Offsets_ && |
8308 | predicate_ == other.predicate_ && |
8309 | UseFP16Accumulation_ == other.UseFP16Accumulation_ && |
8310 | HasEndOffset_ == other.HasEndOffset_ && |
8311 | LengthsMode_ == other.LengthsMode_ && |
8312 | AvgLength_ == other.AvgLength_ && |
8313 | getType(0) == other.getType(0); |
8314 | } |
8315 | |
8316 | Node* EmbeddingBagByteRowwiseOffsetsNode::clone() const { |
8317 | return new EmbeddingBagByteRowwiseOffsetsNode(getName(), getResult().getType(), getData(), getWeights(), getIndices(), getOffsets(), getUseFP16Accumulation(), getHasEndOffset(), getLengthsMode(), getAvgLength()); |
8318 | } |
8319 | |
8320 | llvm::hash_code EmbeddingBagByteRowwiseOffsetsNode::getHash() const { |
8321 | return llvm::hash_combine( |
8322 | UseFP16Accumulation_, |
8323 | HasEndOffset_, |
8324 | LengthsMode_, |
8325 | toBinary(AvgLength_), |
8326 | Data_, |
8327 | Weights_, |
8328 | Indices_, |
8329 | Offsets_); |
8330 | } |
8331 | |
8332 | unsigned RowwiseQuantizedSparseLengthsWeightedSumNode::getNumInputs() const { |
8333 | return 6; |
8334 | } |
8335 | |
8336 | std::string RowwiseQuantizedSparseLengthsWeightedSumNode::getInputName(unsigned idx) const { |
8337 | if (idx == 0) { return "Data" ; } |
8338 | if (idx == 1) { return "Scales" ; } |
8339 | if (idx == 2) { return "Offsets" ; } |
8340 | if (idx == 3) { return "Weights" ; } |
8341 | if (idx == 4) { return "Indices" ; } |
8342 | if (idx == 5) { return "Lengths" ; } |
8343 | idx -= 6; |
8344 | llvm_unreachable("Invalid index" ); |
8345 | } |
8346 | |
8347 | NodeValue RowwiseQuantizedSparseLengthsWeightedSumNode::getNthInput(unsigned idx) { |
8348 | if (idx == 0) { return Data_; } |
8349 | if (idx == 1) { return Scales_; } |
8350 | if (idx == 2) { return Offsets_; } |
8351 | if (idx == 3) { return Weights_; } |
8352 | if (idx == 4) { return Indices_; } |
8353 | if (idx == 5) { return Lengths_; } |
8354 | idx -= 6; |
8355 | llvm_unreachable("Invalid index" ); |
8356 | } |
8357 | |
8358 | void RowwiseQuantizedSparseLengthsWeightedSumNode::setNthInput(unsigned idx, NodeValue val) { |
8359 | if (idx == 0) { Data_ = val; return; } |
8360 | if (idx == 1) { Scales_ = val; return; } |
8361 | if (idx == 2) { Offsets_ = val; return; } |
8362 | if (idx == 3) { Weights_ = val; return; } |
8363 | if (idx == 4) { Indices_ = val; return; } |
8364 | if (idx == 5) { Lengths_ = val; return; } |
8365 | idx -= 6; |
8366 | llvm_unreachable("Invalid index" ); |
8367 | } |
8368 | |
8369 | llvm::StringRef RowwiseQuantizedSparseLengthsWeightedSumNode::getOutputName(unsigned idx) const { |
8370 | if (idx == 0) { return "Result" ; } |
8371 | llvm_unreachable("Invalid index" ); |
8372 | } |
8373 | |
8374 | std::string RowwiseQuantizedSparseLengthsWeightedSumNode::getDebugDesc() const { |
8375 | DescriptionBuilder db(getKindName()); |
8376 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8377 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8378 | db |
8379 | .addParam("Data" , *(getData().getType())) |
8380 | .addParam("Scales" , *(getScales().getType())) |
8381 | .addParam("Offsets" , *(getOffsets().getType())) |
8382 | .addParam("Weights" , *(getWeights().getType())) |
8383 | .addParam("Indices" , *(getIndices().getType())) |
8384 | .addParam("Lengths" , *(getLengths().getType())) |
8385 | .addParam("UseFP16Accumulation" , getUseFP16Accumulation()) |
8386 | .addParam("LengthsMode" , getLengthsMode()) |
8387 | .addParam("AvgLength" , getAvgLength()) |
8388 | .addParam("Users" , getNumUsers()); |
8389 | db.addParam("Result" , *(getResult().getType())); |
8390 | return db; |
8391 | } |
8392 | |
8393 | void RowwiseQuantizedSparseLengthsWeightedSumNode::visit(Node *parent, NodeWalker *visitor) { |
8394 | if (!visitor->shouldVisit(parent, this)) { return; } |
8395 | visitor->pre(parent, this); |
8396 | if (hasPredicate()) |
8397 | getPredicate().getNode()->visit(this, visitor); |
8398 | getData().getNode()->visit(this, visitor); |
8399 | getScales().getNode()->visit(this, visitor); |
8400 | getOffsets().getNode()->visit(this, visitor); |
8401 | getWeights().getNode()->visit(this, visitor); |
8402 | getIndices().getNode()->visit(this, visitor); |
8403 | getLengths().getNode()->visit(this, visitor); |
8404 | visitor->post(parent, this); |
8405 | } |
8406 | |
8407 | bool RowwiseQuantizedSparseLengthsWeightedSumNode::isEqual(const RowwiseQuantizedSparseLengthsWeightedSumNode &other) const { |
8408 | return true && |
8409 | Data_ == other.Data_ && |
8410 | Scales_ == other.Scales_ && |
8411 | Offsets_ == other.Offsets_ && |
8412 | Weights_ == other.Weights_ && |
8413 | Indices_ == other.Indices_ && |
8414 | Lengths_ == other.Lengths_ && |
8415 | predicate_ == other.predicate_ && |
8416 | UseFP16Accumulation_ == other.UseFP16Accumulation_ && |
8417 | LengthsMode_ == other.LengthsMode_ && |
8418 | AvgLength_ == other.AvgLength_ && |
8419 | getType(0) == other.getType(0); |
8420 | } |
8421 | |
8422 | Node* RowwiseQuantizedSparseLengthsWeightedSumNode::clone() const { |
8423 | return new RowwiseQuantizedSparseLengthsWeightedSumNode(getName(), getResult().getType(), getData(), getScales(), getOffsets(), getWeights(), getIndices(), getLengths(), getUseFP16Accumulation(), getLengthsMode(), getAvgLength()); |
8424 | } |
8425 | |
8426 | llvm::hash_code RowwiseQuantizedSparseLengthsWeightedSumNode::getHash() const { |
8427 | return llvm::hash_combine( |
8428 | UseFP16Accumulation_, |
8429 | LengthsMode_, |
8430 | toBinary(AvgLength_), |
8431 | Data_, |
8432 | Scales_, |
8433 | Offsets_, |
8434 | Weights_, |
8435 | Indices_, |
8436 | Lengths_); |
8437 | } |
8438 | |
8439 | unsigned FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getNumInputs() const { |
8440 | return 4; |
8441 | } |
8442 | |
8443 | std::string FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getInputName(unsigned idx) const { |
8444 | if (idx == 0) { return "Data" ; } |
8445 | if (idx == 1) { return "Weights" ; } |
8446 | if (idx == 2) { return "Indices" ; } |
8447 | if (idx == 3) { return "Lengths" ; } |
8448 | idx -= 4; |
8449 | llvm_unreachable("Invalid index" ); |
8450 | } |
8451 | |
8452 | NodeValue FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getNthInput(unsigned idx) { |
8453 | if (idx == 0) { return Data_; } |
8454 | if (idx == 1) { return Weights_; } |
8455 | if (idx == 2) { return Indices_; } |
8456 | if (idx == 3) { return Lengths_; } |
8457 | idx -= 4; |
8458 | llvm_unreachable("Invalid index" ); |
8459 | } |
8460 | |
8461 | void FusedRowwiseQuantizedSparseLengthsWeightedSumNode::setNthInput(unsigned idx, NodeValue val) { |
8462 | if (idx == 0) { Data_ = val; return; } |
8463 | if (idx == 1) { Weights_ = val; return; } |
8464 | if (idx == 2) { Indices_ = val; return; } |
8465 | if (idx == 3) { Lengths_ = val; return; } |
8466 | idx -= 4; |
8467 | llvm_unreachable("Invalid index" ); |
8468 | } |
8469 | |
8470 | llvm::StringRef FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getOutputName(unsigned idx) const { |
8471 | if (idx == 0) { return "Result" ; } |
8472 | llvm_unreachable("Invalid index" ); |
8473 | } |
8474 | |
8475 | std::string FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getDebugDesc() const { |
8476 | DescriptionBuilder db(getKindName()); |
8477 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8478 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8479 | db |
8480 | .addParam("Data" , *(getData().getType())) |
8481 | .addParam("Weights" , *(getWeights().getType())) |
8482 | .addParam("Indices" , *(getIndices().getType())) |
8483 | .addParam("Lengths" , *(getLengths().getType())) |
8484 | .addParam("UseFP16Accumulation" , getUseFP16Accumulation()) |
8485 | .addParam("LengthsMode" , getLengthsMode()) |
8486 | .addParam("AvgLength" , getAvgLength()) |
8487 | .addParam("Users" , getNumUsers()); |
8488 | db.addParam("Result" , *(getResult().getType())); |
8489 | return db; |
8490 | } |
8491 | |
8492 | void FusedRowwiseQuantizedSparseLengthsWeightedSumNode::visit(Node *parent, NodeWalker *visitor) { |
8493 | if (!visitor->shouldVisit(parent, this)) { return; } |
8494 | visitor->pre(parent, this); |
8495 | if (hasPredicate()) |
8496 | getPredicate().getNode()->visit(this, visitor); |
8497 | getData().getNode()->visit(this, visitor); |
8498 | getWeights().getNode()->visit(this, visitor); |
8499 | getIndices().getNode()->visit(this, visitor); |
8500 | getLengths().getNode()->visit(this, visitor); |
8501 | visitor->post(parent, this); |
8502 | } |
8503 | |
8504 | bool FusedRowwiseQuantizedSparseLengthsWeightedSumNode::isEqual(const FusedRowwiseQuantizedSparseLengthsWeightedSumNode &other) const { |
8505 | return true && |
8506 | Data_ == other.Data_ && |
8507 | Weights_ == other.Weights_ && |
8508 | Indices_ == other.Indices_ && |
8509 | Lengths_ == other.Lengths_ && |
8510 | predicate_ == other.predicate_ && |
8511 | UseFP16Accumulation_ == other.UseFP16Accumulation_ && |
8512 | LengthsMode_ == other.LengthsMode_ && |
8513 | AvgLength_ == other.AvgLength_ && |
8514 | getType(0) == other.getType(0); |
8515 | } |
8516 | |
8517 | Node* FusedRowwiseQuantizedSparseLengthsWeightedSumNode::clone() const { |
8518 | return new FusedRowwiseQuantizedSparseLengthsWeightedSumNode(getName(), getResult().getType(), getData(), getWeights(), getIndices(), getLengths(), getUseFP16Accumulation(), getLengthsMode(), getAvgLength()); |
8519 | } |
8520 | |
8521 | llvm::hash_code FusedRowwiseQuantizedSparseLengthsWeightedSumNode::getHash() const { |
8522 | return llvm::hash_combine( |
8523 | UseFP16Accumulation_, |
8524 | LengthsMode_, |
8525 | toBinary(AvgLength_), |
8526 | Data_, |
8527 | Weights_, |
8528 | Indices_, |
8529 | Lengths_); |
8530 | } |
8531 | |
8532 | unsigned FusedRowwiseQuantizedSparseLengthsSumNode::getNumInputs() const { |
8533 | return 3; |
8534 | } |
8535 | |
8536 | std::string FusedRowwiseQuantizedSparseLengthsSumNode::getInputName(unsigned idx) const { |
8537 | if (idx == 0) { return "Data" ; } |
8538 | if (idx == 1) { return "Indices" ; } |
8539 | if (idx == 2) { return "Lengths" ; } |
8540 | idx -= 3; |
8541 | llvm_unreachable("Invalid index" ); |
8542 | } |
8543 | |
8544 | NodeValue FusedRowwiseQuantizedSparseLengthsSumNode::getNthInput(unsigned idx) { |
8545 | if (idx == 0) { return Data_; } |
8546 | if (idx == 1) { return Indices_; } |
8547 | if (idx == 2) { return Lengths_; } |
8548 | idx -= 3; |
8549 | llvm_unreachable("Invalid index" ); |
8550 | } |
8551 | |
8552 | void FusedRowwiseQuantizedSparseLengthsSumNode::setNthInput(unsigned idx, NodeValue val) { |
8553 | if (idx == 0) { Data_ = val; return; } |
8554 | if (idx == 1) { Indices_ = val; return; } |
8555 | if (idx == 2) { Lengths_ = val; return; } |
8556 | idx -= 3; |
8557 | llvm_unreachable("Invalid index" ); |
8558 | } |
8559 | |
8560 | llvm::StringRef FusedRowwiseQuantizedSparseLengthsSumNode::getOutputName(unsigned idx) const { |
8561 | if (idx == 0) { return "Result" ; } |
8562 | llvm_unreachable("Invalid index" ); |
8563 | } |
8564 | |
8565 | std::string FusedRowwiseQuantizedSparseLengthsSumNode::getDebugDesc() const { |
8566 | DescriptionBuilder db(getKindName()); |
8567 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8568 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8569 | db |
8570 | .addParam("Data" , *(getData().getType())) |
8571 | .addParam("Indices" , *(getIndices().getType())) |
8572 | .addParam("Lengths" , *(getLengths().getType())) |
8573 | .addParam("UseFP16Accumulation" , getUseFP16Accumulation()) |
8574 | .addParam("LengthsMode" , getLengthsMode()) |
8575 | .addParam("AvgLength" , getAvgLength()) |
8576 | .addParam("Users" , getNumUsers()); |
8577 | db.addParam("Result" , *(getResult().getType())); |
8578 | return db; |
8579 | } |
8580 | |
8581 | void FusedRowwiseQuantizedSparseLengthsSumNode::visit(Node *parent, NodeWalker *visitor) { |
8582 | if (!visitor->shouldVisit(parent, this)) { return; } |
8583 | visitor->pre(parent, this); |
8584 | if (hasPredicate()) |
8585 | getPredicate().getNode()->visit(this, visitor); |
8586 | getData().getNode()->visit(this, visitor); |
8587 | getIndices().getNode()->visit(this, visitor); |
8588 | getLengths().getNode()->visit(this, visitor); |
8589 | visitor->post(parent, this); |
8590 | } |
8591 | |
8592 | bool FusedRowwiseQuantizedSparseLengthsSumNode::isEqual(const FusedRowwiseQuantizedSparseLengthsSumNode &other) const { |
8593 | return true && |
8594 | Data_ == other.Data_ && |
8595 | Indices_ == other.Indices_ && |
8596 | Lengths_ == other.Lengths_ && |
8597 | predicate_ == other.predicate_ && |
8598 | UseFP16Accumulation_ == other.UseFP16Accumulation_ && |
8599 | LengthsMode_ == other.LengthsMode_ && |
8600 | AvgLength_ == other.AvgLength_ && |
8601 | getType(0) == other.getType(0); |
8602 | } |
8603 | |
8604 | Node* FusedRowwiseQuantizedSparseLengthsSumNode::clone() const { |
8605 | return new FusedRowwiseQuantizedSparseLengthsSumNode(getName(), getResult().getType(), getData(), getIndices(), getLengths(), getUseFP16Accumulation(), getLengthsMode(), getAvgLength()); |
8606 | } |
8607 | |
8608 | llvm::hash_code FusedRowwiseQuantizedSparseLengthsSumNode::getHash() const { |
8609 | return llvm::hash_combine( |
8610 | UseFP16Accumulation_, |
8611 | LengthsMode_, |
8612 | toBinary(AvgLength_), |
8613 | Data_, |
8614 | Indices_, |
8615 | Lengths_); |
8616 | } |
8617 | |
8618 | unsigned LengthsToRangesNode::getNumInputs() const { |
8619 | return 1; |
8620 | } |
8621 | |
8622 | std::string LengthsToRangesNode::getInputName(unsigned idx) const { |
8623 | if (idx == 0) { return "Lengths" ; } |
8624 | idx -= 1; |
8625 | llvm_unreachable("Invalid index" ); |
8626 | } |
8627 | |
8628 | NodeValue LengthsToRangesNode::getNthInput(unsigned idx) { |
8629 | if (idx == 0) { return Lengths_; } |
8630 | idx -= 1; |
8631 | llvm_unreachable("Invalid index" ); |
8632 | } |
8633 | |
8634 | void LengthsToRangesNode::setNthInput(unsigned idx, NodeValue val) { |
8635 | if (idx == 0) { Lengths_ = val; return; } |
8636 | idx -= 1; |
8637 | llvm_unreachable("Invalid index" ); |
8638 | } |
8639 | |
8640 | llvm::StringRef LengthsToRangesNode::getOutputName(unsigned idx) const { |
8641 | if (idx == 0) { return "Result" ; } |
8642 | llvm_unreachable("Invalid index" ); |
8643 | } |
8644 | |
8645 | std::string LengthsToRangesNode::getDebugDesc() const { |
8646 | DescriptionBuilder db(getKindName()); |
8647 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8648 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8649 | db |
8650 | .addParam("Lengths" , *(getLengths().getType())) |
8651 | .addParam("Users" , getNumUsers()); |
8652 | db.addParam("Result" , *(getResult().getType())); |
8653 | return db; |
8654 | } |
8655 | |
8656 | void LengthsToRangesNode::visit(Node *parent, NodeWalker *visitor) { |
8657 | if (!visitor->shouldVisit(parent, this)) { return; } |
8658 | visitor->pre(parent, this); |
8659 | if (hasPredicate()) |
8660 | getPredicate().getNode()->visit(this, visitor); |
8661 | getLengths().getNode()->visit(this, visitor); |
8662 | visitor->post(parent, this); |
8663 | } |
8664 | |
8665 | bool LengthsToRangesNode::isEqual(const LengthsToRangesNode &other) const { |
8666 | return true && |
8667 | Lengths_ == other.Lengths_ && |
8668 | predicate_ == other.predicate_ && |
8669 | getType(0) == other.getType(0); |
8670 | } |
8671 | |
8672 | Node* LengthsToRangesNode::clone() const { |
8673 | return new LengthsToRangesNode(getName(), getResult().getType(), getLengths()); |
8674 | } |
8675 | |
8676 | llvm::hash_code LengthsToRangesNode::getHash() const { |
8677 | return llvm::hash_combine( |
8678 | Lengths_); |
8679 | } |
8680 | |
8681 | unsigned LengthsRangeFillNode::getNumInputs() const { |
8682 | return 1; |
8683 | } |
8684 | |
8685 | std::string LengthsRangeFillNode::getInputName(unsigned idx) const { |
8686 | if (idx == 0) { return "Lengths" ; } |
8687 | idx -= 1; |
8688 | llvm_unreachable("Invalid index" ); |
8689 | } |
8690 | |
8691 | NodeValue LengthsRangeFillNode::getNthInput(unsigned idx) { |
8692 | if (idx == 0) { return Lengths_; } |
8693 | idx -= 1; |
8694 | llvm_unreachable("Invalid index" ); |
8695 | } |
8696 | |
8697 | void LengthsRangeFillNode::setNthInput(unsigned idx, NodeValue val) { |
8698 | if (idx == 0) { Lengths_ = val; return; } |
8699 | idx -= 1; |
8700 | llvm_unreachable("Invalid index" ); |
8701 | } |
8702 | |
8703 | llvm::StringRef LengthsRangeFillNode::getOutputName(unsigned idx) const { |
8704 | if (idx == 0) { return "Result" ; } |
8705 | llvm_unreachable("Invalid index" ); |
8706 | } |
8707 | |
8708 | std::string LengthsRangeFillNode::getDebugDesc() const { |
8709 | DescriptionBuilder db(getKindName()); |
8710 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8711 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8712 | db |
8713 | .addParam("Lengths" , *(getLengths().getType())) |
8714 | .addParam("Users" , getNumUsers()); |
8715 | db.addParam("Result" , *(getResult().getType())); |
8716 | return db; |
8717 | } |
8718 | |
8719 | void LengthsRangeFillNode::visit(Node *parent, NodeWalker *visitor) { |
8720 | if (!visitor->shouldVisit(parent, this)) { return; } |
8721 | visitor->pre(parent, this); |
8722 | if (hasPredicate()) |
8723 | getPredicate().getNode()->visit(this, visitor); |
8724 | getLengths().getNode()->visit(this, visitor); |
8725 | visitor->post(parent, this); |
8726 | } |
8727 | |
8728 | bool LengthsRangeFillNode::isEqual(const LengthsRangeFillNode &other) const { |
8729 | return true && |
8730 | Lengths_ == other.Lengths_ && |
8731 | predicate_ == other.predicate_ && |
8732 | getType(0) == other.getType(0); |
8733 | } |
8734 | |
8735 | Node* LengthsRangeFillNode::clone() const { |
8736 | return new LengthsRangeFillNode(getName(), getResult().getType(), getLengths()); |
8737 | } |
8738 | |
8739 | llvm::hash_code LengthsRangeFillNode::getHash() const { |
8740 | return llvm::hash_combine( |
8741 | Lengths_); |
8742 | } |
8743 | |
8744 | unsigned BatchSparseToDenseNode::getNumInputs() const { |
8745 | return 3; |
8746 | } |
8747 | |
8748 | std::string BatchSparseToDenseNode::getInputName(unsigned idx) const { |
8749 | if (idx == 0) { return "Lengths" ; } |
8750 | if (idx == 1) { return "Indices" ; } |
8751 | if (idx == 2) { return "Values" ; } |
8752 | idx -= 3; |
8753 | llvm_unreachable("Invalid index" ); |
8754 | } |
8755 | |
8756 | NodeValue BatchSparseToDenseNode::getNthInput(unsigned idx) { |
8757 | if (idx == 0) { return Lengths_; } |
8758 | if (idx == 1) { return Indices_; } |
8759 | if (idx == 2) { return Values_; } |
8760 | idx -= 3; |
8761 | llvm_unreachable("Invalid index" ); |
8762 | } |
8763 | |
8764 | void BatchSparseToDenseNode::setNthInput(unsigned idx, NodeValue val) { |
8765 | if (idx == 0) { Lengths_ = val; return; } |
8766 | if (idx == 1) { Indices_ = val; return; } |
8767 | if (idx == 2) { Values_ = val; return; } |
8768 | idx -= 3; |
8769 | llvm_unreachable("Invalid index" ); |
8770 | } |
8771 | |
8772 | llvm::StringRef BatchSparseToDenseNode::getOutputName(unsigned idx) const { |
8773 | if (idx == 0) { return "Result" ; } |
8774 | llvm_unreachable("Invalid index" ); |
8775 | } |
8776 | |
8777 | std::string BatchSparseToDenseNode::getDebugDesc() const { |
8778 | DescriptionBuilder db(getKindName()); |
8779 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8780 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8781 | db |
8782 | .addParam("Lengths" , *(getLengths().getType())) |
8783 | .addParam("Indices" , *(getIndices().getType())) |
8784 | .addParam("Values" , *(getValues().getType())) |
8785 | .addParam("DefaultValue" , getDefaultValue()) |
8786 | .addParam("DenseLastDim" , getDenseLastDim()) |
8787 | .addParam("Users" , getNumUsers()); |
8788 | db.addParam("Result" , *(getResult().getType())); |
8789 | return db; |
8790 | } |
8791 | |
8792 | void BatchSparseToDenseNode::visit(Node *parent, NodeWalker *visitor) { |
8793 | if (!visitor->shouldVisit(parent, this)) { return; } |
8794 | visitor->pre(parent, this); |
8795 | if (hasPredicate()) |
8796 | getPredicate().getNode()->visit(this, visitor); |
8797 | getLengths().getNode()->visit(this, visitor); |
8798 | getIndices().getNode()->visit(this, visitor); |
8799 | getValues().getNode()->visit(this, visitor); |
8800 | visitor->post(parent, this); |
8801 | } |
8802 | |
8803 | bool BatchSparseToDenseNode::isEqual(const BatchSparseToDenseNode &other) const { |
8804 | return true && |
8805 | Lengths_ == other.Lengths_ && |
8806 | Indices_ == other.Indices_ && |
8807 | Values_ == other.Values_ && |
8808 | predicate_ == other.predicate_ && |
8809 | DefaultValue_ == other.DefaultValue_ && |
8810 | DenseLastDim_ == other.DenseLastDim_ && |
8811 | getType(0) == other.getType(0); |
8812 | } |
8813 | |
8814 | Node* BatchSparseToDenseNode::clone() const { |
8815 | return new BatchSparseToDenseNode(getName(), getResult().getType(), getLengths(), getIndices(), getValues(), getDefaultValue(), getDenseLastDim()); |
8816 | } |
8817 | |
8818 | llvm::hash_code BatchSparseToDenseNode::getHash() const { |
8819 | return llvm::hash_combine( |
8820 | toBinary(DefaultValue_), |
8821 | DenseLastDim_, |
8822 | Lengths_, |
8823 | Indices_, |
8824 | Values_); |
8825 | } |
8826 | |
8827 | unsigned FillExamplesWithIndicatorNode::getNumInputs() const { |
8828 | return 2; |
8829 | } |
8830 | |
8831 | std::string FillExamplesWithIndicatorNode::getInputName(unsigned idx) const { |
8832 | if (idx == 0) { return "Data" ; } |
8833 | if (idx == 1) { return "Indicator" ; } |
8834 | idx -= 2; |
8835 | llvm_unreachable("Invalid index" ); |
8836 | } |
8837 | |
8838 | NodeValue FillExamplesWithIndicatorNode::getNthInput(unsigned idx) { |
8839 | if (idx == 0) { return Data_; } |
8840 | if (idx == 1) { return Indicator_; } |
8841 | idx -= 2; |
8842 | llvm_unreachable("Invalid index" ); |
8843 | } |
8844 | |
8845 | void FillExamplesWithIndicatorNode::setNthInput(unsigned idx, NodeValue val) { |
8846 | if (idx == 0) { Data_ = val; return; } |
8847 | if (idx == 1) { Indicator_ = val; return; } |
8848 | idx -= 2; |
8849 | llvm_unreachable("Invalid index" ); |
8850 | } |
8851 | |
8852 | llvm::StringRef FillExamplesWithIndicatorNode::getOutputName(unsigned idx) const { |
8853 | if (idx == 0) { return "Result" ; } |
8854 | llvm_unreachable("Invalid index" ); |
8855 | } |
8856 | |
8857 | std::string FillExamplesWithIndicatorNode::getDebugDesc() const { |
8858 | DescriptionBuilder db(getKindName()); |
8859 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8860 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8861 | db |
8862 | .addParam("Data" , *(getData().getType())) |
8863 | .addParam("Indicator" , *(getIndicator().getType())) |
8864 | .addParam("Users" , getNumUsers()); |
8865 | db.addParam("Result" , *(getResult().getType())); |
8866 | return db; |
8867 | } |
8868 | |
8869 | void FillExamplesWithIndicatorNode::visit(Node *parent, NodeWalker *visitor) { |
8870 | if (!visitor->shouldVisit(parent, this)) { return; } |
8871 | visitor->pre(parent, this); |
8872 | if (hasPredicate()) |
8873 | getPredicate().getNode()->visit(this, visitor); |
8874 | getData().getNode()->visit(this, visitor); |
8875 | getIndicator().getNode()->visit(this, visitor); |
8876 | visitor->post(parent, this); |
8877 | } |
8878 | |
8879 | bool FillExamplesWithIndicatorNode::isEqual(const FillExamplesWithIndicatorNode &other) const { |
8880 | return true && |
8881 | Data_ == other.Data_ && |
8882 | Indicator_ == other.Indicator_ && |
8883 | predicate_ == other.predicate_ && |
8884 | getType(0) == other.getType(0); |
8885 | } |
8886 | |
8887 | Node* FillExamplesWithIndicatorNode::clone() const { |
8888 | return new FillExamplesWithIndicatorNode(getName(), getResult().getType(), getData(), getIndicator()); |
8889 | } |
8890 | |
8891 | llvm::hash_code FillExamplesWithIndicatorNode::getHash() const { |
8892 | return llvm::hash_combine( |
8893 | Data_, |
8894 | Indicator_); |
8895 | } |
8896 | |
8897 | unsigned SparseToDenseMaskNode::getNumInputs() const { |
8898 | return 4; |
8899 | } |
8900 | |
8901 | std::string SparseToDenseMaskNode::getInputName(unsigned idx) const { |
8902 | if (idx == 0) { return "Indices" ; } |
8903 | if (idx == 1) { return "Values" ; } |
8904 | if (idx == 2) { return "DefaultValue" ; } |
8905 | if (idx == 3) { return "Lengths" ; } |
8906 | idx -= 4; |
8907 | llvm_unreachable("Invalid index" ); |
8908 | } |
8909 | |
8910 | NodeValue SparseToDenseMaskNode::getNthInput(unsigned idx) { |
8911 | if (idx == 0) { return Indices_; } |
8912 | if (idx == 1) { return Values_; } |
8913 | if (idx == 2) { return DefaultValue_; } |
8914 | if (idx == 3) { return Lengths_; } |
8915 | idx -= 4; |
8916 | llvm_unreachable("Invalid index" ); |
8917 | } |
8918 | |
8919 | void SparseToDenseMaskNode::setNthInput(unsigned idx, NodeValue val) { |
8920 | if (idx == 0) { Indices_ = val; return; } |
8921 | if (idx == 1) { Values_ = val; return; } |
8922 | if (idx == 2) { DefaultValue_ = val; return; } |
8923 | if (idx == 3) { Lengths_ = val; return; } |
8924 | idx -= 4; |
8925 | llvm_unreachable("Invalid index" ); |
8926 | } |
8927 | |
8928 | llvm::StringRef SparseToDenseMaskNode::getOutputName(unsigned idx) const { |
8929 | if (idx == 0) { return "Result" ; } |
8930 | llvm_unreachable("Invalid index" ); |
8931 | } |
8932 | |
8933 | std::string SparseToDenseMaskNode::getDebugDesc() const { |
8934 | DescriptionBuilder db(getKindName()); |
8935 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
8936 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
8937 | db |
8938 | .addParam("Indices" , *(getIndices().getType())) |
8939 | .addParam("Values" , *(getValues().getType())) |
8940 | .addParam("DefaultValue" , *(getDefaultValue().getType())) |
8941 | .addParam("Lengths" , *(getLengths().getType())) |
8942 | .addParam("Mask" , getMask()) |
8943 | .addParam("Users" , getNumUsers()); |
8944 | db.addParam("Result" , *(getResult().getType())); |
8945 | return db; |
8946 | } |
8947 | |
8948 | void SparseToDenseMaskNode::visit(Node *parent, NodeWalker *visitor) { |
8949 | if (!visitor->shouldVisit(parent, this)) { return; } |
8950 | visitor->pre(parent, this); |
8951 | if (hasPredicate()) |
8952 | getPredicate().getNode()->visit(this, visitor); |
8953 | getIndices().getNode()->visit(this, visitor); |
8954 | getValues().getNode()->visit(this, visitor); |
8955 | getDefaultValue().getNode()->visit(this, visitor); |
8956 | getLengths().getNode()->visit(this, visitor); |
8957 | visitor->post(parent, this); |
8958 | } |
8959 | |
8960 | bool SparseToDenseMaskNode::isEqual(const SparseToDenseMaskNode &other) const { |
8961 | return true && |
8962 | Indices_ == other.Indices_ && |
8963 | Values_ == other.Values_ && |
8964 | DefaultValue_ == other.DefaultValue_ && |
8965 | Lengths_ == other.Lengths_ && |
8966 | predicate_ == other.predicate_ && |
8967 | Mask_ == other.Mask_ && |
8968 | getType(0) == other.getType(0); |
8969 | } |
8970 | |
8971 | Node* SparseToDenseMaskNode::clone() const { |
8972 | return new SparseToDenseMaskNode(getName(), getResult().getType(), getIndices(), getValues(), getDefaultValue(), getLengths(), getMask()); |
8973 | } |
8974 | |
8975 | llvm::hash_code SparseToDenseMaskNode::getHash() const { |
8976 | return llvm::hash_combine( |
8977 | llvm::hash_combine_range(Mask_.begin(), Mask_.end()), |
8978 | Indices_, |
8979 | Values_, |
8980 | DefaultValue_, |
8981 | Lengths_); |
8982 | } |
8983 | |
8984 | unsigned IsNaNNode::getNumInputs() const { |
8985 | return 1; |
8986 | } |
8987 | |
8988 | std::string IsNaNNode::getInputName(unsigned idx) const { |
8989 | if (idx == 0) { return "Input" ; } |
8990 | idx -= 1; |
8991 | llvm_unreachable("Invalid index" ); |
8992 | } |
8993 | |
8994 | NodeValue IsNaNNode::getNthInput(unsigned idx) { |
8995 | if (idx == 0) { return Input_; } |
8996 | idx -= 1; |
8997 | llvm_unreachable("Invalid index" ); |
8998 | } |
8999 | |
9000 | void IsNaNNode::setNthInput(unsigned idx, NodeValue val) { |
9001 | if (idx == 0) { Input_ = val; return; } |
9002 | idx -= 1; |
9003 | llvm_unreachable("Invalid index" ); |
9004 | } |
9005 | |
9006 | llvm::StringRef IsNaNNode::getOutputName(unsigned idx) const { |
9007 | if (idx == 0) { return "Result" ; } |
9008 | llvm_unreachable("Invalid index" ); |
9009 | } |
9010 | |
9011 | std::string IsNaNNode::getDebugDesc() const { |
9012 | DescriptionBuilder db(getKindName()); |
9013 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9014 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9015 | db |
9016 | .addParam("Input" , *(getInput().getType())) |
9017 | .addParam("Users" , getNumUsers()); |
9018 | db.addParam("Result" , *(getResult().getType())); |
9019 | return db; |
9020 | } |
9021 | |
9022 | void IsNaNNode::visit(Node *parent, NodeWalker *visitor) { |
9023 | if (!visitor->shouldVisit(parent, this)) { return; } |
9024 | visitor->pre(parent, this); |
9025 | if (hasPredicate()) |
9026 | getPredicate().getNode()->visit(this, visitor); |
9027 | getInput().getNode()->visit(this, visitor); |
9028 | visitor->post(parent, this); |
9029 | } |
9030 | |
9031 | bool IsNaNNode::isEqual(const IsNaNNode &other) const { |
9032 | return true && |
9033 | Input_ == other.Input_ && |
9034 | predicate_ == other.predicate_ && |
9035 | getType(0) == other.getType(0); |
9036 | } |
9037 | |
9038 | Node* IsNaNNode::clone() const { |
9039 | return new IsNaNNode(getName(), getResult().getType(), getInput()); |
9040 | } |
9041 | |
9042 | llvm::hash_code IsNaNNode::getHash() const { |
9043 | return llvm::hash_combine( |
9044 | Input_); |
9045 | } |
9046 | |
9047 | unsigned ReplaceNaNNode::getNumInputs() const { |
9048 | return 1; |
9049 | } |
9050 | |
9051 | std::string ReplaceNaNNode::getInputName(unsigned idx) const { |
9052 | if (idx == 0) { return "Input" ; } |
9053 | idx -= 1; |
9054 | llvm_unreachable("Invalid index" ); |
9055 | } |
9056 | |
9057 | NodeValue ReplaceNaNNode::getNthInput(unsigned idx) { |
9058 | if (idx == 0) { return Input_; } |
9059 | idx -= 1; |
9060 | llvm_unreachable("Invalid index" ); |
9061 | } |
9062 | |
9063 | void ReplaceNaNNode::setNthInput(unsigned idx, NodeValue val) { |
9064 | if (idx == 0) { Input_ = val; return; } |
9065 | idx -= 1; |
9066 | llvm_unreachable("Invalid index" ); |
9067 | } |
9068 | |
9069 | llvm::StringRef ReplaceNaNNode::getOutputName(unsigned idx) const { |
9070 | if (idx == 0) { return "Result" ; } |
9071 | llvm_unreachable("Invalid index" ); |
9072 | } |
9073 | |
9074 | std::string ReplaceNaNNode::getDebugDesc() const { |
9075 | DescriptionBuilder db(getKindName()); |
9076 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9077 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9078 | db |
9079 | .addParam("Input" , *(getInput().getType())) |
9080 | .addParam("Value" , getValue()) |
9081 | .addParam("Users" , getNumUsers()); |
9082 | db.addParam("Result" , *(getResult().getType())); |
9083 | return db; |
9084 | } |
9085 | |
9086 | void ReplaceNaNNode::visit(Node *parent, NodeWalker *visitor) { |
9087 | if (!visitor->shouldVisit(parent, this)) { return; } |
9088 | visitor->pre(parent, this); |
9089 | if (hasPredicate()) |
9090 | getPredicate().getNode()->visit(this, visitor); |
9091 | getInput().getNode()->visit(this, visitor); |
9092 | visitor->post(parent, this); |
9093 | } |
9094 | |
9095 | bool ReplaceNaNNode::isEqual(const ReplaceNaNNode &other) const { |
9096 | return true && |
9097 | Input_ == other.Input_ && |
9098 | predicate_ == other.predicate_ && |
9099 | Value_ == other.Value_ && |
9100 | getType(0) == other.getType(0); |
9101 | } |
9102 | |
9103 | Node* ReplaceNaNNode::clone() const { |
9104 | return new ReplaceNaNNode(getName(), getResult().getType(), getInput(), getValue()); |
9105 | } |
9106 | |
9107 | llvm::hash_code ReplaceNaNNode::getHash() const { |
9108 | return llvm::hash_combine( |
9109 | toBinary(Value_), |
9110 | Input_); |
9111 | } |
9112 | |
9113 | unsigned ModuloNode::getNumInputs() const { |
9114 | return 1; |
9115 | } |
9116 | |
9117 | std::string ModuloNode::getInputName(unsigned idx) const { |
9118 | if (idx == 0) { return "Input" ; } |
9119 | idx -= 1; |
9120 | llvm_unreachable("Invalid index" ); |
9121 | } |
9122 | |
9123 | NodeValue ModuloNode::getNthInput(unsigned idx) { |
9124 | if (idx == 0) { return Input_; } |
9125 | idx -= 1; |
9126 | llvm_unreachable("Invalid index" ); |
9127 | } |
9128 | |
9129 | void ModuloNode::setNthInput(unsigned idx, NodeValue val) { |
9130 | if (idx == 0) { Input_ = val; return; } |
9131 | idx -= 1; |
9132 | llvm_unreachable("Invalid index" ); |
9133 | } |
9134 | |
9135 | llvm::StringRef ModuloNode::getOutputName(unsigned idx) const { |
9136 | if (idx == 0) { return "Result" ; } |
9137 | llvm_unreachable("Invalid index" ); |
9138 | } |
9139 | |
9140 | std::string ModuloNode::getDebugDesc() const { |
9141 | DescriptionBuilder db(getKindName()); |
9142 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9143 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9144 | db |
9145 | .addParam("Input" , *(getInput().getType())) |
9146 | .addParam("Divisor" , getDivisor()) |
9147 | .addParam("SignFollowDivisor" , getSignFollowDivisor()) |
9148 | .addParam("Users" , getNumUsers()); |
9149 | db.addParam("Result" , *(getResult().getType())); |
9150 | return db; |
9151 | } |
9152 | |
9153 | void ModuloNode::visit(Node *parent, NodeWalker *visitor) { |
9154 | if (!visitor->shouldVisit(parent, this)) { return; } |
9155 | visitor->pre(parent, this); |
9156 | if (hasPredicate()) |
9157 | getPredicate().getNode()->visit(this, visitor); |
9158 | getInput().getNode()->visit(this, visitor); |
9159 | visitor->post(parent, this); |
9160 | } |
9161 | |
9162 | bool ModuloNode::isEqual(const ModuloNode &other) const { |
9163 | return true && |
9164 | Input_ == other.Input_ && |
9165 | predicate_ == other.predicate_ && |
9166 | Divisor_ == other.Divisor_ && |
9167 | SignFollowDivisor_ == other.SignFollowDivisor_ && |
9168 | getType(0) == other.getType(0); |
9169 | } |
9170 | |
9171 | Node* ModuloNode::clone() const { |
9172 | return new ModuloNode(getName(), getResult().getType(), getInput(), getDivisor(), getSignFollowDivisor()); |
9173 | } |
9174 | |
9175 | llvm::hash_code ModuloNode::getHash() const { |
9176 | return llvm::hash_combine( |
9177 | Divisor_, |
9178 | SignFollowDivisor_, |
9179 | Input_); |
9180 | } |
9181 | |
9182 | unsigned BatchedPairwiseDotProductNode::getNumInputs() const { |
9183 | return 0 + Inputs_.size(); |
9184 | } |
9185 | |
9186 | std::string BatchedPairwiseDotProductNode::getInputName(unsigned idx) const { |
9187 | idx -= 0; |
9188 | if (idx < Inputs_.size()) { return "Inputs" + std::to_string(idx); } |
9189 | idx -= Inputs_.size(); |
9190 | llvm_unreachable("Invalid index" ); |
9191 | } |
9192 | |
9193 | NodeValue BatchedPairwiseDotProductNode::getNthInput(unsigned idx) { |
9194 | idx -= 0; |
9195 | if (idx < Inputs_.size()) { return Inputs_[idx]; } |
9196 | idx -= Inputs_.size(); |
9197 | llvm_unreachable("Invalid index" ); |
9198 | } |
9199 | |
9200 | void BatchedPairwiseDotProductNode::setNthInput(unsigned idx, NodeValue val) { |
9201 | idx -= 0; |
9202 | if (idx < Inputs_.size()) { Inputs_[idx] = val; return; } |
9203 | idx -= Inputs_.size(); |
9204 | llvm_unreachable("Invalid index" ); |
9205 | } |
9206 | |
9207 | llvm::StringRef BatchedPairwiseDotProductNode::getOutputName(unsigned idx) const { |
9208 | if (idx == 0) { return "Result" ; } |
9209 | llvm_unreachable("Invalid index" ); |
9210 | } |
9211 | |
9212 | std::string BatchedPairwiseDotProductNode::getDebugDesc() const { |
9213 | DescriptionBuilder db(getKindName()); |
9214 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9215 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9216 | db |
9217 | .addParam("Users" , getNumUsers()); |
9218 | { |
9219 | unsigned mIndex = 0; |
9220 | for (const auto &II : getInputs()) { |
9221 | db.addParam("Inputs" +std::to_string(mIndex++), *II.getType()); |
9222 | } |
9223 | } |
9224 | db.addParam("Result" , *(getResult().getType())); |
9225 | return db; |
9226 | } |
9227 | |
9228 | void BatchedPairwiseDotProductNode::visit(Node *parent, NodeWalker *visitor) { |
9229 | if (!visitor->shouldVisit(parent, this)) { return; } |
9230 | visitor->pre(parent, this); |
9231 | if (hasPredicate()) |
9232 | getPredicate().getNode()->visit(this, visitor); |
9233 | for (auto &I : Inputs_) { I.getNode()->visit(this, visitor); } |
9234 | visitor->post(parent, this); |
9235 | } |
9236 | |
9237 | bool BatchedPairwiseDotProductNode::isEqual(const BatchedPairwiseDotProductNode &other) const { |
9238 | return true && |
9239 | predicate_ == other.predicate_ && |
9240 | Inputs_ == other.Inputs_ && |
9241 | getType(0) == other.getType(0); |
9242 | } |
9243 | |
9244 | Node* BatchedPairwiseDotProductNode::clone() const { |
9245 | return new BatchedPairwiseDotProductNode(getName(), getResult().getType(), getInputs()); |
9246 | } |
9247 | |
9248 | llvm::hash_code BatchedPairwiseDotProductNode::getHash() const { |
9249 | return llvm::hash_combine( |
9250 | llvm::hash_combine_range(Inputs_.begin(), Inputs_.end())); |
9251 | } |
9252 | |
9253 | unsigned BatchedPairwiseDotProductGradNode::getNumInputs() const { |
9254 | return 1 + OriginalInputs_.size(); |
9255 | } |
9256 | |
9257 | std::string BatchedPairwiseDotProductGradNode::getInputName(unsigned idx) const { |
9258 | if (idx == 0) { return "OutputGrad" ; } |
9259 | idx -= 1; |
9260 | if (idx < OriginalInputs_.size()) { return "OriginalInputs" + std::to_string(idx); } |
9261 | idx -= OriginalInputs_.size(); |
9262 | llvm_unreachable("Invalid index" ); |
9263 | } |
9264 | |
9265 | NodeValue BatchedPairwiseDotProductGradNode::getNthInput(unsigned idx) { |
9266 | if (idx == 0) { return OutputGrad_; } |
9267 | idx -= 1; |
9268 | if (idx < OriginalInputs_.size()) { return OriginalInputs_[idx]; } |
9269 | idx -= OriginalInputs_.size(); |
9270 | llvm_unreachable("Invalid index" ); |
9271 | } |
9272 | |
9273 | void BatchedPairwiseDotProductGradNode::setNthInput(unsigned idx, NodeValue val) { |
9274 | if (idx == 0) { OutputGrad_ = val; return; } |
9275 | idx -= 1; |
9276 | if (idx < OriginalInputs_.size()) { OriginalInputs_[idx] = val; return; } |
9277 | idx -= OriginalInputs_.size(); |
9278 | llvm_unreachable("Invalid index" ); |
9279 | } |
9280 | |
9281 | llvm::StringRef BatchedPairwiseDotProductGradNode::getOutputName(unsigned idx) const { |
9282 | llvm_unreachable("Invalid index" ); |
9283 | } |
9284 | |
9285 | std::string BatchedPairwiseDotProductGradNode::getDebugDesc() const { |
9286 | DescriptionBuilder db(getKindName()); |
9287 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9288 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9289 | db |
9290 | .addParam("OutputGrad" , *(getOutputGrad().getType())) |
9291 | .addParam("Users" , getNumUsers()); |
9292 | { |
9293 | unsigned mIndex = 0; |
9294 | for (const auto &II : getOriginalInputs()) { |
9295 | db.addParam("OriginalInputs" +std::to_string(mIndex++), *II.getType()); |
9296 | } |
9297 | } |
9298 | return db; |
9299 | } |
9300 | |
9301 | void BatchedPairwiseDotProductGradNode::visit(Node *parent, NodeWalker *visitor) { |
9302 | if (!visitor->shouldVisit(parent, this)) { return; } |
9303 | visitor->pre(parent, this); |
9304 | if (hasPredicate()) |
9305 | getPredicate().getNode()->visit(this, visitor); |
9306 | getOutputGrad().getNode()->visit(this, visitor); |
9307 | for (auto &I : OriginalInputs_) { I.getNode()->visit(this, visitor); } |
9308 | visitor->post(parent, this); |
9309 | } |
9310 | |
9311 | bool BatchedPairwiseDotProductGradNode::isEqual(const BatchedPairwiseDotProductGradNode &other) const { |
9312 | return true && |
9313 | OutputGrad_ == other.OutputGrad_ && |
9314 | predicate_ == other.predicate_ && |
9315 | OriginalInputs_ == other.OriginalInputs_; |
9316 | } |
9317 | |
9318 | Node* BatchedPairwiseDotProductGradNode::clone() const { |
9319 | return new BatchedPairwiseDotProductGradNode(getName(), getOutputGrad(), getOriginalInputs()); |
9320 | } |
9321 | |
9322 | llvm::hash_code BatchedPairwiseDotProductGradNode::getHash() const { |
9323 | return llvm::hash_combine( |
9324 | llvm::hash_combine_range(OriginalInputs_.begin(), OriginalInputs_.end()), |
9325 | OutputGrad_); |
9326 | } |
9327 | |
9328 | unsigned BatchedUnaryEmbeddingsBagsNode::getNumInputs() const { |
9329 | return 4; |
9330 | } |
9331 | |
9332 | std::string BatchedUnaryEmbeddingsBagsNode::getInputName(unsigned idx) const { |
9333 | if (idx == 0) { return "Weights" ; } |
9334 | if (idx == 1) { return "TableOffsets" ; } |
9335 | if (idx == 2) { return "Offsets" ; } |
9336 | if (idx == 3) { return "Indices" ; } |
9337 | idx -= 4; |
9338 | llvm_unreachable("Invalid index" ); |
9339 | } |
9340 | |
9341 | NodeValue BatchedUnaryEmbeddingsBagsNode::getNthInput(unsigned idx) { |
9342 | if (idx == 0) { return Weights_; } |
9343 | if (idx == 1) { return TableOffsets_; } |
9344 | if (idx == 2) { return Offsets_; } |
9345 | if (idx == 3) { return Indices_; } |
9346 | idx -= 4; |
9347 | llvm_unreachable("Invalid index" ); |
9348 | } |
9349 | |
9350 | void BatchedUnaryEmbeddingsBagsNode::setNthInput(unsigned idx, NodeValue val) { |
9351 | if (idx == 0) { Weights_ = val; return; } |
9352 | if (idx == 1) { TableOffsets_ = val; return; } |
9353 | if (idx == 2) { Offsets_ = val; return; } |
9354 | if (idx == 3) { Indices_ = val; return; } |
9355 | idx -= 4; |
9356 | llvm_unreachable("Invalid index" ); |
9357 | } |
9358 | |
9359 | llvm::StringRef BatchedUnaryEmbeddingsBagsNode::getOutputName(unsigned idx) const { |
9360 | if (idx == 0) { return "Result" ; } |
9361 | llvm_unreachable("Invalid index" ); |
9362 | } |
9363 | |
9364 | std::string BatchedUnaryEmbeddingsBagsNode::getDebugDesc() const { |
9365 | DescriptionBuilder db(getKindName()); |
9366 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9367 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9368 | db |
9369 | .addParam("Weights" , *(getWeights().getType())) |
9370 | .addParam("TableOffsets" , *(getTableOffsets().getType())) |
9371 | .addParam("Offsets" , *(getOffsets().getType())) |
9372 | .addParam("Indices" , *(getIndices().getType())) |
9373 | .addParam("Users" , getNumUsers()); |
9374 | db.addParam("Result" , *(getResult().getType())); |
9375 | return db; |
9376 | } |
9377 | |
9378 | void BatchedUnaryEmbeddingsBagsNode::visit(Node *parent, NodeWalker *visitor) { |
9379 | if (!visitor->shouldVisit(parent, this)) { return; } |
9380 | visitor->pre(parent, this); |
9381 | if (hasPredicate()) |
9382 | getPredicate().getNode()->visit(this, visitor); |
9383 | getWeights().getNode()->visit(this, visitor); |
9384 | getTableOffsets().getNode()->visit(this, visitor); |
9385 | getOffsets().getNode()->visit(this, visitor); |
9386 | getIndices().getNode()->visit(this, visitor); |
9387 | visitor->post(parent, this); |
9388 | } |
9389 | |
9390 | bool BatchedUnaryEmbeddingsBagsNode::isEqual(const BatchedUnaryEmbeddingsBagsNode &other) const { |
9391 | return true && |
9392 | Weights_ == other.Weights_ && |
9393 | TableOffsets_ == other.TableOffsets_ && |
9394 | Offsets_ == other.Offsets_ && |
9395 | Indices_ == other.Indices_ && |
9396 | predicate_ == other.predicate_ && |
9397 | getType(0) == other.getType(0); |
9398 | } |
9399 | |
9400 | Node* BatchedUnaryEmbeddingsBagsNode::clone() const { |
9401 | return new BatchedUnaryEmbeddingsBagsNode(getName(), getResult().getType(), getWeights(), getTableOffsets(), getOffsets(), getIndices()); |
9402 | } |
9403 | |
9404 | llvm::hash_code BatchedUnaryEmbeddingsBagsNode::getHash() const { |
9405 | return llvm::hash_combine( |
9406 | Weights_, |
9407 | TableOffsets_, |
9408 | Offsets_, |
9409 | Indices_); |
9410 | } |
9411 | |
9412 | unsigned IntNBitSplitEmbeddingBagsNode::getNumInputs() const { |
9413 | return 8; |
9414 | } |
9415 | |
9416 | std::string IntNBitSplitEmbeddingBagsNode::getInputName(unsigned idx) const { |
9417 | if (idx == 0) { return "DevWeights" ; } |
9418 | if (idx == 1) { return "UvmWeights" ; } |
9419 | if (idx == 2) { return "WeightsPlacements" ; } |
9420 | if (idx == 3) { return "WeightsOffsets" ; } |
9421 | if (idx == 4) { return "WeightsTys" ; } |
9422 | if (idx == 5) { return "DimOffsets" ; } |
9423 | if (idx == 6) { return "Indices" ; } |
9424 | if (idx == 7) { return "Offsets" ; } |
9425 | idx -= 8; |
9426 | llvm_unreachable("Invalid index" ); |
9427 | } |
9428 | |
9429 | NodeValue IntNBitSplitEmbeddingBagsNode::getNthInput(unsigned idx) { |
9430 | if (idx == 0) { return DevWeights_; } |
9431 | if (idx == 1) { return UvmWeights_; } |
9432 | if (idx == 2) { return WeightsPlacements_; } |
9433 | if (idx == 3) { return WeightsOffsets_; } |
9434 | if (idx == 4) { return WeightsTys_; } |
9435 | if (idx == 5) { return DimOffsets_; } |
9436 | if (idx == 6) { return Indices_; } |
9437 | if (idx == 7) { return Offsets_; } |
9438 | idx -= 8; |
9439 | llvm_unreachable("Invalid index" ); |
9440 | } |
9441 | |
9442 | void IntNBitSplitEmbeddingBagsNode::setNthInput(unsigned idx, NodeValue val) { |
9443 | if (idx == 0) { DevWeights_ = val; return; } |
9444 | if (idx == 1) { UvmWeights_ = val; return; } |
9445 | if (idx == 2) { WeightsPlacements_ = val; return; } |
9446 | if (idx == 3) { WeightsOffsets_ = val; return; } |
9447 | if (idx == 4) { WeightsTys_ = val; return; } |
9448 | if (idx == 5) { DimOffsets_ = val; return; } |
9449 | if (idx == 6) { Indices_ = val; return; } |
9450 | if (idx == 7) { Offsets_ = val; return; } |
9451 | idx -= 8; |
9452 | llvm_unreachable("Invalid index" ); |
9453 | } |
9454 | |
9455 | llvm::StringRef IntNBitSplitEmbeddingBagsNode::getOutputName(unsigned idx) const { |
9456 | if (idx == 0) { return "Result" ; } |
9457 | llvm_unreachable("Invalid index" ); |
9458 | } |
9459 | |
9460 | std::string IntNBitSplitEmbeddingBagsNode::getDebugDesc() const { |
9461 | DescriptionBuilder db(getKindName()); |
9462 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9463 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9464 | db |
9465 | .addParam("DevWeights" , *(getDevWeights().getType())) |
9466 | .addParam("UvmWeights" , *(getUvmWeights().getType())) |
9467 | .addParam("WeightsPlacements" , *(getWeightsPlacements().getType())) |
9468 | .addParam("WeightsOffsets" , *(getWeightsOffsets().getType())) |
9469 | .addParam("WeightsTys" , *(getWeightsTys().getType())) |
9470 | .addParam("DimOffsets" , *(getDimOffsets().getType())) |
9471 | .addParam("Indices" , *(getIndices().getType())) |
9472 | .addParam("Offsets" , *(getOffsets().getType())) |
9473 | .addParam("TotalDims" , getTotalDims()) |
9474 | .addParam("PoolingMode" , getPoolingMode()) |
9475 | .addParam("OutputDType" , getOutputDType()) |
9476 | .addParam("Users" , getNumUsers()); |
9477 | db.addParam("Result" , *(getResult().getType())); |
9478 | return db; |
9479 | } |
9480 | |
9481 | void IntNBitSplitEmbeddingBagsNode::visit(Node *parent, NodeWalker *visitor) { |
9482 | if (!visitor->shouldVisit(parent, this)) { return; } |
9483 | visitor->pre(parent, this); |
9484 | if (hasPredicate()) |
9485 | getPredicate().getNode()->visit(this, visitor); |
9486 | getDevWeights().getNode()->visit(this, visitor); |
9487 | getUvmWeights().getNode()->visit(this, visitor); |
9488 | getWeightsPlacements().getNode()->visit(this, visitor); |
9489 | getWeightsOffsets().getNode()->visit(this, visitor); |
9490 | getWeightsTys().getNode()->visit(this, visitor); |
9491 | getDimOffsets().getNode()->visit(this, visitor); |
9492 | getIndices().getNode()->visit(this, visitor); |
9493 | getOffsets().getNode()->visit(this, visitor); |
9494 | visitor->post(parent, this); |
9495 | } |
9496 | |
9497 | bool IntNBitSplitEmbeddingBagsNode::isEqual(const IntNBitSplitEmbeddingBagsNode &other) const { |
9498 | return true && |
9499 | DevWeights_ == other.DevWeights_ && |
9500 | UvmWeights_ == other.UvmWeights_ && |
9501 | WeightsPlacements_ == other.WeightsPlacements_ && |
9502 | WeightsOffsets_ == other.WeightsOffsets_ && |
9503 | WeightsTys_ == other.WeightsTys_ && |
9504 | DimOffsets_ == other.DimOffsets_ && |
9505 | Indices_ == other.Indices_ && |
9506 | Offsets_ == other.Offsets_ && |
9507 | predicate_ == other.predicate_ && |
9508 | TotalDims_ == other.TotalDims_ && |
9509 | PoolingMode_ == other.PoolingMode_ && |
9510 | OutputDType_ == other.OutputDType_ && |
9511 | getType(0) == other.getType(0); |
9512 | } |
9513 | |
9514 | Node* IntNBitSplitEmbeddingBagsNode::clone() const { |
9515 | return new IntNBitSplitEmbeddingBagsNode(getName(), getResult().getType(), getDevWeights(), getUvmWeights(), getWeightsPlacements(), getWeightsOffsets(), getWeightsTys(), getDimOffsets(), getIndices(), getOffsets(), getTotalDims(), getPoolingMode(), getOutputDType()); |
9516 | } |
9517 | |
9518 | llvm::hash_code IntNBitSplitEmbeddingBagsNode::getHash() const { |
9519 | return llvm::hash_combine( |
9520 | TotalDims_, |
9521 | PoolingMode_, |
9522 | OutputDType_, |
9523 | DevWeights_, |
9524 | UvmWeights_, |
9525 | WeightsPlacements_, |
9526 | WeightsOffsets_, |
9527 | WeightsTys_, |
9528 | DimOffsets_, |
9529 | Indices_, |
9530 | Offsets_); |
9531 | } |
9532 | |
9533 | unsigned IntNBitSplitEmbeddingWeightedBagsNode::getNumInputs() const { |
9534 | return 9; |
9535 | } |
9536 | |
9537 | std::string IntNBitSplitEmbeddingWeightedBagsNode::getInputName(unsigned idx) const { |
9538 | if (idx == 0) { return "DevWeights" ; } |
9539 | if (idx == 1) { return "UvmWeights" ; } |
9540 | if (idx == 2) { return "WeightsPlacements" ; } |
9541 | if (idx == 3) { return "WeightsOffsets" ; } |
9542 | if (idx == 4) { return "WeightsTys" ; } |
9543 | if (idx == 5) { return "DimOffsets" ; } |
9544 | if (idx == 6) { return "Indices" ; } |
9545 | if (idx == 7) { return "Offsets" ; } |
9546 | if (idx == 8) { return "IndiceWeight" ; } |
9547 | idx -= 9; |
9548 | llvm_unreachable("Invalid index" ); |
9549 | } |
9550 | |
9551 | NodeValue IntNBitSplitEmbeddingWeightedBagsNode::getNthInput(unsigned idx) { |
9552 | if (idx == 0) { return DevWeights_; } |
9553 | if (idx == 1) { return UvmWeights_; } |
9554 | if (idx == 2) { return WeightsPlacements_; } |
9555 | if (idx == 3) { return WeightsOffsets_; } |
9556 | if (idx == 4) { return WeightsTys_; } |
9557 | if (idx == 5) { return DimOffsets_; } |
9558 | if (idx == 6) { return Indices_; } |
9559 | if (idx == 7) { return Offsets_; } |
9560 | if (idx == 8) { return IndiceWeight_; } |
9561 | idx -= 9; |
9562 | llvm_unreachable("Invalid index" ); |
9563 | } |
9564 | |
9565 | void IntNBitSplitEmbeddingWeightedBagsNode::setNthInput(unsigned idx, NodeValue val) { |
9566 | if (idx == 0) { DevWeights_ = val; return; } |
9567 | if (idx == 1) { UvmWeights_ = val; return; } |
9568 | if (idx == 2) { WeightsPlacements_ = val; return; } |
9569 | if (idx == 3) { WeightsOffsets_ = val; return; } |
9570 | if (idx == 4) { WeightsTys_ = val; return; } |
9571 | if (idx == 5) { DimOffsets_ = val; return; } |
9572 | if (idx == 6) { Indices_ = val; return; } |
9573 | if (idx == 7) { Offsets_ = val; return; } |
9574 | if (idx == 8) { IndiceWeight_ = val; return; } |
9575 | idx -= 9; |
9576 | llvm_unreachable("Invalid index" ); |
9577 | } |
9578 | |
9579 | llvm::StringRef IntNBitSplitEmbeddingWeightedBagsNode::getOutputName(unsigned idx) const { |
9580 | if (idx == 0) { return "Result" ; } |
9581 | llvm_unreachable("Invalid index" ); |
9582 | } |
9583 | |
9584 | std::string IntNBitSplitEmbeddingWeightedBagsNode::getDebugDesc() const { |
9585 | DescriptionBuilder db(getKindName()); |
9586 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9587 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9588 | db |
9589 | .addParam("DevWeights" , *(getDevWeights().getType())) |
9590 | .addParam("UvmWeights" , *(getUvmWeights().getType())) |
9591 | .addParam("WeightsPlacements" , *(getWeightsPlacements().getType())) |
9592 | .addParam("WeightsOffsets" , *(getWeightsOffsets().getType())) |
9593 | .addParam("WeightsTys" , *(getWeightsTys().getType())) |
9594 | .addParam("DimOffsets" , *(getDimOffsets().getType())) |
9595 | .addParam("Indices" , *(getIndices().getType())) |
9596 | .addParam("Offsets" , *(getOffsets().getType())) |
9597 | .addParam("IndiceWeight" , *(getIndiceWeight().getType())) |
9598 | .addParam("TotalDims" , getTotalDims()) |
9599 | .addParam("PoolingMode" , getPoolingMode()) |
9600 | .addParam("OutputDType" , getOutputDType()) |
9601 | .addParam("Users" , getNumUsers()); |
9602 | db.addParam("Result" , *(getResult().getType())); |
9603 | return db; |
9604 | } |
9605 | |
9606 | void IntNBitSplitEmbeddingWeightedBagsNode::visit(Node *parent, NodeWalker *visitor) { |
9607 | if (!visitor->shouldVisit(parent, this)) { return; } |
9608 | visitor->pre(parent, this); |
9609 | if (hasPredicate()) |
9610 | getPredicate().getNode()->visit(this, visitor); |
9611 | getDevWeights().getNode()->visit(this, visitor); |
9612 | getUvmWeights().getNode()->visit(this, visitor); |
9613 | getWeightsPlacements().getNode()->visit(this, visitor); |
9614 | getWeightsOffsets().getNode()->visit(this, visitor); |
9615 | getWeightsTys().getNode()->visit(this, visitor); |
9616 | getDimOffsets().getNode()->visit(this, visitor); |
9617 | getIndices().getNode()->visit(this, visitor); |
9618 | getOffsets().getNode()->visit(this, visitor); |
9619 | getIndiceWeight().getNode()->visit(this, visitor); |
9620 | visitor->post(parent, this); |
9621 | } |
9622 | |
9623 | bool IntNBitSplitEmbeddingWeightedBagsNode::isEqual(const IntNBitSplitEmbeddingWeightedBagsNode &other) const { |
9624 | return true && |
9625 | DevWeights_ == other.DevWeights_ && |
9626 | UvmWeights_ == other.UvmWeights_ && |
9627 | WeightsPlacements_ == other.WeightsPlacements_ && |
9628 | WeightsOffsets_ == other.WeightsOffsets_ && |
9629 | WeightsTys_ == other.WeightsTys_ && |
9630 | DimOffsets_ == other.DimOffsets_ && |
9631 | Indices_ == other.Indices_ && |
9632 | Offsets_ == other.Offsets_ && |
9633 | IndiceWeight_ == other.IndiceWeight_ && |
9634 | predicate_ == other.predicate_ && |
9635 | TotalDims_ == other.TotalDims_ && |
9636 | PoolingMode_ == other.PoolingMode_ && |
9637 | OutputDType_ == other.OutputDType_ && |
9638 | getType(0) == other.getType(0); |
9639 | } |
9640 | |
9641 | Node* IntNBitSplitEmbeddingWeightedBagsNode::clone() const { |
9642 | return new IntNBitSplitEmbeddingWeightedBagsNode(getName(), getResult().getType(), getDevWeights(), getUvmWeights(), getWeightsPlacements(), getWeightsOffsets(), getWeightsTys(), getDimOffsets(), getIndices(), getOffsets(), getIndiceWeight(), getTotalDims(), getPoolingMode(), getOutputDType()); |
9643 | } |
9644 | |
9645 | llvm::hash_code IntNBitSplitEmbeddingWeightedBagsNode::getHash() const { |
9646 | return llvm::hash_combine( |
9647 | TotalDims_, |
9648 | PoolingMode_, |
9649 | OutputDType_, |
9650 | DevWeights_, |
9651 | UvmWeights_, |
9652 | WeightsPlacements_, |
9653 | WeightsOffsets_, |
9654 | WeightsTys_, |
9655 | DimOffsets_, |
9656 | Indices_, |
9657 | Offsets_, |
9658 | IndiceWeight_); |
9659 | } |
9660 | |
9661 | unsigned GaussianFillNode::getNumInputs() const { |
9662 | return 1; |
9663 | } |
9664 | |
9665 | std::string GaussianFillNode::getInputName(unsigned idx) const { |
9666 | if (idx == 0) { return "Input" ; } |
9667 | idx -= 1; |
9668 | llvm_unreachable("Invalid index" ); |
9669 | } |
9670 | |
9671 | NodeValue GaussianFillNode::getNthInput(unsigned idx) { |
9672 | if (idx == 0) { return Input_; } |
9673 | idx -= 1; |
9674 | llvm_unreachable("Invalid index" ); |
9675 | } |
9676 | |
9677 | void GaussianFillNode::setNthInput(unsigned idx, NodeValue val) { |
9678 | if (idx == 0) { Input_ = val; return; } |
9679 | idx -= 1; |
9680 | llvm_unreachable("Invalid index" ); |
9681 | } |
9682 | |
9683 | llvm::StringRef GaussianFillNode::getOutputName(unsigned idx) const { |
9684 | if (idx == 0) { return "Result" ; } |
9685 | llvm_unreachable("Invalid index" ); |
9686 | } |
9687 | |
9688 | std::string GaussianFillNode::getDebugDesc() const { |
9689 | DescriptionBuilder db(getKindName()); |
9690 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9691 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9692 | db |
9693 | .addParam("Input" , *(getInput().getType())) |
9694 | .addParam("Mean" , getMean()) |
9695 | .addParam("Scale" , getScale()) |
9696 | .addParam("Seed" , getSeed()) |
9697 | .addParam("Users" , getNumUsers()); |
9698 | db.addParam("Result" , *(getResult().getType())); |
9699 | return db; |
9700 | } |
9701 | |
9702 | void GaussianFillNode::visit(Node *parent, NodeWalker *visitor) { |
9703 | if (!visitor->shouldVisit(parent, this)) { return; } |
9704 | visitor->pre(parent, this); |
9705 | if (hasPredicate()) |
9706 | getPredicate().getNode()->visit(this, visitor); |
9707 | getInput().getNode()->visit(this, visitor); |
9708 | visitor->post(parent, this); |
9709 | } |
9710 | |
9711 | bool GaussianFillNode::isEqual(const GaussianFillNode &other) const { |
9712 | return true && |
9713 | Input_ == other.Input_ && |
9714 | predicate_ == other.predicate_ && |
9715 | Mean_ == other.Mean_ && |
9716 | Scale_ == other.Scale_ && |
9717 | Seed_ == other.Seed_ && |
9718 | getType(0) == other.getType(0); |
9719 | } |
9720 | |
9721 | Node* GaussianFillNode::clone() const { |
9722 | return new GaussianFillNode(getName(), getResult().getType(), getInput(), getMean(), getScale(), getSeed()); |
9723 | } |
9724 | |
9725 | llvm::hash_code GaussianFillNode::getHash() const { |
9726 | return llvm::hash_combine( |
9727 | toBinary(Mean_), |
9728 | toBinary(Scale_), |
9729 | toBinary(Seed_), |
9730 | Input_); |
9731 | } |
9732 | |
9733 | unsigned ReluGradNode::getNumInputs() const { |
9734 | return 3; |
9735 | } |
9736 | |
9737 | std::string ReluGradNode::getInputName(unsigned idx) const { |
9738 | if (idx == 0) { return "Input" ; } |
9739 | if (idx == 1) { return "OriginalOutputForResult" ; } |
9740 | if (idx == 2) { return "GradOfOriginalOutputNamedResult" ; } |
9741 | idx -= 3; |
9742 | llvm_unreachable("Invalid index" ); |
9743 | } |
9744 | |
9745 | NodeValue ReluGradNode::getNthInput(unsigned idx) { |
9746 | if (idx == 0) { return Input_; } |
9747 | if (idx == 1) { return OriginalOutputForResult_; } |
9748 | if (idx == 2) { return GradOfOriginalOutputNamedResult_; } |
9749 | idx -= 3; |
9750 | llvm_unreachable("Invalid index" ); |
9751 | } |
9752 | |
9753 | void ReluGradNode::setNthInput(unsigned idx, NodeValue val) { |
9754 | if (idx == 0) { Input_ = val; return; } |
9755 | if (idx == 1) { OriginalOutputForResult_ = val; return; } |
9756 | if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; } |
9757 | idx -= 3; |
9758 | llvm_unreachable("Invalid index" ); |
9759 | } |
9760 | |
9761 | llvm::StringRef ReluGradNode::getOutputName(unsigned idx) const { |
9762 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
9763 | llvm_unreachable("Invalid index" ); |
9764 | } |
9765 | |
9766 | std::string ReluGradNode::getDebugDesc() const { |
9767 | DescriptionBuilder db(getKindName()); |
9768 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9769 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9770 | db |
9771 | .addParam("Input" , *(getInput().getType())) |
9772 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
9773 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
9774 | .addParam("Users" , getNumUsers()); |
9775 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
9776 | return db; |
9777 | } |
9778 | |
9779 | void ReluGradNode::visit(Node *parent, NodeWalker *visitor) { |
9780 | if (!visitor->shouldVisit(parent, this)) { return; } |
9781 | visitor->pre(parent, this); |
9782 | if (hasPredicate()) |
9783 | getPredicate().getNode()->visit(this, visitor); |
9784 | getInput().getNode()->visit(this, visitor); |
9785 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
9786 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
9787 | visitor->post(parent, this); |
9788 | } |
9789 | |
9790 | bool ReluGradNode::isEqual(const ReluGradNode &other) const { |
9791 | return true && |
9792 | Input_ == other.Input_ && |
9793 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
9794 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
9795 | predicate_ == other.predicate_ && |
9796 | getType(0) == other.getType(0); |
9797 | } |
9798 | |
9799 | Node* ReluGradNode::clone() const { |
9800 | return new ReluGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
9801 | } |
9802 | |
9803 | llvm::hash_code ReluGradNode::getHash() const { |
9804 | return llvm::hash_combine( |
9805 | Input_, |
9806 | OriginalOutputForResult_, |
9807 | GradOfOriginalOutputNamedResult_); |
9808 | } |
9809 | |
9810 | unsigned ReluNode::getNumInputs() const { |
9811 | return 1; |
9812 | } |
9813 | |
9814 | std::string ReluNode::getInputName(unsigned idx) const { |
9815 | if (idx == 0) { return "Input" ; } |
9816 | idx -= 1; |
9817 | llvm_unreachable("Invalid index" ); |
9818 | } |
9819 | |
9820 | NodeValue ReluNode::getNthInput(unsigned idx) { |
9821 | if (idx == 0) { return Input_; } |
9822 | idx -= 1; |
9823 | llvm_unreachable("Invalid index" ); |
9824 | } |
9825 | |
9826 | void ReluNode::setNthInput(unsigned idx, NodeValue val) { |
9827 | if (idx == 0) { Input_ = val; return; } |
9828 | idx -= 1; |
9829 | llvm_unreachable("Invalid index" ); |
9830 | } |
9831 | |
9832 | llvm::StringRef ReluNode::getOutputName(unsigned idx) const { |
9833 | if (idx == 0) { return "Result" ; } |
9834 | llvm_unreachable("Invalid index" ); |
9835 | } |
9836 | |
9837 | std::string ReluNode::getDebugDesc() const { |
9838 | DescriptionBuilder db(getKindName()); |
9839 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9840 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9841 | db |
9842 | .addParam("Input" , *(getInput().getType())) |
9843 | .addParam("Users" , getNumUsers()); |
9844 | db.addParam("Result" , *(getResult().getType())); |
9845 | return db; |
9846 | } |
9847 | |
9848 | void ReluNode::visit(Node *parent, NodeWalker *visitor) { |
9849 | if (!visitor->shouldVisit(parent, this)) { return; } |
9850 | visitor->pre(parent, this); |
9851 | if (hasPredicate()) |
9852 | getPredicate().getNode()->visit(this, visitor); |
9853 | getInput().getNode()->visit(this, visitor); |
9854 | visitor->post(parent, this); |
9855 | } |
9856 | |
9857 | bool ReluNode::isEqual(const ReluNode &other) const { |
9858 | return true && |
9859 | Input_ == other.Input_ && |
9860 | predicate_ == other.predicate_ && |
9861 | getType(0) == other.getType(0); |
9862 | } |
9863 | |
9864 | Node* ReluNode::clone() const { |
9865 | return new ReluNode(getName(), getResult().getType(), getInput()); |
9866 | } |
9867 | |
9868 | llvm::hash_code ReluNode::getHash() const { |
9869 | return llvm::hash_combine( |
9870 | Input_); |
9871 | } |
9872 | |
9873 | ReluGradNode *ReluNode::getGrad(GraphGradMapper &builder) { |
9874 | auto *x = new ReluGradNode(getName().str() + "_grad" , getInput(), getResult(), builder.getGradient(getResult())); |
9875 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
9876 | return x; |
9877 | } |
9878 | |
9879 | unsigned HardSwishNode::getNumInputs() const { |
9880 | return 1; |
9881 | } |
9882 | |
9883 | std::string HardSwishNode::getInputName(unsigned idx) const { |
9884 | if (idx == 0) { return "Input" ; } |
9885 | idx -= 1; |
9886 | llvm_unreachable("Invalid index" ); |
9887 | } |
9888 | |
9889 | NodeValue HardSwishNode::getNthInput(unsigned idx) { |
9890 | if (idx == 0) { return Input_; } |
9891 | idx -= 1; |
9892 | llvm_unreachable("Invalid index" ); |
9893 | } |
9894 | |
9895 | void HardSwishNode::setNthInput(unsigned idx, NodeValue val) { |
9896 | if (idx == 0) { Input_ = val; return; } |
9897 | idx -= 1; |
9898 | llvm_unreachable("Invalid index" ); |
9899 | } |
9900 | |
9901 | llvm::StringRef HardSwishNode::getOutputName(unsigned idx) const { |
9902 | if (idx == 0) { return "Result" ; } |
9903 | llvm_unreachable("Invalid index" ); |
9904 | } |
9905 | |
9906 | std::string HardSwishNode::getDebugDesc() const { |
9907 | DescriptionBuilder db(getKindName()); |
9908 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9909 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9910 | db |
9911 | .addParam("Input" , *(getInput().getType())) |
9912 | .addParam("Users" , getNumUsers()); |
9913 | db.addParam("Result" , *(getResult().getType())); |
9914 | return db; |
9915 | } |
9916 | |
9917 | void HardSwishNode::visit(Node *parent, NodeWalker *visitor) { |
9918 | if (!visitor->shouldVisit(parent, this)) { return; } |
9919 | visitor->pre(parent, this); |
9920 | if (hasPredicate()) |
9921 | getPredicate().getNode()->visit(this, visitor); |
9922 | getInput().getNode()->visit(this, visitor); |
9923 | visitor->post(parent, this); |
9924 | } |
9925 | |
9926 | bool HardSwishNode::isEqual(const HardSwishNode &other) const { |
9927 | return true && |
9928 | Input_ == other.Input_ && |
9929 | predicate_ == other.predicate_ && |
9930 | getType(0) == other.getType(0); |
9931 | } |
9932 | |
9933 | Node* HardSwishNode::clone() const { |
9934 | return new HardSwishNode(getName(), getResult().getType(), getInput()); |
9935 | } |
9936 | |
9937 | llvm::hash_code HardSwishNode::getHash() const { |
9938 | return llvm::hash_combine( |
9939 | Input_); |
9940 | } |
9941 | |
9942 | unsigned GeluNode::getNumInputs() const { |
9943 | return 1; |
9944 | } |
9945 | |
9946 | std::string GeluNode::getInputName(unsigned idx) const { |
9947 | if (idx == 0) { return "Input" ; } |
9948 | idx -= 1; |
9949 | llvm_unreachable("Invalid index" ); |
9950 | } |
9951 | |
9952 | NodeValue GeluNode::getNthInput(unsigned idx) { |
9953 | if (idx == 0) { return Input_; } |
9954 | idx -= 1; |
9955 | llvm_unreachable("Invalid index" ); |
9956 | } |
9957 | |
9958 | void GeluNode::setNthInput(unsigned idx, NodeValue val) { |
9959 | if (idx == 0) { Input_ = val; return; } |
9960 | idx -= 1; |
9961 | llvm_unreachable("Invalid index" ); |
9962 | } |
9963 | |
9964 | llvm::StringRef GeluNode::getOutputName(unsigned idx) const { |
9965 | if (idx == 0) { return "Result" ; } |
9966 | llvm_unreachable("Invalid index" ); |
9967 | } |
9968 | |
9969 | std::string GeluNode::getDebugDesc() const { |
9970 | DescriptionBuilder db(getKindName()); |
9971 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
9972 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
9973 | db |
9974 | .addParam("Input" , *(getInput().getType())) |
9975 | .addParam("Users" , getNumUsers()); |
9976 | db.addParam("Result" , *(getResult().getType())); |
9977 | return db; |
9978 | } |
9979 | |
9980 | void GeluNode::visit(Node *parent, NodeWalker *visitor) { |
9981 | if (!visitor->shouldVisit(parent, this)) { return; } |
9982 | visitor->pre(parent, this); |
9983 | if (hasPredicate()) |
9984 | getPredicate().getNode()->visit(this, visitor); |
9985 | getInput().getNode()->visit(this, visitor); |
9986 | visitor->post(parent, this); |
9987 | } |
9988 | |
9989 | bool GeluNode::isEqual(const GeluNode &other) const { |
9990 | return true && |
9991 | Input_ == other.Input_ && |
9992 | predicate_ == other.predicate_ && |
9993 | getType(0) == other.getType(0); |
9994 | } |
9995 | |
9996 | Node* GeluNode::clone() const { |
9997 | return new GeluNode(getName(), getResult().getType(), getInput()); |
9998 | } |
9999 | |
10000 | llvm::hash_code GeluNode::getHash() const { |
10001 | return llvm::hash_combine( |
10002 | Input_); |
10003 | } |
10004 | |
10005 | unsigned ClipNode::getNumInputs() const { |
10006 | return 1; |
10007 | } |
10008 | |
10009 | std::string ClipNode::getInputName(unsigned idx) const { |
10010 | if (idx == 0) { return "Input" ; } |
10011 | idx -= 1; |
10012 | llvm_unreachable("Invalid index" ); |
10013 | } |
10014 | |
10015 | NodeValue ClipNode::getNthInput(unsigned idx) { |
10016 | if (idx == 0) { return Input_; } |
10017 | idx -= 1; |
10018 | llvm_unreachable("Invalid index" ); |
10019 | } |
10020 | |
10021 | void ClipNode::setNthInput(unsigned idx, NodeValue val) { |
10022 | if (idx == 0) { Input_ = val; return; } |
10023 | idx -= 1; |
10024 | llvm_unreachable("Invalid index" ); |
10025 | } |
10026 | |
10027 | llvm::StringRef ClipNode::getOutputName(unsigned idx) const { |
10028 | if (idx == 0) { return "Result" ; } |
10029 | llvm_unreachable("Invalid index" ); |
10030 | } |
10031 | |
10032 | std::string ClipNode::getDebugDesc() const { |
10033 | DescriptionBuilder db(getKindName()); |
10034 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10035 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10036 | db |
10037 | .addParam("Input" , *(getInput().getType())) |
10038 | .addParam("Min" , getMin()) |
10039 | .addParam("Max" , getMax()) |
10040 | .addParam("Users" , getNumUsers()); |
10041 | db.addParam("Result" , *(getResult().getType())); |
10042 | return db; |
10043 | } |
10044 | |
10045 | void ClipNode::visit(Node *parent, NodeWalker *visitor) { |
10046 | if (!visitor->shouldVisit(parent, this)) { return; } |
10047 | visitor->pre(parent, this); |
10048 | if (hasPredicate()) |
10049 | getPredicate().getNode()->visit(this, visitor); |
10050 | getInput().getNode()->visit(this, visitor); |
10051 | visitor->post(parent, this); |
10052 | } |
10053 | |
10054 | bool ClipNode::isEqual(const ClipNode &other) const { |
10055 | return true && |
10056 | Input_ == other.Input_ && |
10057 | predicate_ == other.predicate_ && |
10058 | Min_ == other.Min_ && |
10059 | Max_ == other.Max_ && |
10060 | getType(0) == other.getType(0); |
10061 | } |
10062 | |
10063 | Node* ClipNode::clone() const { |
10064 | return new ClipNode(getName(), getResult().getType(), getInput(), getMin(), getMax()); |
10065 | } |
10066 | |
10067 | llvm::hash_code ClipNode::getHash() const { |
10068 | return llvm::hash_combine( |
10069 | toBinary(Min_), |
10070 | toBinary(Max_), |
10071 | Input_); |
10072 | } |
10073 | |
10074 | unsigned PReluNode::getNumInputs() const { |
10075 | return 2; |
10076 | } |
10077 | |
10078 | std::string PReluNode::getInputName(unsigned idx) const { |
10079 | if (idx == 0) { return "Input" ; } |
10080 | if (idx == 1) { return "Slope" ; } |
10081 | idx -= 2; |
10082 | llvm_unreachable("Invalid index" ); |
10083 | } |
10084 | |
10085 | NodeValue PReluNode::getNthInput(unsigned idx) { |
10086 | if (idx == 0) { return Input_; } |
10087 | if (idx == 1) { return Slope_; } |
10088 | idx -= 2; |
10089 | llvm_unreachable("Invalid index" ); |
10090 | } |
10091 | |
10092 | void PReluNode::setNthInput(unsigned idx, NodeValue val) { |
10093 | if (idx == 0) { Input_ = val; return; } |
10094 | if (idx == 1) { Slope_ = val; return; } |
10095 | idx -= 2; |
10096 | llvm_unreachable("Invalid index" ); |
10097 | } |
10098 | |
10099 | llvm::StringRef PReluNode::getOutputName(unsigned idx) const { |
10100 | if (idx == 0) { return "Result" ; } |
10101 | llvm_unreachable("Invalid index" ); |
10102 | } |
10103 | |
10104 | std::string PReluNode::getDebugDesc() const { |
10105 | DescriptionBuilder db(getKindName()); |
10106 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10107 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10108 | db |
10109 | .addParam("Input" , *(getInput().getType())) |
10110 | .addParam("Slope" , *(getSlope().getType())) |
10111 | .addParam("Users" , getNumUsers()); |
10112 | db.addParam("Result" , *(getResult().getType())); |
10113 | return db; |
10114 | } |
10115 | |
10116 | void PReluNode::visit(Node *parent, NodeWalker *visitor) { |
10117 | if (!visitor->shouldVisit(parent, this)) { return; } |
10118 | visitor->pre(parent, this); |
10119 | if (hasPredicate()) |
10120 | getPredicate().getNode()->visit(this, visitor); |
10121 | getInput().getNode()->visit(this, visitor); |
10122 | getSlope().getNode()->visit(this, visitor); |
10123 | visitor->post(parent, this); |
10124 | } |
10125 | |
10126 | bool PReluNode::isEqual(const PReluNode &other) const { |
10127 | return true && |
10128 | Input_ == other.Input_ && |
10129 | Slope_ == other.Slope_ && |
10130 | predicate_ == other.predicate_ && |
10131 | getType(0) == other.getType(0); |
10132 | } |
10133 | |
10134 | Node* PReluNode::clone() const { |
10135 | return new PReluNode(getName(), getResult().getType(), getInput(), getSlope()); |
10136 | } |
10137 | |
10138 | llvm::hash_code PReluNode::getHash() const { |
10139 | return llvm::hash_combine( |
10140 | Input_, |
10141 | Slope_); |
10142 | } |
10143 | |
10144 | unsigned SigmoidGradNode::getNumInputs() const { |
10145 | return 3; |
10146 | } |
10147 | |
10148 | std::string SigmoidGradNode::getInputName(unsigned idx) const { |
10149 | if (idx == 0) { return "Input" ; } |
10150 | if (idx == 1) { return "OriginalOutputForResult" ; } |
10151 | if (idx == 2) { return "GradOfOriginalOutputNamedResult" ; } |
10152 | idx -= 3; |
10153 | llvm_unreachable("Invalid index" ); |
10154 | } |
10155 | |
10156 | NodeValue SigmoidGradNode::getNthInput(unsigned idx) { |
10157 | if (idx == 0) { return Input_; } |
10158 | if (idx == 1) { return OriginalOutputForResult_; } |
10159 | if (idx == 2) { return GradOfOriginalOutputNamedResult_; } |
10160 | idx -= 3; |
10161 | llvm_unreachable("Invalid index" ); |
10162 | } |
10163 | |
10164 | void SigmoidGradNode::setNthInput(unsigned idx, NodeValue val) { |
10165 | if (idx == 0) { Input_ = val; return; } |
10166 | if (idx == 1) { OriginalOutputForResult_ = val; return; } |
10167 | if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; } |
10168 | idx -= 3; |
10169 | llvm_unreachable("Invalid index" ); |
10170 | } |
10171 | |
10172 | llvm::StringRef SigmoidGradNode::getOutputName(unsigned idx) const { |
10173 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
10174 | llvm_unreachable("Invalid index" ); |
10175 | } |
10176 | |
10177 | std::string SigmoidGradNode::getDebugDesc() const { |
10178 | DescriptionBuilder db(getKindName()); |
10179 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10180 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10181 | db |
10182 | .addParam("Input" , *(getInput().getType())) |
10183 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
10184 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
10185 | .addParam("Users" , getNumUsers()); |
10186 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
10187 | return db; |
10188 | } |
10189 | |
10190 | void SigmoidGradNode::visit(Node *parent, NodeWalker *visitor) { |
10191 | if (!visitor->shouldVisit(parent, this)) { return; } |
10192 | visitor->pre(parent, this); |
10193 | if (hasPredicate()) |
10194 | getPredicate().getNode()->visit(this, visitor); |
10195 | getInput().getNode()->visit(this, visitor); |
10196 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
10197 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
10198 | visitor->post(parent, this); |
10199 | } |
10200 | |
10201 | bool SigmoidGradNode::isEqual(const SigmoidGradNode &other) const { |
10202 | return true && |
10203 | Input_ == other.Input_ && |
10204 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
10205 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
10206 | predicate_ == other.predicate_ && |
10207 | getType(0) == other.getType(0); |
10208 | } |
10209 | |
10210 | Node* SigmoidGradNode::clone() const { |
10211 | return new SigmoidGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
10212 | } |
10213 | |
10214 | llvm::hash_code SigmoidGradNode::getHash() const { |
10215 | return llvm::hash_combine( |
10216 | Input_, |
10217 | OriginalOutputForResult_, |
10218 | GradOfOriginalOutputNamedResult_); |
10219 | } |
10220 | |
10221 | unsigned SigmoidNode::getNumInputs() const { |
10222 | return 1; |
10223 | } |
10224 | |
10225 | std::string SigmoidNode::getInputName(unsigned idx) const { |
10226 | if (idx == 0) { return "Input" ; } |
10227 | idx -= 1; |
10228 | llvm_unreachable("Invalid index" ); |
10229 | } |
10230 | |
10231 | NodeValue SigmoidNode::getNthInput(unsigned idx) { |
10232 | if (idx == 0) { return Input_; } |
10233 | idx -= 1; |
10234 | llvm_unreachable("Invalid index" ); |
10235 | } |
10236 | |
10237 | void SigmoidNode::setNthInput(unsigned idx, NodeValue val) { |
10238 | if (idx == 0) { Input_ = val; return; } |
10239 | idx -= 1; |
10240 | llvm_unreachable("Invalid index" ); |
10241 | } |
10242 | |
10243 | llvm::StringRef SigmoidNode::getOutputName(unsigned idx) const { |
10244 | if (idx == 0) { return "Result" ; } |
10245 | llvm_unreachable("Invalid index" ); |
10246 | } |
10247 | |
10248 | std::string SigmoidNode::getDebugDesc() const { |
10249 | DescriptionBuilder db(getKindName()); |
10250 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10251 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10252 | db |
10253 | .addParam("Input" , *(getInput().getType())) |
10254 | .addParam("Users" , getNumUsers()); |
10255 | db.addParam("Result" , *(getResult().getType())); |
10256 | return db; |
10257 | } |
10258 | |
10259 | void SigmoidNode::visit(Node *parent, NodeWalker *visitor) { |
10260 | if (!visitor->shouldVisit(parent, this)) { return; } |
10261 | visitor->pre(parent, this); |
10262 | if (hasPredicate()) |
10263 | getPredicate().getNode()->visit(this, visitor); |
10264 | getInput().getNode()->visit(this, visitor); |
10265 | visitor->post(parent, this); |
10266 | } |
10267 | |
10268 | bool SigmoidNode::isEqual(const SigmoidNode &other) const { |
10269 | return true && |
10270 | Input_ == other.Input_ && |
10271 | predicate_ == other.predicate_ && |
10272 | getType(0) == other.getType(0); |
10273 | } |
10274 | |
10275 | Node* SigmoidNode::clone() const { |
10276 | return new SigmoidNode(getName(), getResult().getType(), getInput()); |
10277 | } |
10278 | |
10279 | llvm::hash_code SigmoidNode::getHash() const { |
10280 | return llvm::hash_combine( |
10281 | Input_); |
10282 | } |
10283 | |
10284 | SigmoidGradNode *SigmoidNode::getGrad(GraphGradMapper &builder) { |
10285 | auto *x = new SigmoidGradNode(getName().str() + "_grad" , getInput(), getResult(), builder.getGradient(getResult())); |
10286 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
10287 | return x; |
10288 | } |
10289 | |
10290 | unsigned SwishNode::getNumInputs() const { |
10291 | return 1; |
10292 | } |
10293 | |
10294 | std::string SwishNode::getInputName(unsigned idx) const { |
10295 | if (idx == 0) { return "Input" ; } |
10296 | idx -= 1; |
10297 | llvm_unreachable("Invalid index" ); |
10298 | } |
10299 | |
10300 | NodeValue SwishNode::getNthInput(unsigned idx) { |
10301 | if (idx == 0) { return Input_; } |
10302 | idx -= 1; |
10303 | llvm_unreachable("Invalid index" ); |
10304 | } |
10305 | |
10306 | void SwishNode::setNthInput(unsigned idx, NodeValue val) { |
10307 | if (idx == 0) { Input_ = val; return; } |
10308 | idx -= 1; |
10309 | llvm_unreachable("Invalid index" ); |
10310 | } |
10311 | |
10312 | llvm::StringRef SwishNode::getOutputName(unsigned idx) const { |
10313 | if (idx == 0) { return "Result" ; } |
10314 | llvm_unreachable("Invalid index" ); |
10315 | } |
10316 | |
10317 | std::string SwishNode::getDebugDesc() const { |
10318 | DescriptionBuilder db(getKindName()); |
10319 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10320 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10321 | db |
10322 | .addParam("Input" , *(getInput().getType())) |
10323 | .addParam("Users" , getNumUsers()); |
10324 | db.addParam("Result" , *(getResult().getType())); |
10325 | return db; |
10326 | } |
10327 | |
10328 | void SwishNode::visit(Node *parent, NodeWalker *visitor) { |
10329 | if (!visitor->shouldVisit(parent, this)) { return; } |
10330 | visitor->pre(parent, this); |
10331 | if (hasPredicate()) |
10332 | getPredicate().getNode()->visit(this, visitor); |
10333 | getInput().getNode()->visit(this, visitor); |
10334 | visitor->post(parent, this); |
10335 | } |
10336 | |
10337 | bool SwishNode::isEqual(const SwishNode &other) const { |
10338 | return true && |
10339 | Input_ == other.Input_ && |
10340 | predicate_ == other.predicate_ && |
10341 | getType(0) == other.getType(0); |
10342 | } |
10343 | |
10344 | Node* SwishNode::clone() const { |
10345 | return new SwishNode(getName(), getResult().getType(), getInput()); |
10346 | } |
10347 | |
10348 | llvm::hash_code SwishNode::getHash() const { |
10349 | return llvm::hash_combine( |
10350 | Input_); |
10351 | } |
10352 | |
10353 | unsigned TanhGradNode::getNumInputs() const { |
10354 | return 3; |
10355 | } |
10356 | |
10357 | std::string TanhGradNode::getInputName(unsigned idx) const { |
10358 | if (idx == 0) { return "Input" ; } |
10359 | if (idx == 1) { return "OriginalOutputForResult" ; } |
10360 | if (idx == 2) { return "GradOfOriginalOutputNamedResult" ; } |
10361 | idx -= 3; |
10362 | llvm_unreachable("Invalid index" ); |
10363 | } |
10364 | |
10365 | NodeValue TanhGradNode::getNthInput(unsigned idx) { |
10366 | if (idx == 0) { return Input_; } |
10367 | if (idx == 1) { return OriginalOutputForResult_; } |
10368 | if (idx == 2) { return GradOfOriginalOutputNamedResult_; } |
10369 | idx -= 3; |
10370 | llvm_unreachable("Invalid index" ); |
10371 | } |
10372 | |
10373 | void TanhGradNode::setNthInput(unsigned idx, NodeValue val) { |
10374 | if (idx == 0) { Input_ = val; return; } |
10375 | if (idx == 1) { OriginalOutputForResult_ = val; return; } |
10376 | if (idx == 2) { GradOfOriginalOutputNamedResult_ = val; return; } |
10377 | idx -= 3; |
10378 | llvm_unreachable("Invalid index" ); |
10379 | } |
10380 | |
10381 | llvm::StringRef TanhGradNode::getOutputName(unsigned idx) const { |
10382 | if (idx == 0) { return "GradOfInputNamedInput" ; } |
10383 | llvm_unreachable("Invalid index" ); |
10384 | } |
10385 | |
10386 | std::string TanhGradNode::getDebugDesc() const { |
10387 | DescriptionBuilder db(getKindName()); |
10388 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10389 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10390 | db |
10391 | .addParam("Input" , *(getInput().getType())) |
10392 | .addParam("OriginalOutputForResult" , *(getOriginalOutputForResult().getType())) |
10393 | .addParam("GradOfOriginalOutputNamedResult" , *(getGradOfOriginalOutputNamedResult().getType())) |
10394 | .addParam("Users" , getNumUsers()); |
10395 | db.addParam("GradOfInputNamedInput" , *(getGradOfInputNamedInput().getType())); |
10396 | return db; |
10397 | } |
10398 | |
10399 | void TanhGradNode::visit(Node *parent, NodeWalker *visitor) { |
10400 | if (!visitor->shouldVisit(parent, this)) { return; } |
10401 | visitor->pre(parent, this); |
10402 | if (hasPredicate()) |
10403 | getPredicate().getNode()->visit(this, visitor); |
10404 | getInput().getNode()->visit(this, visitor); |
10405 | getOriginalOutputForResult().getNode()->visit(this, visitor); |
10406 | getGradOfOriginalOutputNamedResult().getNode()->visit(this, visitor); |
10407 | visitor->post(parent, this); |
10408 | } |
10409 | |
10410 | bool TanhGradNode::isEqual(const TanhGradNode &other) const { |
10411 | return true && |
10412 | Input_ == other.Input_ && |
10413 | OriginalOutputForResult_ == other.OriginalOutputForResult_ && |
10414 | GradOfOriginalOutputNamedResult_ == other.GradOfOriginalOutputNamedResult_ && |
10415 | predicate_ == other.predicate_ && |
10416 | getType(0) == other.getType(0); |
10417 | } |
10418 | |
10419 | Node* TanhGradNode::clone() const { |
10420 | return new TanhGradNode(getName(), getInput(), getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult()); |
10421 | } |
10422 | |
10423 | llvm::hash_code TanhGradNode::getHash() const { |
10424 | return llvm::hash_combine( |
10425 | Input_, |
10426 | OriginalOutputForResult_, |
10427 | GradOfOriginalOutputNamedResult_); |
10428 | } |
10429 | |
10430 | unsigned TanhNode::getNumInputs() const { |
10431 | return 1; |
10432 | } |
10433 | |
10434 | std::string TanhNode::getInputName(unsigned idx) const { |
10435 | if (idx == 0) { return "Input" ; } |
10436 | idx -= 1; |
10437 | llvm_unreachable("Invalid index" ); |
10438 | } |
10439 | |
10440 | NodeValue TanhNode::getNthInput(unsigned idx) { |
10441 | if (idx == 0) { return Input_; } |
10442 | idx -= 1; |
10443 | llvm_unreachable("Invalid index" ); |
10444 | } |
10445 | |
10446 | void TanhNode::setNthInput(unsigned idx, NodeValue val) { |
10447 | if (idx == 0) { Input_ = val; return; } |
10448 | idx -= 1; |
10449 | llvm_unreachable("Invalid index" ); |
10450 | } |
10451 | |
10452 | llvm::StringRef TanhNode::getOutputName(unsigned idx) const { |
10453 | if (idx == 0) { return "Result" ; } |
10454 | llvm_unreachable("Invalid index" ); |
10455 | } |
10456 | |
10457 | std::string TanhNode::getDebugDesc() const { |
10458 | DescriptionBuilder db(getKindName()); |
10459 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10460 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10461 | db |
10462 | .addParam("Input" , *(getInput().getType())) |
10463 | .addParam("Users" , getNumUsers()); |
10464 | db.addParam("Result" , *(getResult().getType())); |
10465 | return db; |
10466 | } |
10467 | |
10468 | void TanhNode::visit(Node *parent, NodeWalker *visitor) { |
10469 | if (!visitor->shouldVisit(parent, this)) { return; } |
10470 | visitor->pre(parent, this); |
10471 | if (hasPredicate()) |
10472 | getPredicate().getNode()->visit(this, visitor); |
10473 | getInput().getNode()->visit(this, visitor); |
10474 | visitor->post(parent, this); |
10475 | } |
10476 | |
10477 | bool TanhNode::isEqual(const TanhNode &other) const { |
10478 | return true && |
10479 | Input_ == other.Input_ && |
10480 | predicate_ == other.predicate_ && |
10481 | getType(0) == other.getType(0); |
10482 | } |
10483 | |
10484 | Node* TanhNode::clone() const { |
10485 | return new TanhNode(getName(), getResult().getType(), getInput()); |
10486 | } |
10487 | |
10488 | llvm::hash_code TanhNode::getHash() const { |
10489 | return llvm::hash_combine( |
10490 | Input_); |
10491 | } |
10492 | |
10493 | TanhGradNode *TanhNode::getGrad(GraphGradMapper &builder) { |
10494 | auto *x = new TanhGradNode(getName().str() + "_grad" , getInput(), getResult(), builder.getGradient(getResult())); |
10495 | builder.addGradient(getInput(), x->getGradOfInputNamedInput()); |
10496 | return x; |
10497 | } |
10498 | |
10499 | unsigned LeakyReluNode::getNumInputs() const { |
10500 | return 1; |
10501 | } |
10502 | |
10503 | std::string LeakyReluNode::getInputName(unsigned idx) const { |
10504 | if (idx == 0) { return "Input" ; } |
10505 | idx -= 1; |
10506 | llvm_unreachable("Invalid index" ); |
10507 | } |
10508 | |
10509 | NodeValue LeakyReluNode::getNthInput(unsigned idx) { |
10510 | if (idx == 0) { return Input_; } |
10511 | idx -= 1; |
10512 | llvm_unreachable("Invalid index" ); |
10513 | } |
10514 | |
10515 | void LeakyReluNode::setNthInput(unsigned idx, NodeValue val) { |
10516 | if (idx == 0) { Input_ = val; return; } |
10517 | idx -= 1; |
10518 | llvm_unreachable("Invalid index" ); |
10519 | } |
10520 | |
10521 | llvm::StringRef LeakyReluNode::getOutputName(unsigned idx) const { |
10522 | if (idx == 0) { return "Result" ; } |
10523 | llvm_unreachable("Invalid index" ); |
10524 | } |
10525 | |
10526 | std::string LeakyReluNode::getDebugDesc() const { |
10527 | DescriptionBuilder db(getKindName()); |
10528 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10529 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10530 | db |
10531 | .addParam("Input" , *(getInput().getType())) |
10532 | .addParam("Alpha" , getAlpha()) |
10533 | .addParam("Users" , getNumUsers()); |
10534 | db.addParam("Result" , *(getResult().getType())); |
10535 | return db; |
10536 | } |
10537 | |
10538 | void LeakyReluNode::visit(Node *parent, NodeWalker *visitor) { |
10539 | if (!visitor->shouldVisit(parent, this)) { return; } |
10540 | visitor->pre(parent, this); |
10541 | if (hasPredicate()) |
10542 | getPredicate().getNode()->visit(this, visitor); |
10543 | getInput().getNode()->visit(this, visitor); |
10544 | visitor->post(parent, this); |
10545 | } |
10546 | |
10547 | bool LeakyReluNode::isEqual(const LeakyReluNode &other) const { |
10548 | return true && |
10549 | Input_ == other.Input_ && |
10550 | predicate_ == other.predicate_ && |
10551 | Alpha_ == other.Alpha_ && |
10552 | getType(0) == other.getType(0); |
10553 | } |
10554 | |
10555 | Node* LeakyReluNode::clone() const { |
10556 | return new LeakyReluNode(getName(), getResult().getType(), getInput(), getAlpha()); |
10557 | } |
10558 | |
10559 | llvm::hash_code LeakyReluNode::getHash() const { |
10560 | return llvm::hash_combine( |
10561 | toBinary(Alpha_), |
10562 | Input_); |
10563 | } |
10564 | |
10565 | unsigned SoftPlusNode::getNumInputs() const { |
10566 | return 1; |
10567 | } |
10568 | |
10569 | std::string SoftPlusNode::getInputName(unsigned idx) const { |
10570 | if (idx == 0) { return "Input" ; } |
10571 | idx -= 1; |
10572 | llvm_unreachable("Invalid index" ); |
10573 | } |
10574 | |
10575 | NodeValue SoftPlusNode::getNthInput(unsigned idx) { |
10576 | if (idx == 0) { return Input_; } |
10577 | idx -= 1; |
10578 | llvm_unreachable("Invalid index" ); |
10579 | } |
10580 | |
10581 | void SoftPlusNode::setNthInput(unsigned idx, NodeValue val) { |
10582 | if (idx == 0) { Input_ = val; return; } |
10583 | idx -= 1; |
10584 | llvm_unreachable("Invalid index" ); |
10585 | } |
10586 | |
10587 | llvm::StringRef SoftPlusNode::getOutputName(unsigned idx) const { |
10588 | if (idx == 0) { return "Result" ; } |
10589 | llvm_unreachable("Invalid index" ); |
10590 | } |
10591 | |
10592 | std::string SoftPlusNode::getDebugDesc() const { |
10593 | DescriptionBuilder db(getKindName()); |
10594 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10595 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10596 | db |
10597 | .addParam("Input" , *(getInput().getType())) |
10598 | .addParam("Users" , getNumUsers()); |
10599 | db.addParam("Result" , *(getResult().getType())); |
10600 | return db; |
10601 | } |
10602 | |
10603 | void SoftPlusNode::visit(Node *parent, NodeWalker *visitor) { |
10604 | if (!visitor->shouldVisit(parent, this)) { return; } |
10605 | visitor->pre(parent, this); |
10606 | if (hasPredicate()) |
10607 | getPredicate().getNode()->visit(this, visitor); |
10608 | getInput().getNode()->visit(this, visitor); |
10609 | visitor->post(parent, this); |
10610 | } |
10611 | |
10612 | bool SoftPlusNode::isEqual(const SoftPlusNode &other) const { |
10613 | return true && |
10614 | Input_ == other.Input_ && |
10615 | predicate_ == other.predicate_ && |
10616 | getType(0) == other.getType(0); |
10617 | } |
10618 | |
10619 | Node* SoftPlusNode::clone() const { |
10620 | return new SoftPlusNode(getName(), getResult().getType(), getInput()); |
10621 | } |
10622 | |
10623 | llvm::hash_code SoftPlusNode::getHash() const { |
10624 | return llvm::hash_combine( |
10625 | Input_); |
10626 | } |
10627 | |
10628 | unsigned ReshapeNode::getNumInputs() const { |
10629 | return 1; |
10630 | } |
10631 | |
10632 | std::string ReshapeNode::getInputName(unsigned idx) const { |
10633 | if (idx == 0) { return "Input" ; } |
10634 | idx -= 1; |
10635 | llvm_unreachable("Invalid index" ); |
10636 | } |
10637 | |
10638 | NodeValue ReshapeNode::getNthInput(unsigned idx) { |
10639 | if (idx == 0) { return Input_; } |
10640 | idx -= 1; |
10641 | llvm_unreachable("Invalid index" ); |
10642 | } |
10643 | |
10644 | void ReshapeNode::setNthInput(unsigned idx, NodeValue val) { |
10645 | if (idx == 0) { Input_ = val; return; } |
10646 | idx -= 1; |
10647 | llvm_unreachable("Invalid index" ); |
10648 | } |
10649 | |
10650 | llvm::StringRef ReshapeNode::getOutputName(unsigned idx) const { |
10651 | if (idx == 0) { return "Result" ; } |
10652 | llvm_unreachable("Invalid index" ); |
10653 | } |
10654 | |
10655 | std::string ReshapeNode::getDebugDesc() const { |
10656 | DescriptionBuilder db(getKindName()); |
10657 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10658 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10659 | db |
10660 | .addParam("Input" , *(getInput().getType())) |
10661 | .addParam("Dims" , getDims()) |
10662 | .addParam("Layout" , getLayout()) |
10663 | .addParam("Users" , getNumUsers()); |
10664 | db.addParam("Result" , *(getResult().getType())); |
10665 | return db; |
10666 | } |
10667 | |
10668 | void ReshapeNode::visit(Node *parent, NodeWalker *visitor) { |
10669 | if (!visitor->shouldVisit(parent, this)) { return; } |
10670 | visitor->pre(parent, this); |
10671 | if (hasPredicate()) |
10672 | getPredicate().getNode()->visit(this, visitor); |
10673 | getInput().getNode()->visit(this, visitor); |
10674 | visitor->post(parent, this); |
10675 | } |
10676 | |
10677 | bool ReshapeNode::isEqual(const ReshapeNode &other) const { |
10678 | return true && |
10679 | Input_ == other.Input_ && |
10680 | predicate_ == other.predicate_ && |
10681 | Dims_ == other.Dims_ && |
10682 | Layout_ == other.Layout_ && |
10683 | getType(0) == other.getType(0); |
10684 | } |
10685 | |
10686 | Node* ReshapeNode::clone() const { |
10687 | return new ReshapeNode(getName(), getResult().getType(), getInput(), getDims(), getLayout()); |
10688 | } |
10689 | |
10690 | llvm::hash_code ReshapeNode::getHash() const { |
10691 | return llvm::hash_combine( |
10692 | llvm::hash_combine_range(Dims_.begin(), Dims_.end()), |
10693 | Layout_, |
10694 | Input_); |
10695 | } |
10696 | |
10697 | unsigned TransposeNode::getNumInputs() const { |
10698 | return 1; |
10699 | } |
10700 | |
10701 | std::string TransposeNode::getInputName(unsigned idx) const { |
10702 | if (idx == 0) { return "Input" ; } |
10703 | idx -= 1; |
10704 | llvm_unreachable("Invalid index" ); |
10705 | } |
10706 | |
10707 | NodeValue TransposeNode::getNthInput(unsigned idx) { |
10708 | if (idx == 0) { return Input_; } |
10709 | idx -= 1; |
10710 | llvm_unreachable("Invalid index" ); |
10711 | } |
10712 | |
10713 | void TransposeNode::setNthInput(unsigned idx, NodeValue val) { |
10714 | if (idx == 0) { Input_ = val; return; } |
10715 | idx -= 1; |
10716 | llvm_unreachable("Invalid index" ); |
10717 | } |
10718 | |
10719 | llvm::StringRef TransposeNode::getOutputName(unsigned idx) const { |
10720 | if (idx == 0) { return "Result" ; } |
10721 | llvm_unreachable("Invalid index" ); |
10722 | } |
10723 | |
10724 | std::string TransposeNode::getDebugDesc() const { |
10725 | DescriptionBuilder db(getKindName()); |
10726 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10727 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10728 | db |
10729 | .addParam("Input" , *(getInput().getType())) |
10730 | .addParam("Shuffle" , getShuffle()) |
10731 | .addParam("Layout" , getLayout()) |
10732 | .addParam("Users" , getNumUsers()); |
10733 | db.addParam("Result" , *(getResult().getType())); |
10734 | return db; |
10735 | } |
10736 | |
10737 | void TransposeNode::visit(Node *parent, NodeWalker *visitor) { |
10738 | if (!visitor->shouldVisit(parent, this)) { return; } |
10739 | visitor->pre(parent, this); |
10740 | if (hasPredicate()) |
10741 | getPredicate().getNode()->visit(this, visitor); |
10742 | getInput().getNode()->visit(this, visitor); |
10743 | visitor->post(parent, this); |
10744 | } |
10745 | |
10746 | bool TransposeNode::isEqual(const TransposeNode &other) const { |
10747 | return true && |
10748 | Input_ == other.Input_ && |
10749 | predicate_ == other.predicate_ && |
10750 | Shuffle_ == other.Shuffle_ && |
10751 | Layout_ == other.Layout_ && |
10752 | getType(0) == other.getType(0); |
10753 | } |
10754 | |
10755 | Node* TransposeNode::clone() const { |
10756 | return new TransposeNode(getName(), getResult().getType(), getInput(), getShuffle(), getLayout()); |
10757 | } |
10758 | |
10759 | llvm::hash_code TransposeNode::getHash() const { |
10760 | return llvm::hash_combine( |
10761 | llvm::hash_combine_range(Shuffle_.begin(), Shuffle_.end()), |
10762 | Layout_, |
10763 | Input_); |
10764 | } |
10765 | |
10766 | unsigned ConcatNode::getNumInputs() const { |
10767 | return 0 + Inputs_.size(); |
10768 | } |
10769 | |
10770 | std::string ConcatNode::getInputName(unsigned idx) const { |
10771 | idx -= 0; |
10772 | if (idx < Inputs_.size()) { return "Inputs" + std::to_string(idx); } |
10773 | idx -= Inputs_.size(); |
10774 | llvm_unreachable("Invalid index" ); |
10775 | } |
10776 | |
10777 | NodeValue ConcatNode::getNthInput(unsigned idx) { |
10778 | idx -= 0; |
10779 | if (idx < Inputs_.size()) { return Inputs_[idx]; } |
10780 | idx -= Inputs_.size(); |
10781 | llvm_unreachable("Invalid index" ); |
10782 | } |
10783 | |
10784 | void ConcatNode::setNthInput(unsigned idx, NodeValue val) { |
10785 | idx -= 0; |
10786 | if (idx < Inputs_.size()) { Inputs_[idx] = val; return; } |
10787 | idx -= Inputs_.size(); |
10788 | llvm_unreachable("Invalid index" ); |
10789 | } |
10790 | |
10791 | llvm::StringRef ConcatNode::getOutputName(unsigned idx) const { |
10792 | if (idx == 0) { return "Result" ; } |
10793 | llvm_unreachable("Invalid index" ); |
10794 | } |
10795 | |
10796 | std::string ConcatNode::getDebugDesc() const { |
10797 | DescriptionBuilder db(getKindName()); |
10798 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10799 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10800 | db |
10801 | .addParam("Dim" , getDim()) |
10802 | .addParam("Users" , getNumUsers()); |
10803 | { |
10804 | unsigned mIndex = 0; |
10805 | for (const auto &II : getInputs()) { |
10806 | db.addParam("Inputs" +std::to_string(mIndex++), *II.getType()); |
10807 | } |
10808 | } |
10809 | db.addParam("Result" , *(getResult().getType())); |
10810 | return db; |
10811 | } |
10812 | |
10813 | void ConcatNode::visit(Node *parent, NodeWalker *visitor) { |
10814 | if (!visitor->shouldVisit(parent, this)) { return; } |
10815 | visitor->pre(parent, this); |
10816 | if (hasPredicate()) |
10817 | getPredicate().getNode()->visit(this, visitor); |
10818 | for (auto &I : Inputs_) { I.getNode()->visit(this, visitor); } |
10819 | visitor->post(parent, this); |
10820 | } |
10821 | |
10822 | bool ConcatNode::isEqual(const ConcatNode &other) const { |
10823 | return true && |
10824 | predicate_ == other.predicate_ && |
10825 | Inputs_ == other.Inputs_ && |
10826 | Dim_ == other.Dim_ && |
10827 | getType(0) == other.getType(0); |
10828 | } |
10829 | |
10830 | Node* ConcatNode::clone() const { |
10831 | return new ConcatNode(getName(), getResult().getType(), getInputs(), getDim()); |
10832 | } |
10833 | |
10834 | llvm::hash_code ConcatNode::getHash() const { |
10835 | return llvm::hash_combine( |
10836 | llvm::hash_combine_range(Inputs_.begin(), Inputs_.end()), |
10837 | Dim_); |
10838 | } |
10839 | |
10840 | unsigned SliceNode::getNumInputs() const { |
10841 | return 1; |
10842 | } |
10843 | |
10844 | std::string SliceNode::getInputName(unsigned idx) const { |
10845 | if (idx == 0) { return "Input" ; } |
10846 | idx -= 1; |
10847 | llvm_unreachable("Invalid index" ); |
10848 | } |
10849 | |
10850 | NodeValue SliceNode::getNthInput(unsigned idx) { |
10851 | if (idx == 0) { return Input_; } |
10852 | idx -= 1; |
10853 | llvm_unreachable("Invalid index" ); |
10854 | } |
10855 | |
10856 | void SliceNode::setNthInput(unsigned idx, NodeValue val) { |
10857 | if (idx == 0) { Input_ = val; return; } |
10858 | idx -= 1; |
10859 | llvm_unreachable("Invalid index" ); |
10860 | } |
10861 | |
10862 | llvm::StringRef SliceNode::getOutputName(unsigned idx) const { |
10863 | if (idx == 0) { return "Result" ; } |
10864 | llvm_unreachable("Invalid index" ); |
10865 | } |
10866 | |
10867 | std::string SliceNode::getDebugDesc() const { |
10868 | DescriptionBuilder db(getKindName()); |
10869 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10870 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10871 | db |
10872 | .addParam("Input" , *(getInput().getType())) |
10873 | .addParam("Start" , getStart()) |
10874 | .addParam("Users" , getNumUsers()); |
10875 | db.addParam("Result" , *(getResult().getType())); |
10876 | return db; |
10877 | } |
10878 | |
10879 | void SliceNode::visit(Node *parent, NodeWalker *visitor) { |
10880 | if (!visitor->shouldVisit(parent, this)) { return; } |
10881 | visitor->pre(parent, this); |
10882 | if (hasPredicate()) |
10883 | getPredicate().getNode()->visit(this, visitor); |
10884 | getInput().getNode()->visit(this, visitor); |
10885 | visitor->post(parent, this); |
10886 | } |
10887 | |
10888 | bool SliceNode::isEqual(const SliceNode &other) const { |
10889 | return true && |
10890 | Input_ == other.Input_ && |
10891 | predicate_ == other.predicate_ && |
10892 | Start_ == other.Start_ && |
10893 | getType(0) == other.getType(0); |
10894 | } |
10895 | |
10896 | Node* SliceNode::clone() const { |
10897 | return new SliceNode(getName(), getResult().getType(), getInput(), getStart()); |
10898 | } |
10899 | |
10900 | llvm::hash_code SliceNode::getHash() const { |
10901 | return llvm::hash_combine( |
10902 | llvm::hash_combine_range(Start_.begin(), Start_.end()), |
10903 | Input_); |
10904 | } |
10905 | |
10906 | unsigned InsertTensorNode::getNumInputs() const { |
10907 | return 2; |
10908 | } |
10909 | |
10910 | std::string InsertTensorNode::getInputName(unsigned idx) const { |
10911 | if (idx == 0) { return "Big" ; } |
10912 | if (idx == 1) { return "Small" ; } |
10913 | idx -= 2; |
10914 | llvm_unreachable("Invalid index" ); |
10915 | } |
10916 | |
10917 | NodeValue InsertTensorNode::getNthInput(unsigned idx) { |
10918 | if (idx == 0) { return Big_; } |
10919 | if (idx == 1) { return Small_; } |
10920 | idx -= 2; |
10921 | llvm_unreachable("Invalid index" ); |
10922 | } |
10923 | |
10924 | void InsertTensorNode::setNthInput(unsigned idx, NodeValue val) { |
10925 | if (idx == 0) { Big_ = val; return; } |
10926 | if (idx == 1) { Small_ = val; return; } |
10927 | idx -= 2; |
10928 | llvm_unreachable("Invalid index" ); |
10929 | } |
10930 | |
10931 | llvm::StringRef InsertTensorNode::getOutputName(unsigned idx) const { |
10932 | if (idx == 0) { return "Result" ; } |
10933 | llvm_unreachable("Invalid index" ); |
10934 | } |
10935 | |
10936 | std::string InsertTensorNode::getDebugDesc() const { |
10937 | DescriptionBuilder db(getKindName()); |
10938 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
10939 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
10940 | db |
10941 | .addParam("Big" , *(getBig().getType())) |
10942 | .addParam("Small" , *(getSmall().getType())) |
10943 | .addParam("Start" , getStart()) |
10944 | .addParam("Count" , getCount()) |
10945 | .addParam("Axis" , getAxis()) |
10946 | .addParam("Users" , getNumUsers()); |
10947 | db.addParam("Result" , *(getResult().getType())); |
10948 | return db; |
10949 | } |
10950 | |
10951 | void InsertTensorNode::visit(Node *parent, NodeWalker *visitor) { |
10952 | if (!visitor->shouldVisit(parent, this)) { return; } |
10953 | visitor->pre(parent, this); |
10954 | if (hasPredicate()) |
10955 | getPredicate().getNode()->visit(this, visitor); |
10956 | getBig().getNode()->visit(this, visitor); |
10957 | getSmall().getNode()->visit(this, visitor); |
10958 | visitor->post(parent, this); |
10959 | } |
10960 | |
10961 | bool InsertTensorNode::isEqual(const InsertTensorNode &other) const { |
10962 | return true && |
10963 | Big_ == other.Big_ && |
10964 | Small_ == other.Small_ && |
10965 | predicate_ == other.predicate_ && |
10966 | Start_ == other.Start_ && |
10967 | Count_ == other.Count_ && |
10968 | Axis_ == other.Axis_ && |
10969 | getType(0) == other.getType(0); |
10970 | } |
10971 | |
10972 | Node* InsertTensorNode::clone() const { |
10973 | return new InsertTensorNode(getName(), getBig(), getSmall(), getStart(), getCount(), getAxis()); |
10974 | } |
10975 | |
10976 | llvm::hash_code InsertTensorNode::getHash() const { |
10977 | return llvm::hash_combine( |
10978 | llvm::hash_combine_range(Start_.begin(), Start_.end()), |
10979 | Count_, |
10980 | Axis_, |
10981 | Big_, |
10982 | Small_); |
10983 | } |
10984 | |
10985 | unsigned GatherNode::getNumInputs() const { |
10986 | return 2; |
10987 | } |
10988 | |
10989 | std::string GatherNode::getInputName(unsigned idx) const { |
10990 | if (idx == 0) { return "Data" ; } |
10991 | if (idx == 1) { return "Indices" ; } |
10992 | idx -= 2; |
10993 | llvm_unreachable("Invalid index" ); |
10994 | } |
10995 | |
10996 | NodeValue GatherNode::getNthInput(unsigned idx) { |
10997 | if (idx == 0) { return Data_; } |
10998 | if (idx == 1) { return Indices_; } |
10999 | idx -= 2; |
11000 | llvm_unreachable("Invalid index" ); |
11001 | } |
11002 | |
11003 | void GatherNode::setNthInput(unsigned idx, NodeValue val) { |
11004 | if (idx == 0) { Data_ = val; return; } |
11005 | if (idx == 1) { Indices_ = val; return; } |
11006 | idx -= 2; |
11007 | llvm_unreachable("Invalid index" ); |
11008 | } |
11009 | |
11010 | llvm::StringRef GatherNode::getOutputName(unsigned idx) const { |
11011 | if (idx == 0) { return "Result" ; } |
11012 | llvm_unreachable("Invalid index" ); |
11013 | } |
11014 | |
11015 | std::string GatherNode::getDebugDesc() const { |
11016 | DescriptionBuilder db(getKindName()); |
11017 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11018 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11019 | db |
11020 | .addParam("Data" , *(getData().getType())) |
11021 | .addParam("Indices" , *(getIndices().getType())) |
11022 | .addParam("BatchDims" , getBatchDims()) |
11023 | .addParam("Users" , getNumUsers()); |
11024 | db.addParam("Result" , *(getResult().getType())); |
11025 | return db; |
11026 | } |
11027 | |
11028 | void GatherNode::visit(Node *parent, NodeWalker *visitor) { |
11029 | if (!visitor->shouldVisit(parent, this)) { return; } |
11030 | visitor->pre(parent, this); |
11031 | if (hasPredicate()) |
11032 | getPredicate().getNode()->visit(this, visitor); |
11033 | getData().getNode()->visit(this, visitor); |
11034 | getIndices().getNode()->visit(this, visitor); |
11035 | visitor->post(parent, this); |
11036 | } |
11037 | |
11038 | bool GatherNode::isEqual(const GatherNode &other) const { |
11039 | return true && |
11040 | Data_ == other.Data_ && |
11041 | Indices_ == other.Indices_ && |
11042 | predicate_ == other.predicate_ && |
11043 | BatchDims_ == other.BatchDims_ && |
11044 | getType(0) == other.getType(0); |
11045 | } |
11046 | |
11047 | Node* GatherNode::clone() const { |
11048 | return new GatherNode(getName(), getResult().getType(), getData(), getIndices(), getBatchDims()); |
11049 | } |
11050 | |
11051 | llvm::hash_code GatherNode::getHash() const { |
11052 | return llvm::hash_combine( |
11053 | BatchDims_, |
11054 | Data_, |
11055 | Indices_); |
11056 | } |
11057 | |
11058 | unsigned GatherNDNode::getNumInputs() const { |
11059 | return 2; |
11060 | } |
11061 | |
11062 | std::string GatherNDNode::getInputName(unsigned idx) const { |
11063 | if (idx == 0) { return "Data" ; } |
11064 | if (idx == 1) { return "Indices" ; } |
11065 | idx -= 2; |
11066 | llvm_unreachable("Invalid index" ); |
11067 | } |
11068 | |
11069 | NodeValue GatherNDNode::getNthInput(unsigned idx) { |
11070 | if (idx == 0) { return Data_; } |
11071 | if (idx == 1) { return Indices_; } |
11072 | idx -= 2; |
11073 | llvm_unreachable("Invalid index" ); |
11074 | } |
11075 | |
11076 | void GatherNDNode::setNthInput(unsigned idx, NodeValue val) { |
11077 | if (idx == 0) { Data_ = val; return; } |
11078 | if (idx == 1) { Indices_ = val; return; } |
11079 | idx -= 2; |
11080 | llvm_unreachable("Invalid index" ); |
11081 | } |
11082 | |
11083 | llvm::StringRef GatherNDNode::getOutputName(unsigned idx) const { |
11084 | if (idx == 0) { return "Result" ; } |
11085 | llvm_unreachable("Invalid index" ); |
11086 | } |
11087 | |
11088 | std::string GatherNDNode::getDebugDesc() const { |
11089 | DescriptionBuilder db(getKindName()); |
11090 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11091 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11092 | db |
11093 | .addParam("Data" , *(getData().getType())) |
11094 | .addParam("Indices" , *(getIndices().getType())) |
11095 | .addParam("BatchDims" , getBatchDims()) |
11096 | .addParam("Users" , getNumUsers()); |
11097 | db.addParam("Result" , *(getResult().getType())); |
11098 | return db; |
11099 | } |
11100 | |
11101 | void GatherNDNode::visit(Node *parent, NodeWalker *visitor) { |
11102 | if (!visitor->shouldVisit(parent, this)) { return; } |
11103 | visitor->pre(parent, this); |
11104 | if (hasPredicate()) |
11105 | getPredicate().getNode()->visit(this, visitor); |
11106 | getData().getNode()->visit(this, visitor); |
11107 | getIndices().getNode()->visit(this, visitor); |
11108 | visitor->post(parent, this); |
11109 | } |
11110 | |
11111 | bool GatherNDNode::isEqual(const GatherNDNode &other) const { |
11112 | return true && |
11113 | Data_ == other.Data_ && |
11114 | Indices_ == other.Indices_ && |
11115 | predicate_ == other.predicate_ && |
11116 | BatchDims_ == other.BatchDims_ && |
11117 | getType(0) == other.getType(0); |
11118 | } |
11119 | |
11120 | Node* GatherNDNode::clone() const { |
11121 | return new GatherNDNode(getName(), getResult().getType(), getData(), getIndices(), getBatchDims()); |
11122 | } |
11123 | |
11124 | llvm::hash_code GatherNDNode::getHash() const { |
11125 | return llvm::hash_combine( |
11126 | BatchDims_, |
11127 | Data_, |
11128 | Indices_); |
11129 | } |
11130 | |
11131 | unsigned GatherElementsNode::getNumInputs() const { |
11132 | return 2; |
11133 | } |
11134 | |
11135 | std::string GatherElementsNode::getInputName(unsigned idx) const { |
11136 | if (idx == 0) { return "Data" ; } |
11137 | if (idx == 1) { return "Indices" ; } |
11138 | idx -= 2; |
11139 | llvm_unreachable("Invalid index" ); |
11140 | } |
11141 | |
11142 | NodeValue GatherElementsNode::getNthInput(unsigned idx) { |
11143 | if (idx == 0) { return Data_; } |
11144 | if (idx == 1) { return Indices_; } |
11145 | idx -= 2; |
11146 | llvm_unreachable("Invalid index" ); |
11147 | } |
11148 | |
11149 | void GatherElementsNode::setNthInput(unsigned idx, NodeValue val) { |
11150 | if (idx == 0) { Data_ = val; return; } |
11151 | if (idx == 1) { Indices_ = val; return; } |
11152 | idx -= 2; |
11153 | llvm_unreachable("Invalid index" ); |
11154 | } |
11155 | |
11156 | llvm::StringRef GatherElementsNode::getOutputName(unsigned idx) const { |
11157 | if (idx == 0) { return "Result" ; } |
11158 | llvm_unreachable("Invalid index" ); |
11159 | } |
11160 | |
11161 | std::string GatherElementsNode::getDebugDesc() const { |
11162 | DescriptionBuilder db(getKindName()); |
11163 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11164 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11165 | db |
11166 | .addParam("Data" , *(getData().getType())) |
11167 | .addParam("Indices" , *(getIndices().getType())) |
11168 | .addParam("Dim" , getDim()) |
11169 | .addParam("Users" , getNumUsers()); |
11170 | db.addParam("Result" , *(getResult().getType())); |
11171 | return db; |
11172 | } |
11173 | |
11174 | void GatherElementsNode::visit(Node *parent, NodeWalker *visitor) { |
11175 | if (!visitor->shouldVisit(parent, this)) { return; } |
11176 | visitor->pre(parent, this); |
11177 | if (hasPredicate()) |
11178 | getPredicate().getNode()->visit(this, visitor); |
11179 | getData().getNode()->visit(this, visitor); |
11180 | getIndices().getNode()->visit(this, visitor); |
11181 | visitor->post(parent, this); |
11182 | } |
11183 | |
11184 | bool GatherElementsNode::isEqual(const GatherElementsNode &other) const { |
11185 | return true && |
11186 | Data_ == other.Data_ && |
11187 | Indices_ == other.Indices_ && |
11188 | predicate_ == other.predicate_ && |
11189 | Dim_ == other.Dim_ && |
11190 | getType(0) == other.getType(0); |
11191 | } |
11192 | |
11193 | Node* GatherElementsNode::clone() const { |
11194 | return new GatherElementsNode(getName(), getResult().getType(), getData(), getIndices(), getDim()); |
11195 | } |
11196 | |
11197 | llvm::hash_code GatherElementsNode::getHash() const { |
11198 | return llvm::hash_combine( |
11199 | Dim_, |
11200 | Data_, |
11201 | Indices_); |
11202 | } |
11203 | |
11204 | unsigned GatherRangesNode::getNumInputs() const { |
11205 | return 2; |
11206 | } |
11207 | |
11208 | std::string GatherRangesNode::getInputName(unsigned idx) const { |
11209 | if (idx == 0) { return "Data" ; } |
11210 | if (idx == 1) { return "Ranges" ; } |
11211 | idx -= 2; |
11212 | llvm_unreachable("Invalid index" ); |
11213 | } |
11214 | |
11215 | NodeValue GatherRangesNode::getNthInput(unsigned idx) { |
11216 | if (idx == 0) { return Data_; } |
11217 | if (idx == 1) { return Ranges_; } |
11218 | idx -= 2; |
11219 | llvm_unreachable("Invalid index" ); |
11220 | } |
11221 | |
11222 | void GatherRangesNode::setNthInput(unsigned idx, NodeValue val) { |
11223 | if (idx == 0) { Data_ = val; return; } |
11224 | if (idx == 1) { Ranges_ = val; return; } |
11225 | idx -= 2; |
11226 | llvm_unreachable("Invalid index" ); |
11227 | } |
11228 | |
11229 | llvm::StringRef GatherRangesNode::getOutputName(unsigned idx) const { |
11230 | if (idx == 0) { return "Output" ; } |
11231 | if (idx == 1) { return "Lengths" ; } |
11232 | llvm_unreachable("Invalid index" ); |
11233 | } |
11234 | |
11235 | std::string GatherRangesNode::getDebugDesc() const { |
11236 | DescriptionBuilder db(getKindName()); |
11237 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11238 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11239 | db |
11240 | .addParam("Data" , *(getData().getType())) |
11241 | .addParam("Ranges" , *(getRanges().getType())) |
11242 | .addParam("Users" , getNumUsers()); |
11243 | db.addParam("Output" , *(getOutput().getType())); |
11244 | db.addParam("Lengths" , *(getLengths().getType())); |
11245 | return db; |
11246 | } |
11247 | |
11248 | void GatherRangesNode::visit(Node *parent, NodeWalker *visitor) { |
11249 | if (!visitor->shouldVisit(parent, this)) { return; } |
11250 | visitor->pre(parent, this); |
11251 | if (hasPredicate()) |
11252 | getPredicate().getNode()->visit(this, visitor); |
11253 | getData().getNode()->visit(this, visitor); |
11254 | getRanges().getNode()->visit(this, visitor); |
11255 | visitor->post(parent, this); |
11256 | } |
11257 | |
11258 | bool GatherRangesNode::isEqual(const GatherRangesNode &other) const { |
11259 | return true && |
11260 | Data_ == other.Data_ && |
11261 | Ranges_ == other.Ranges_ && |
11262 | predicate_ == other.predicate_ && |
11263 | getType(0) == other.getType(0) && |
11264 | getType(1) == other.getType(1); |
11265 | } |
11266 | |
11267 | Node* GatherRangesNode::clone() const { |
11268 | return new GatherRangesNode(getName(), getOutput().getType(), getLengths().getType(), getData(), getRanges()); |
11269 | } |
11270 | |
11271 | llvm::hash_code GatherRangesNode::getHash() const { |
11272 | return llvm::hash_combine( |
11273 | Data_, |
11274 | Ranges_); |
11275 | } |
11276 | |
11277 | unsigned ScatterDataNode::getNumInputs() const { |
11278 | return 3; |
11279 | } |
11280 | |
11281 | std::string ScatterDataNode::getInputName(unsigned idx) const { |
11282 | if (idx == 0) { return "Data" ; } |
11283 | if (idx == 1) { return "Indices" ; } |
11284 | if (idx == 2) { return "Slices" ; } |
11285 | idx -= 3; |
11286 | llvm_unreachable("Invalid index" ); |
11287 | } |
11288 | |
11289 | NodeValue ScatterDataNode::getNthInput(unsigned idx) { |
11290 | if (idx == 0) { return Data_; } |
11291 | if (idx == 1) { return Indices_; } |
11292 | if (idx == 2) { return Slices_; } |
11293 | idx -= 3; |
11294 | llvm_unreachable("Invalid index" ); |
11295 | } |
11296 | |
11297 | void ScatterDataNode::setNthInput(unsigned idx, NodeValue val) { |
11298 | if (idx == 0) { Data_ = val; return; } |
11299 | if (idx == 1) { Indices_ = val; return; } |
11300 | if (idx == 2) { Slices_ = val; return; } |
11301 | idx -= 3; |
11302 | llvm_unreachable("Invalid index" ); |
11303 | } |
11304 | |
11305 | llvm::StringRef ScatterDataNode::getOutputName(unsigned idx) const { |
11306 | if (idx == 0) { return "Result" ; } |
11307 | llvm_unreachable("Invalid index" ); |
11308 | } |
11309 | |
11310 | std::string ScatterDataNode::getDebugDesc() const { |
11311 | DescriptionBuilder db(getKindName()); |
11312 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11313 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11314 | db |
11315 | .addParam("Data" , *(getData().getType())) |
11316 | .addParam("Indices" , *(getIndices().getType())) |
11317 | .addParam("Slices" , *(getSlices().getType())) |
11318 | .addParam("Cumulative" , getCumulative()) |
11319 | .addParam("Users" , getNumUsers()); |
11320 | db.addParam("Result" , *(getResult().getType())); |
11321 | return db; |
11322 | } |
11323 | |
11324 | void ScatterDataNode::visit(Node *parent, NodeWalker *visitor) { |
11325 | if (!visitor->shouldVisit(parent, this)) { return; } |
11326 | visitor->pre(parent, this); |
11327 | if (hasPredicate()) |
11328 | getPredicate().getNode()->visit(this, visitor); |
11329 | getData().getNode()->visit(this, visitor); |
11330 | getIndices().getNode()->visit(this, visitor); |
11331 | getSlices().getNode()->visit(this, visitor); |
11332 | visitor->post(parent, this); |
11333 | } |
11334 | |
11335 | bool ScatterDataNode::isEqual(const ScatterDataNode &other) const { |
11336 | return true && |
11337 | Data_ == other.Data_ && |
11338 | Indices_ == other.Indices_ && |
11339 | Slices_ == other.Slices_ && |
11340 | predicate_ == other.predicate_ && |
11341 | Cumulative_ == other.Cumulative_ && |
11342 | getType(0) == other.getType(0); |
11343 | } |
11344 | |
11345 | Node* ScatterDataNode::clone() const { |
11346 | return new ScatterDataNode(getName(), getData(), getIndices(), getSlices(), getCumulative()); |
11347 | } |
11348 | |
11349 | llvm::hash_code ScatterDataNode::getHash() const { |
11350 | return llvm::hash_combine( |
11351 | Cumulative_, |
11352 | Data_, |
11353 | Indices_, |
11354 | Slices_); |
11355 | } |
11356 | |
11357 | unsigned TileNode::getNumInputs() const { |
11358 | return 1; |
11359 | } |
11360 | |
11361 | std::string TileNode::getInputName(unsigned idx) const { |
11362 | if (idx == 0) { return "Input" ; } |
11363 | idx -= 1; |
11364 | llvm_unreachable("Invalid index" ); |
11365 | } |
11366 | |
11367 | NodeValue TileNode::getNthInput(unsigned idx) { |
11368 | if (idx == 0) { return Input_; } |
11369 | idx -= 1; |
11370 | llvm_unreachable("Invalid index" ); |
11371 | } |
11372 | |
11373 | void TileNode::setNthInput(unsigned idx, NodeValue val) { |
11374 | if (idx == 0) { Input_ = val; return; } |
11375 | idx -= 1; |
11376 | llvm_unreachable("Invalid index" ); |
11377 | } |
11378 | |
11379 | llvm::StringRef TileNode::getOutputName(unsigned idx) const { |
11380 | if (idx == 0) { return "Result" ; } |
11381 | llvm_unreachable("Invalid index" ); |
11382 | } |
11383 | |
11384 | std::string TileNode::getDebugDesc() const { |
11385 | DescriptionBuilder db(getKindName()); |
11386 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11387 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11388 | db |
11389 | .addParam("Input" , *(getInput().getType())) |
11390 | .addParam("Count" , getCount()) |
11391 | .addParam("Axis" , getAxis()) |
11392 | .addParam("Users" , getNumUsers()); |
11393 | db.addParam("Result" , *(getResult().getType())); |
11394 | return db; |
11395 | } |
11396 | |
11397 | void TileNode::visit(Node *parent, NodeWalker *visitor) { |
11398 | if (!visitor->shouldVisit(parent, this)) { return; } |
11399 | visitor->pre(parent, this); |
11400 | if (hasPredicate()) |
11401 | getPredicate().getNode()->visit(this, visitor); |
11402 | getInput().getNode()->visit(this, visitor); |
11403 | visitor->post(parent, this); |
11404 | } |
11405 | |
11406 | bool TileNode::isEqual(const TileNode &other) const { |
11407 | return true && |
11408 | Input_ == other.Input_ && |
11409 | predicate_ == other.predicate_ && |
11410 | Count_ == other.Count_ && |
11411 | Axis_ == other.Axis_ && |
11412 | getType(0) == other.getType(0); |
11413 | } |
11414 | |
11415 | Node* TileNode::clone() const { |
11416 | return new TileNode(getName(), getResult().getType(), getInput(), getCount(), getAxis()); |
11417 | } |
11418 | |
11419 | llvm::hash_code TileNode::getHash() const { |
11420 | return llvm::hash_combine( |
11421 | Count_, |
11422 | Axis_, |
11423 | Input_); |
11424 | } |
11425 | |
11426 | unsigned BatchOneHotNode::getNumInputs() const { |
11427 | return 3; |
11428 | } |
11429 | |
11430 | std::string BatchOneHotNode::getInputName(unsigned idx) const { |
11431 | if (idx == 0) { return "Data" ; } |
11432 | if (idx == 1) { return "Lengths" ; } |
11433 | if (idx == 2) { return "Values" ; } |
11434 | idx -= 3; |
11435 | llvm_unreachable("Invalid index" ); |
11436 | } |
11437 | |
11438 | NodeValue BatchOneHotNode::getNthInput(unsigned idx) { |
11439 | if (idx == 0) { return Data_; } |
11440 | if (idx == 1) { return Lengths_; } |
11441 | if (idx == 2) { return Values_; } |
11442 | idx -= 3; |
11443 | llvm_unreachable("Invalid index" ); |
11444 | } |
11445 | |
11446 | void BatchOneHotNode::setNthInput(unsigned idx, NodeValue val) { |
11447 | if (idx == 0) { Data_ = val; return; } |
11448 | if (idx == 1) { Lengths_ = val; return; } |
11449 | if (idx == 2) { Values_ = val; return; } |
11450 | idx -= 3; |
11451 | llvm_unreachable("Invalid index" ); |
11452 | } |
11453 | |
11454 | llvm::StringRef BatchOneHotNode::getOutputName(unsigned idx) const { |
11455 | if (idx == 0) { return "Result" ; } |
11456 | llvm_unreachable("Invalid index" ); |
11457 | } |
11458 | |
11459 | std::string BatchOneHotNode::getDebugDesc() const { |
11460 | DescriptionBuilder db(getKindName()); |
11461 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11462 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11463 | db |
11464 | .addParam("Data" , *(getData().getType())) |
11465 | .addParam("Lengths" , *(getLengths().getType())) |
11466 | .addParam("Values" , *(getValues().getType())) |
11467 | .addParam("Users" , getNumUsers()); |
11468 | db.addParam("Result" , *(getResult().getType())); |
11469 | return db; |
11470 | } |
11471 | |
11472 | void BatchOneHotNode::visit(Node *parent, NodeWalker *visitor) { |
11473 | if (!visitor->shouldVisit(parent, this)) { return; } |
11474 | visitor->pre(parent, this); |
11475 | if (hasPredicate()) |
11476 | getPredicate().getNode()->visit(this, visitor); |
11477 | getData().getNode()->visit(this, visitor); |
11478 | getLengths().getNode()->visit(this, visitor); |
11479 | getValues().getNode()->visit(this, visitor); |
11480 | visitor->post(parent, this); |
11481 | } |
11482 | |
11483 | bool BatchOneHotNode::isEqual(const BatchOneHotNode &other) const { |
11484 | return true && |
11485 | Data_ == other.Data_ && |
11486 | Lengths_ == other.Lengths_ && |
11487 | Values_ == other.Values_ && |
11488 | predicate_ == other.predicate_ && |
11489 | getType(0) == other.getType(0); |
11490 | } |
11491 | |
11492 | Node* BatchOneHotNode::clone() const { |
11493 | return new BatchOneHotNode(getName(), getResult().getType(), getData(), getLengths(), getValues()); |
11494 | } |
11495 | |
11496 | llvm::hash_code BatchOneHotNode::getHash() const { |
11497 | return llvm::hash_combine( |
11498 | Data_, |
11499 | Lengths_, |
11500 | Values_); |
11501 | } |
11502 | |
11503 | unsigned SpaceToDepthNode::getNumInputs() const { |
11504 | return 1; |
11505 | } |
11506 | |
11507 | std::string SpaceToDepthNode::getInputName(unsigned idx) const { |
11508 | if (idx == 0) { return "Input" ; } |
11509 | idx -= 1; |
11510 | llvm_unreachable("Invalid index" ); |
11511 | } |
11512 | |
11513 | NodeValue SpaceToDepthNode::getNthInput(unsigned idx) { |
11514 | if (idx == 0) { return Input_; } |
11515 | idx -= 1; |
11516 | llvm_unreachable("Invalid index" ); |
11517 | } |
11518 | |
11519 | void SpaceToDepthNode::setNthInput(unsigned idx, NodeValue val) { |
11520 | if (idx == 0) { Input_ = val; return; } |
11521 | idx -= 1; |
11522 | llvm_unreachable("Invalid index" ); |
11523 | } |
11524 | |
11525 | llvm::StringRef SpaceToDepthNode::getOutputName(unsigned idx) const { |
11526 | if (idx == 0) { return "Result" ; } |
11527 | llvm_unreachable("Invalid index" ); |
11528 | } |
11529 | |
11530 | std::string SpaceToDepthNode::getDebugDesc() const { |
11531 | DescriptionBuilder db(getKindName()); |
11532 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11533 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11534 | db |
11535 | .addParam("Input" , *(getInput().getType())) |
11536 | .addParam("BlockSize" , getBlockSize()) |
11537 | .addParam("Users" , getNumUsers()); |
11538 | db.addParam("Result" , *(getResult().getType())); |
11539 | return db; |
11540 | } |
11541 | |
11542 | void SpaceToDepthNode::visit(Node *parent, NodeWalker *visitor) { |
11543 | if (!visitor->shouldVisit(parent, this)) { return; } |
11544 | visitor->pre(parent, this); |
11545 | if (hasPredicate()) |
11546 | getPredicate().getNode()->visit(this, visitor); |
11547 | getInput().getNode()->visit(this, visitor); |
11548 | visitor->post(parent, this); |
11549 | } |
11550 | |
11551 | bool SpaceToDepthNode::isEqual(const SpaceToDepthNode &other) const { |
11552 | return true && |
11553 | Input_ == other.Input_ && |
11554 | predicate_ == other.predicate_ && |
11555 | BlockSize_ == other.BlockSize_ && |
11556 | getType(0) == other.getType(0); |
11557 | } |
11558 | |
11559 | Node* SpaceToDepthNode::clone() const { |
11560 | return new SpaceToDepthNode(getName(), getResult().getType(), getInput(), getBlockSize()); |
11561 | } |
11562 | |
11563 | llvm::hash_code SpaceToDepthNode::getHash() const { |
11564 | return llvm::hash_combine( |
11565 | BlockSize_, |
11566 | Input_); |
11567 | } |
11568 | |
11569 | unsigned ResizeNearestNode::getNumInputs() const { |
11570 | return 1; |
11571 | } |
11572 | |
11573 | std::string ResizeNearestNode::getInputName(unsigned idx) const { |
11574 | if (idx == 0) { return "Input" ; } |
11575 | idx -= 1; |
11576 | llvm_unreachable("Invalid index" ); |
11577 | } |
11578 | |
11579 | NodeValue ResizeNearestNode::getNthInput(unsigned idx) { |
11580 | if (idx == 0) { return Input_; } |
11581 | idx -= 1; |
11582 | llvm_unreachable("Invalid index" ); |
11583 | } |
11584 | |
11585 | void ResizeNearestNode::setNthInput(unsigned idx, NodeValue val) { |
11586 | if (idx == 0) { Input_ = val; return; } |
11587 | idx -= 1; |
11588 | llvm_unreachable("Invalid index" ); |
11589 | } |
11590 | |
11591 | llvm::StringRef ResizeNearestNode::getOutputName(unsigned idx) const { |
11592 | if (idx == 0) { return "Result" ; } |
11593 | llvm_unreachable("Invalid index" ); |
11594 | } |
11595 | |
11596 | std::string ResizeNearestNode::getDebugDesc() const { |
11597 | DescriptionBuilder db(getKindName()); |
11598 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11599 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11600 | db |
11601 | .addParam("Input" , *(getInput().getType())) |
11602 | .addParam("Scale" , getScale()) |
11603 | .addParam("Users" , getNumUsers()); |
11604 | db.addParam("Result" , *(getResult().getType())); |
11605 | return db; |
11606 | } |
11607 | |
11608 | void ResizeNearestNode::visit(Node *parent, NodeWalker *visitor) { |
11609 | if (!visitor->shouldVisit(parent, this)) { return; } |
11610 | visitor->pre(parent, this); |
11611 | if (hasPredicate()) |
11612 | getPredicate().getNode()->visit(this, visitor); |
11613 | getInput().getNode()->visit(this, visitor); |
11614 | visitor->post(parent, this); |
11615 | } |
11616 | |
11617 | bool ResizeNearestNode::isEqual(const ResizeNearestNode &other) const { |
11618 | return true && |
11619 | Input_ == other.Input_ && |
11620 | predicate_ == other.predicate_ && |
11621 | Scale_ == other.Scale_ && |
11622 | getType(0) == other.getType(0); |
11623 | } |
11624 | |
11625 | Node* ResizeNearestNode::clone() const { |
11626 | return new ResizeNearestNode(getName(), getResult().getType(), getInput(), getScale()); |
11627 | } |
11628 | |
11629 | llvm::hash_code ResizeNearestNode::getHash() const { |
11630 | return llvm::hash_combine( |
11631 | [](const std::vector<float>& floatVec) -> llvm::hash_code { |
11632 | std::vector<size_t> sizeVec = toBinary(floatVec); |
11633 | return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end()); |
11634 | }(Scale_), |
11635 | Input_); |
11636 | } |
11637 | |
11638 | unsigned ResizeBilinearNode::getNumInputs() const { |
11639 | return 1; |
11640 | } |
11641 | |
11642 | std::string ResizeBilinearNode::getInputName(unsigned idx) const { |
11643 | if (idx == 0) { return "Input" ; } |
11644 | idx -= 1; |
11645 | llvm_unreachable("Invalid index" ); |
11646 | } |
11647 | |
11648 | NodeValue ResizeBilinearNode::getNthInput(unsigned idx) { |
11649 | if (idx == 0) { return Input_; } |
11650 | idx -= 1; |
11651 | llvm_unreachable("Invalid index" ); |
11652 | } |
11653 | |
11654 | void ResizeBilinearNode::setNthInput(unsigned idx, NodeValue val) { |
11655 | if (idx == 0) { Input_ = val; return; } |
11656 | idx -= 1; |
11657 | llvm_unreachable("Invalid index" ); |
11658 | } |
11659 | |
11660 | llvm::StringRef ResizeBilinearNode::getOutputName(unsigned idx) const { |
11661 | if (idx == 0) { return "Result" ; } |
11662 | llvm_unreachable("Invalid index" ); |
11663 | } |
11664 | |
11665 | std::string ResizeBilinearNode::getDebugDesc() const { |
11666 | DescriptionBuilder db(getKindName()); |
11667 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11668 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11669 | db |
11670 | .addParam("Input" , *(getInput().getType())) |
11671 | .addParam("Scale" , getScale()) |
11672 | .addParam("Users" , getNumUsers()); |
11673 | db.addParam("Result" , *(getResult().getType())); |
11674 | return db; |
11675 | } |
11676 | |
11677 | void ResizeBilinearNode::visit(Node *parent, NodeWalker *visitor) { |
11678 | if (!visitor->shouldVisit(parent, this)) { return; } |
11679 | visitor->pre(parent, this); |
11680 | if (hasPredicate()) |
11681 | getPredicate().getNode()->visit(this, visitor); |
11682 | getInput().getNode()->visit(this, visitor); |
11683 | visitor->post(parent, this); |
11684 | } |
11685 | |
11686 | bool ResizeBilinearNode::isEqual(const ResizeBilinearNode &other) const { |
11687 | return true && |
11688 | Input_ == other.Input_ && |
11689 | predicate_ == other.predicate_ && |
11690 | Scale_ == other.Scale_ && |
11691 | getType(0) == other.getType(0); |
11692 | } |
11693 | |
11694 | Node* ResizeBilinearNode::clone() const { |
11695 | return new ResizeBilinearNode(getName(), getResult().getType(), getInput(), getScale()); |
11696 | } |
11697 | |
11698 | llvm::hash_code ResizeBilinearNode::getHash() const { |
11699 | return llvm::hash_combine( |
11700 | [](const std::vector<float>& floatVec) -> llvm::hash_code { |
11701 | std::vector<size_t> sizeVec = toBinary(floatVec); |
11702 | return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end()); |
11703 | }(Scale_), |
11704 | Input_); |
11705 | } |
11706 | |
11707 | unsigned BroadcastNode::getNumInputs() const { |
11708 | return 1; |
11709 | } |
11710 | |
11711 | std::string BroadcastNode::getInputName(unsigned idx) const { |
11712 | if (idx == 0) { return "Input" ; } |
11713 | idx -= 1; |
11714 | llvm_unreachable("Invalid index" ); |
11715 | } |
11716 | |
11717 | NodeValue BroadcastNode::getNthInput(unsigned idx) { |
11718 | if (idx == 0) { return Input_; } |
11719 | idx -= 1; |
11720 | llvm_unreachable("Invalid index" ); |
11721 | } |
11722 | |
11723 | void BroadcastNode::setNthInput(unsigned idx, NodeValue val) { |
11724 | if (idx == 0) { Input_ = val; return; } |
11725 | idx -= 1; |
11726 | llvm_unreachable("Invalid index" ); |
11727 | } |
11728 | |
11729 | llvm::StringRef BroadcastNode::getOutputName(unsigned idx) const { |
11730 | if (idx == 0) { return "Result" ; } |
11731 | llvm_unreachable("Invalid index" ); |
11732 | } |
11733 | |
11734 | std::string BroadcastNode::getDebugDesc() const { |
11735 | DescriptionBuilder db(getKindName()); |
11736 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11737 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11738 | db |
11739 | .addParam("Input" , *(getInput().getType())) |
11740 | .addParam("Axis" , getAxis()) |
11741 | .addParam("TargetDim" , getTargetDim()) |
11742 | .addParam("Users" , getNumUsers()); |
11743 | db.addParam("Result" , *(getResult().getType())); |
11744 | return db; |
11745 | } |
11746 | |
11747 | void BroadcastNode::visit(Node *parent, NodeWalker *visitor) { |
11748 | if (!visitor->shouldVisit(parent, this)) { return; } |
11749 | visitor->pre(parent, this); |
11750 | if (hasPredicate()) |
11751 | getPredicate().getNode()->visit(this, visitor); |
11752 | getInput().getNode()->visit(this, visitor); |
11753 | visitor->post(parent, this); |
11754 | } |
11755 | |
11756 | bool BroadcastNode::isEqual(const BroadcastNode &other) const { |
11757 | return true && |
11758 | Input_ == other.Input_ && |
11759 | predicate_ == other.predicate_ && |
11760 | Axis_ == other.Axis_ && |
11761 | TargetDim_ == other.TargetDim_ && |
11762 | getType(0) == other.getType(0); |
11763 | } |
11764 | |
11765 | Node* BroadcastNode::clone() const { |
11766 | return new BroadcastNode(getName(), getResult().getType(), getInput(), getAxis(), getTargetDim()); |
11767 | } |
11768 | |
11769 | llvm::hash_code BroadcastNode::getHash() const { |
11770 | return llvm::hash_combine( |
11771 | Axis_, |
11772 | llvm::hash_combine_range(TargetDim_.begin(), TargetDim_.end()), |
11773 | Input_); |
11774 | } |
11775 | |
11776 | unsigned SparseLabelSplitNode::getNumInputs() const { |
11777 | return 3; |
11778 | } |
11779 | |
11780 | std::string SparseLabelSplitNode::getInputName(unsigned idx) const { |
11781 | if (idx == 0) { return "Lengths" ; } |
11782 | if (idx == 1) { return "Indices" ; } |
11783 | if (idx == 2) { return "Values" ; } |
11784 | idx -= 3; |
11785 | llvm_unreachable("Invalid index" ); |
11786 | } |
11787 | |
11788 | NodeValue SparseLabelSplitNode::getNthInput(unsigned idx) { |
11789 | if (idx == 0) { return Lengths_; } |
11790 | if (idx == 1) { return Indices_; } |
11791 | if (idx == 2) { return Values_; } |
11792 | idx -= 3; |
11793 | llvm_unreachable("Invalid index" ); |
11794 | } |
11795 | |
11796 | void SparseLabelSplitNode::setNthInput(unsigned idx, NodeValue val) { |
11797 | if (idx == 0) { Lengths_ = val; return; } |
11798 | if (idx == 1) { Indices_ = val; return; } |
11799 | if (idx == 2) { Values_ = val; return; } |
11800 | idx -= 3; |
11801 | llvm_unreachable("Invalid index" ); |
11802 | } |
11803 | |
11804 | llvm::StringRef SparseLabelSplitNode::getOutputName(unsigned idx) const { |
11805 | if (idx == 0) { return "LabelValues" ; } |
11806 | if (idx == 1) { return "ExampleIds" ; } |
11807 | if (idx == 2) { return "GradientOffsetMap" ; } |
11808 | llvm_unreachable("Invalid index" ); |
11809 | } |
11810 | |
11811 | std::string SparseLabelSplitNode::getDebugDesc() const { |
11812 | DescriptionBuilder db(getKindName()); |
11813 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11814 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11815 | db |
11816 | .addParam("Lengths" , *(getLengths().getType())) |
11817 | .addParam("Indices" , *(getIndices().getType())) |
11818 | .addParam("Values" , *(getValues().getType())) |
11819 | .addParam("NumLabels" , getNumLabels()) |
11820 | .addParam("Users" , getNumUsers()); |
11821 | db.addParam("LabelValues" , *(getLabelValues().getType())); |
11822 | db.addParam("ExampleIds" , *(getExampleIds().getType())); |
11823 | db.addParam("GradientOffsetMap" , *(getGradientOffsetMap().getType())); |
11824 | return db; |
11825 | } |
11826 | |
11827 | void SparseLabelSplitNode::visit(Node *parent, NodeWalker *visitor) { |
11828 | if (!visitor->shouldVisit(parent, this)) { return; } |
11829 | visitor->pre(parent, this); |
11830 | if (hasPredicate()) |
11831 | getPredicate().getNode()->visit(this, visitor); |
11832 | getLengths().getNode()->visit(this, visitor); |
11833 | getIndices().getNode()->visit(this, visitor); |
11834 | getValues().getNode()->visit(this, visitor); |
11835 | visitor->post(parent, this); |
11836 | } |
11837 | |
11838 | bool SparseLabelSplitNode::isEqual(const SparseLabelSplitNode &other) const { |
11839 | return true && |
11840 | Lengths_ == other.Lengths_ && |
11841 | Indices_ == other.Indices_ && |
11842 | Values_ == other.Values_ && |
11843 | predicate_ == other.predicate_ && |
11844 | NumLabels_ == other.NumLabels_ && |
11845 | getType(0) == other.getType(0) && |
11846 | getType(1) == other.getType(1) && |
11847 | getType(2) == other.getType(2); |
11848 | } |
11849 | |
11850 | Node* SparseLabelSplitNode::clone() const { |
11851 | return new SparseLabelSplitNode(getName(), getLabelValues().getType(), getExampleIds().getType(), getGradientOffsetMap().getType(), getLengths(), getIndices(), getValues(), getNumLabels()); |
11852 | } |
11853 | |
11854 | llvm::hash_code SparseLabelSplitNode::getHash() const { |
11855 | return llvm::hash_combine( |
11856 | NumLabels_, |
11857 | Lengths_, |
11858 | Indices_, |
11859 | Values_); |
11860 | } |
11861 | |
11862 | unsigned FlipNode::getNumInputs() const { |
11863 | return 1; |
11864 | } |
11865 | |
11866 | std::string FlipNode::getInputName(unsigned idx) const { |
11867 | if (idx == 0) { return "Input" ; } |
11868 | idx -= 1; |
11869 | llvm_unreachable("Invalid index" ); |
11870 | } |
11871 | |
11872 | NodeValue FlipNode::getNthInput(unsigned idx) { |
11873 | if (idx == 0) { return Input_; } |
11874 | idx -= 1; |
11875 | llvm_unreachable("Invalid index" ); |
11876 | } |
11877 | |
11878 | void FlipNode::setNthInput(unsigned idx, NodeValue val) { |
11879 | if (idx == 0) { Input_ = val; return; } |
11880 | idx -= 1; |
11881 | llvm_unreachable("Invalid index" ); |
11882 | } |
11883 | |
11884 | llvm::StringRef FlipNode::getOutputName(unsigned idx) const { |
11885 | if (idx == 0) { return "Result" ; } |
11886 | llvm_unreachable("Invalid index" ); |
11887 | } |
11888 | |
11889 | std::string FlipNode::getDebugDesc() const { |
11890 | DescriptionBuilder db(getKindName()); |
11891 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11892 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11893 | db |
11894 | .addParam("Input" , *(getInput().getType())) |
11895 | .addParam("Axis" , getAxis()) |
11896 | .addParam("Users" , getNumUsers()); |
11897 | db.addParam("Result" , *(getResult().getType())); |
11898 | return db; |
11899 | } |
11900 | |
11901 | void FlipNode::visit(Node *parent, NodeWalker *visitor) { |
11902 | if (!visitor->shouldVisit(parent, this)) { return; } |
11903 | visitor->pre(parent, this); |
11904 | if (hasPredicate()) |
11905 | getPredicate().getNode()->visit(this, visitor); |
11906 | getInput().getNode()->visit(this, visitor); |
11907 | visitor->post(parent, this); |
11908 | } |
11909 | |
11910 | bool FlipNode::isEqual(const FlipNode &other) const { |
11911 | return true && |
11912 | Input_ == other.Input_ && |
11913 | predicate_ == other.predicate_ && |
11914 | Axis_ == other.Axis_ && |
11915 | getType(0) == other.getType(0); |
11916 | } |
11917 | |
11918 | Node* FlipNode::clone() const { |
11919 | return new FlipNode(getName(), getResult().getType(), getInput(), getAxis()); |
11920 | } |
11921 | |
11922 | llvm::hash_code FlipNode::getHash() const { |
11923 | return llvm::hash_combine( |
11924 | Axis_, |
11925 | Input_); |
11926 | } |
11927 | |
11928 | unsigned SplatNode::getNumInputs() const { |
11929 | return 0; |
11930 | } |
11931 | |
11932 | std::string SplatNode::getInputName(unsigned idx) const { |
11933 | idx -= 0; |
11934 | llvm_unreachable("Invalid index" ); |
11935 | } |
11936 | |
11937 | NodeValue SplatNode::getNthInput(unsigned idx) { |
11938 | idx -= 0; |
11939 | llvm_unreachable("Invalid index" ); |
11940 | } |
11941 | |
11942 | void SplatNode::setNthInput(unsigned idx, NodeValue val) { |
11943 | idx -= 0; |
11944 | llvm_unreachable("Invalid index" ); |
11945 | } |
11946 | |
11947 | llvm::StringRef SplatNode::getOutputName(unsigned idx) const { |
11948 | if (idx == 0) { return "Result" ; } |
11949 | llvm_unreachable("Invalid index" ); |
11950 | } |
11951 | |
11952 | std::string SplatNode::getDebugDesc() const { |
11953 | DescriptionBuilder db(getKindName()); |
11954 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
11955 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
11956 | db |
11957 | .addParam("Value" , getValue()) |
11958 | .addParam("Users" , getNumUsers()); |
11959 | db.addParam("Result" , *(getResult().getType())); |
11960 | return db; |
11961 | } |
11962 | |
11963 | void SplatNode::visit(Node *parent, NodeWalker *visitor) { |
11964 | if (!visitor->shouldVisit(parent, this)) { return; } |
11965 | visitor->pre(parent, this); |
11966 | if (hasPredicate()) |
11967 | getPredicate().getNode()->visit(this, visitor); |
11968 | visitor->post(parent, this); |
11969 | } |
11970 | |
11971 | bool SplatNode::isEqual(const SplatNode &other) const { |
11972 | return true && |
11973 | predicate_ == other.predicate_ && |
11974 | Value_ == other.Value_ && |
11975 | getType(0) == other.getType(0); |
11976 | } |
11977 | |
11978 | Node* SplatNode::clone() const { |
11979 | return new SplatNode(getName(), getResult().getType(), getValue()); |
11980 | } |
11981 | |
11982 | llvm::hash_code SplatNode::getHash() const { |
11983 | return llvm::hash_combine( |
11984 | toBinary(Value_)); |
11985 | } |
11986 | |
11987 | unsigned TouchNode::getNumInputs() const { |
11988 | return 0; |
11989 | } |
11990 | |
11991 | std::string TouchNode::getInputName(unsigned idx) const { |
11992 | idx -= 0; |
11993 | llvm_unreachable("Invalid index" ); |
11994 | } |
11995 | |
11996 | NodeValue TouchNode::getNthInput(unsigned idx) { |
11997 | idx -= 0; |
11998 | llvm_unreachable("Invalid index" ); |
11999 | } |
12000 | |
12001 | void TouchNode::setNthInput(unsigned idx, NodeValue val) { |
12002 | idx -= 0; |
12003 | llvm_unreachable("Invalid index" ); |
12004 | } |
12005 | |
12006 | llvm::StringRef TouchNode::getOutputName(unsigned idx) const { |
12007 | if (idx == 0) { return "Result" ; } |
12008 | llvm_unreachable("Invalid index" ); |
12009 | } |
12010 | |
12011 | std::string TouchNode::getDebugDesc() const { |
12012 | DescriptionBuilder db(getKindName()); |
12013 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12014 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12015 | db |
12016 | .addParam("Users" , getNumUsers()); |
12017 | db.addParam("Result" , *(getResult().getType())); |
12018 | return db; |
12019 | } |
12020 | |
12021 | void TouchNode::visit(Node *parent, NodeWalker *visitor) { |
12022 | if (!visitor->shouldVisit(parent, this)) { return; } |
12023 | visitor->pre(parent, this); |
12024 | if (hasPredicate()) |
12025 | getPredicate().getNode()->visit(this, visitor); |
12026 | visitor->post(parent, this); |
12027 | } |
12028 | |
12029 | bool TouchNode::isEqual(const TouchNode &other) const { |
12030 | return true && |
12031 | predicate_ == other.predicate_ && |
12032 | getType(0) == other.getType(0); |
12033 | } |
12034 | |
12035 | Node* TouchNode::clone() const { |
12036 | return new TouchNode(getName(), getResult().getType()); |
12037 | } |
12038 | |
12039 | llvm::hash_code TouchNode::getHash() const { |
12040 | return llvm::hash_combine(0); |
12041 | } |
12042 | |
12043 | unsigned SGDNode::getNumInputs() const { |
12044 | return 2; |
12045 | } |
12046 | |
12047 | std::string SGDNode::getInputName(unsigned idx) const { |
12048 | if (idx == 0) { return "Gradient" ; } |
12049 | if (idx == 1) { return "Weight" ; } |
12050 | idx -= 2; |
12051 | llvm_unreachable("Invalid index" ); |
12052 | } |
12053 | |
12054 | NodeValue SGDNode::getNthInput(unsigned idx) { |
12055 | if (idx == 0) { return Gradient_; } |
12056 | if (idx == 1) { return Weight_; } |
12057 | idx -= 2; |
12058 | llvm_unreachable("Invalid index" ); |
12059 | } |
12060 | |
12061 | void SGDNode::setNthInput(unsigned idx, NodeValue val) { |
12062 | if (idx == 0) { Gradient_ = val; return; } |
12063 | if (idx == 1) { Weight_ = val; return; } |
12064 | idx -= 2; |
12065 | llvm_unreachable("Invalid index" ); |
12066 | } |
12067 | |
12068 | llvm::StringRef SGDNode::getOutputName(unsigned idx) const { |
12069 | if (idx == 0) { return "UpdatedWeight" ; } |
12070 | llvm_unreachable("Invalid index" ); |
12071 | } |
12072 | |
12073 | std::string SGDNode::getDebugDesc() const { |
12074 | DescriptionBuilder db(getKindName()); |
12075 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12076 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12077 | db |
12078 | .addParam("Gradient" , *(getGradient().getType())) |
12079 | .addParam("Weight" , *(getWeight().getType())) |
12080 | .addParam("L1Decay" , getL1Decay()) |
12081 | .addParam("L2Decay" , getL2Decay()) |
12082 | .addParam("LearningRate" , getLearningRate()) |
12083 | .addParam("Momentum" , getMomentum()) |
12084 | .addParam("BatchSize" , getBatchSize()) |
12085 | .addParam("Users" , getNumUsers()); |
12086 | db.addParam("UpdatedWeight" , *(getUpdatedWeight().getType())); |
12087 | return db; |
12088 | } |
12089 | |
12090 | void SGDNode::visit(Node *parent, NodeWalker *visitor) { |
12091 | if (!visitor->shouldVisit(parent, this)) { return; } |
12092 | visitor->pre(parent, this); |
12093 | if (hasPredicate()) |
12094 | getPredicate().getNode()->visit(this, visitor); |
12095 | getGradient().getNode()->visit(this, visitor); |
12096 | getWeight().getNode()->visit(this, visitor); |
12097 | visitor->post(parent, this); |
12098 | } |
12099 | |
12100 | bool SGDNode::isEqual(const SGDNode &other) const { |
12101 | return true && |
12102 | Gradient_ == other.Gradient_ && |
12103 | Weight_ == other.Weight_ && |
12104 | predicate_ == other.predicate_ && |
12105 | L1Decay_ == other.L1Decay_ && |
12106 | L2Decay_ == other.L2Decay_ && |
12107 | LearningRate_ == other.LearningRate_ && |
12108 | Momentum_ == other.Momentum_ && |
12109 | BatchSize_ == other.BatchSize_ && |
12110 | getType(0) == other.getType(0); |
12111 | } |
12112 | |
12113 | Node* SGDNode::clone() const { |
12114 | return new SGDNode(getName(), getGradient(), getWeight(), getL1Decay(), getL2Decay(), getLearningRate(), getMomentum(), getBatchSize()); |
12115 | } |
12116 | |
12117 | llvm::hash_code SGDNode::getHash() const { |
12118 | return llvm::hash_combine( |
12119 | toBinary(L1Decay_), |
12120 | toBinary(L2Decay_), |
12121 | toBinary(LearningRate_), |
12122 | toBinary(Momentum_), |
12123 | BatchSize_, |
12124 | Gradient_, |
12125 | Weight_); |
12126 | } |
12127 | |
12128 | unsigned TraceEventNode::getNumInputs() const { |
12129 | return 1; |
12130 | } |
12131 | |
12132 | std::string TraceEventNode::getInputName(unsigned idx) const { |
12133 | if (idx == 0) { return "Data" ; } |
12134 | idx -= 1; |
12135 | llvm_unreachable("Invalid index" ); |
12136 | } |
12137 | |
12138 | NodeValue TraceEventNode::getNthInput(unsigned idx) { |
12139 | if (idx == 0) { return Data_; } |
12140 | idx -= 1; |
12141 | llvm_unreachable("Invalid index" ); |
12142 | } |
12143 | |
12144 | void TraceEventNode::setNthInput(unsigned idx, NodeValue val) { |
12145 | if (idx == 0) { Data_ = val; return; } |
12146 | idx -= 1; |
12147 | llvm_unreachable("Invalid index" ); |
12148 | } |
12149 | |
12150 | llvm::StringRef TraceEventNode::getOutputName(unsigned idx) const { |
12151 | llvm_unreachable("Invalid index" ); |
12152 | } |
12153 | |
12154 | std::string TraceEventNode::getDebugDesc() const { |
12155 | DescriptionBuilder db(getKindName()); |
12156 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12157 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12158 | db |
12159 | .addParam("Data" , *(getData().getType())) |
12160 | .addParam("EventName" , getEventName()) |
12161 | .addParam("EventType" , getEventType()) |
12162 | .addParam("Index" , getIndex()) |
12163 | .addParam("Users" , getNumUsers()); |
12164 | return db; |
12165 | } |
12166 | |
12167 | void TraceEventNode::visit(Node *parent, NodeWalker *visitor) { |
12168 | if (!visitor->shouldVisit(parent, this)) { return; } |
12169 | visitor->pre(parent, this); |
12170 | if (hasPredicate()) |
12171 | getPredicate().getNode()->visit(this, visitor); |
12172 | getData().getNode()->visit(this, visitor); |
12173 | visitor->post(parent, this); |
12174 | } |
12175 | |
12176 | bool TraceEventNode::isEqual(const TraceEventNode &other) const { |
12177 | return true && |
12178 | Data_ == other.Data_ && |
12179 | predicate_ == other.predicate_ && |
12180 | EventName_ == other.EventName_ && |
12181 | EventType_ == other.EventType_ && |
12182 | Index_ == other.Index_; |
12183 | } |
12184 | |
12185 | Node* TraceEventNode::clone() const { |
12186 | return new TraceEventNode(getName(), getData(), getEventName(), getEventType(), getIndex()); |
12187 | } |
12188 | |
12189 | llvm::hash_code TraceEventNode::getHash() const { |
12190 | return llvm::hash_combine( |
12191 | EventName_, |
12192 | EventType_, |
12193 | Index_, |
12194 | Data_); |
12195 | } |
12196 | |
12197 | unsigned QuantizationProfileNode::getNumInputs() const { |
12198 | return 3; |
12199 | } |
12200 | |
12201 | std::string QuantizationProfileNode::getInputName(unsigned idx) const { |
12202 | if (idx == 0) { return "Input" ; } |
12203 | if (idx == 1) { return "Histogram" ; } |
12204 | if (idx == 2) { return "ComputationInfo" ; } |
12205 | idx -= 3; |
12206 | llvm_unreachable("Invalid index" ); |
12207 | } |
12208 | |
12209 | NodeValue QuantizationProfileNode::getNthInput(unsigned idx) { |
12210 | if (idx == 0) { return Input_; } |
12211 | if (idx == 1) { return Histogram_; } |
12212 | if (idx == 2) { return ComputationInfo_; } |
12213 | idx -= 3; |
12214 | llvm_unreachable("Invalid index" ); |
12215 | } |
12216 | |
12217 | void QuantizationProfileNode::setNthInput(unsigned idx, NodeValue val) { |
12218 | if (idx == 0) { Input_ = val; return; } |
12219 | if (idx == 1) { Histogram_ = val; return; } |
12220 | if (idx == 2) { ComputationInfo_ = val; return; } |
12221 | idx -= 3; |
12222 | llvm_unreachable("Invalid index" ); |
12223 | } |
12224 | |
12225 | llvm::StringRef QuantizationProfileNode::getOutputName(unsigned idx) const { |
12226 | llvm_unreachable("Invalid index" ); |
12227 | } |
12228 | |
12229 | std::string QuantizationProfileNode::getDebugDesc() const { |
12230 | DescriptionBuilder db(getKindName()); |
12231 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12232 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12233 | db |
12234 | .addParam("Input" , *(getInput().getType())) |
12235 | .addParam("Histogram" , *(getHistogram().getType())) |
12236 | .addParam("ComputationInfo" , *(getComputationInfo().getType())) |
12237 | .addParam("ProfiledNodeName" , getProfiledNodeName()) |
12238 | .addParam("ProfiledOutputNumber" , getProfiledOutputNumber()) |
12239 | .addParam("Users" , getNumUsers()); |
12240 | return db; |
12241 | } |
12242 | |
12243 | void QuantizationProfileNode::visit(Node *parent, NodeWalker *visitor) { |
12244 | if (!visitor->shouldVisit(parent, this)) { return; } |
12245 | visitor->pre(parent, this); |
12246 | if (hasPredicate()) |
12247 | getPredicate().getNode()->visit(this, visitor); |
12248 | getInput().getNode()->visit(this, visitor); |
12249 | getHistogram().getNode()->visit(this, visitor); |
12250 | getComputationInfo().getNode()->visit(this, visitor); |
12251 | visitor->post(parent, this); |
12252 | } |
12253 | |
12254 | bool QuantizationProfileNode::isEqual(const QuantizationProfileNode &other) const { |
12255 | return true && |
12256 | Input_ == other.Input_ && |
12257 | Histogram_ == other.Histogram_ && |
12258 | ComputationInfo_ == other.ComputationInfo_ && |
12259 | predicate_ == other.predicate_ && |
12260 | ProfiledNodeName_ == other.ProfiledNodeName_ && |
12261 | ProfiledOutputNumber_ == other.ProfiledOutputNumber_; |
12262 | } |
12263 | |
12264 | Node* QuantizationProfileNode::clone() const { |
12265 | return new QuantizationProfileNode(getName(), getInput(), getHistogram(), getComputationInfo(), getProfiledNodeName(), getProfiledOutputNumber()); |
12266 | } |
12267 | |
12268 | llvm::hash_code QuantizationProfileNode::getHash() const { |
12269 | return llvm::hash_combine( |
12270 | ProfiledNodeName_, |
12271 | ProfiledOutputNumber_, |
12272 | Input_, |
12273 | Histogram_, |
12274 | ComputationInfo_); |
12275 | } |
12276 | Placeholder *QuantizationProfileNode::getHistogramPlaceholder() const { return llvm::cast<Placeholder>(Histogram_.getNode()); }; |
12277 | Placeholder *QuantizationProfileNode::getComputationInfoPlaceholder() const { return llvm::cast<Placeholder>(ComputationInfo_.getNode()); }; |
12278 | |
12279 | unsigned IntLookupTableNode::getNumInputs() const { |
12280 | return 2; |
12281 | } |
12282 | |
12283 | std::string IntLookupTableNode::getInputName(unsigned idx) const { |
12284 | if (idx == 0) { return "Input" ; } |
12285 | if (idx == 1) { return "Mapping" ; } |
12286 | idx -= 2; |
12287 | llvm_unreachable("Invalid index" ); |
12288 | } |
12289 | |
12290 | NodeValue IntLookupTableNode::getNthInput(unsigned idx) { |
12291 | if (idx == 0) { return Input_; } |
12292 | if (idx == 1) { return Mapping_; } |
12293 | idx -= 2; |
12294 | llvm_unreachable("Invalid index" ); |
12295 | } |
12296 | |
12297 | void IntLookupTableNode::setNthInput(unsigned idx, NodeValue val) { |
12298 | if (idx == 0) { Input_ = val; return; } |
12299 | if (idx == 1) { Mapping_ = val; return; } |
12300 | idx -= 2; |
12301 | llvm_unreachable("Invalid index" ); |
12302 | } |
12303 | |
12304 | llvm::StringRef IntLookupTableNode::getOutputName(unsigned idx) const { |
12305 | if (idx == 0) { return "Result" ; } |
12306 | llvm_unreachable("Invalid index" ); |
12307 | } |
12308 | |
12309 | std::string IntLookupTableNode::getDebugDesc() const { |
12310 | DescriptionBuilder db(getKindName()); |
12311 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12312 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12313 | db |
12314 | .addParam("Input" , *(getInput().getType())) |
12315 | .addParam("Mapping" , *(getMapping().getType())) |
12316 | .addParam("Users" , getNumUsers()); |
12317 | db.addParam("Result" , *(getResult().getType())); |
12318 | return db; |
12319 | } |
12320 | |
12321 | void IntLookupTableNode::visit(Node *parent, NodeWalker *visitor) { |
12322 | if (!visitor->shouldVisit(parent, this)) { return; } |
12323 | visitor->pre(parent, this); |
12324 | if (hasPredicate()) |
12325 | getPredicate().getNode()->visit(this, visitor); |
12326 | getInput().getNode()->visit(this, visitor); |
12327 | getMapping().getNode()->visit(this, visitor); |
12328 | visitor->post(parent, this); |
12329 | } |
12330 | |
12331 | bool IntLookupTableNode::isEqual(const IntLookupTableNode &other) const { |
12332 | return true && |
12333 | Input_ == other.Input_ && |
12334 | Mapping_ == other.Mapping_ && |
12335 | predicate_ == other.predicate_ && |
12336 | getType(0) == other.getType(0); |
12337 | } |
12338 | |
12339 | Node* IntLookupTableNode::clone() const { |
12340 | return new IntLookupTableNode(getName(), getResult().getType(), getInput(), getMapping()); |
12341 | } |
12342 | |
12343 | llvm::hash_code IntLookupTableNode::getHash() const { |
12344 | return llvm::hash_combine( |
12345 | Input_, |
12346 | Mapping_); |
12347 | } |
12348 | |
12349 | unsigned QuantizeNode::getNumInputs() const { |
12350 | return 1; |
12351 | } |
12352 | |
12353 | std::string QuantizeNode::getInputName(unsigned idx) const { |
12354 | if (idx == 0) { return "Input" ; } |
12355 | idx -= 1; |
12356 | llvm_unreachable("Invalid index" ); |
12357 | } |
12358 | |
12359 | NodeValue QuantizeNode::getNthInput(unsigned idx) { |
12360 | if (idx == 0) { return Input_; } |
12361 | idx -= 1; |
12362 | llvm_unreachable("Invalid index" ); |
12363 | } |
12364 | |
12365 | void QuantizeNode::setNthInput(unsigned idx, NodeValue val) { |
12366 | if (idx == 0) { Input_ = val; return; } |
12367 | idx -= 1; |
12368 | llvm_unreachable("Invalid index" ); |
12369 | } |
12370 | |
12371 | llvm::StringRef QuantizeNode::getOutputName(unsigned idx) const { |
12372 | if (idx == 0) { return "Result" ; } |
12373 | llvm_unreachable("Invalid index" ); |
12374 | } |
12375 | |
12376 | std::string QuantizeNode::getDebugDesc() const { |
12377 | DescriptionBuilder db(getKindName()); |
12378 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12379 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12380 | db |
12381 | .addParam("Input" , *(getInput().getType())) |
12382 | .addParam("Users" , getNumUsers()); |
12383 | db.addParam("Result" , *(getResult().getType())); |
12384 | return db; |
12385 | } |
12386 | |
12387 | void QuantizeNode::visit(Node *parent, NodeWalker *visitor) { |
12388 | if (!visitor->shouldVisit(parent, this)) { return; } |
12389 | visitor->pre(parent, this); |
12390 | if (hasPredicate()) |
12391 | getPredicate().getNode()->visit(this, visitor); |
12392 | getInput().getNode()->visit(this, visitor); |
12393 | visitor->post(parent, this); |
12394 | } |
12395 | |
12396 | bool QuantizeNode::isEqual(const QuantizeNode &other) const { |
12397 | return true && |
12398 | Input_ == other.Input_ && |
12399 | predicate_ == other.predicate_ && |
12400 | getType(0) == other.getType(0); |
12401 | } |
12402 | |
12403 | Node* QuantizeNode::clone() const { |
12404 | return new QuantizeNode(getName(), getResult().getType(), getInput()); |
12405 | } |
12406 | |
12407 | llvm::hash_code QuantizeNode::getHash() const { |
12408 | return llvm::hash_combine( |
12409 | Input_); |
12410 | } |
12411 | |
12412 | unsigned DequantizeNode::getNumInputs() const { |
12413 | return 1; |
12414 | } |
12415 | |
12416 | std::string DequantizeNode::getInputName(unsigned idx) const { |
12417 | if (idx == 0) { return "Input" ; } |
12418 | idx -= 1; |
12419 | llvm_unreachable("Invalid index" ); |
12420 | } |
12421 | |
12422 | NodeValue DequantizeNode::getNthInput(unsigned idx) { |
12423 | if (idx == 0) { return Input_; } |
12424 | idx -= 1; |
12425 | llvm_unreachable("Invalid index" ); |
12426 | } |
12427 | |
12428 | void DequantizeNode::setNthInput(unsigned idx, NodeValue val) { |
12429 | if (idx == 0) { Input_ = val; return; } |
12430 | idx -= 1; |
12431 | llvm_unreachable("Invalid index" ); |
12432 | } |
12433 | |
12434 | llvm::StringRef DequantizeNode::getOutputName(unsigned idx) const { |
12435 | if (idx == 0) { return "Result" ; } |
12436 | llvm_unreachable("Invalid index" ); |
12437 | } |
12438 | |
12439 | std::string DequantizeNode::getDebugDesc() const { |
12440 | DescriptionBuilder db(getKindName()); |
12441 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12442 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12443 | db |
12444 | .addParam("Input" , *(getInput().getType())) |
12445 | .addParam("Users" , getNumUsers()); |
12446 | db.addParam("Result" , *(getResult().getType())); |
12447 | return db; |
12448 | } |
12449 | |
12450 | void DequantizeNode::visit(Node *parent, NodeWalker *visitor) { |
12451 | if (!visitor->shouldVisit(parent, this)) { return; } |
12452 | visitor->pre(parent, this); |
12453 | if (hasPredicate()) |
12454 | getPredicate().getNode()->visit(this, visitor); |
12455 | getInput().getNode()->visit(this, visitor); |
12456 | visitor->post(parent, this); |
12457 | } |
12458 | |
12459 | bool DequantizeNode::isEqual(const DequantizeNode &other) const { |
12460 | return true && |
12461 | Input_ == other.Input_ && |
12462 | predicate_ == other.predicate_ && |
12463 | getType(0) == other.getType(0); |
12464 | } |
12465 | |
12466 | Node* DequantizeNode::clone() const { |
12467 | return new DequantizeNode(getName(), getResult().getType(), getInput()); |
12468 | } |
12469 | |
12470 | llvm::hash_code DequantizeNode::getHash() const { |
12471 | return llvm::hash_combine( |
12472 | Input_); |
12473 | } |
12474 | |
12475 | unsigned RescaleQuantizedNode::getNumInputs() const { |
12476 | return 1; |
12477 | } |
12478 | |
12479 | std::string RescaleQuantizedNode::getInputName(unsigned idx) const { |
12480 | if (idx == 0) { return "Input" ; } |
12481 | idx -= 1; |
12482 | llvm_unreachable("Invalid index" ); |
12483 | } |
12484 | |
12485 | NodeValue RescaleQuantizedNode::getNthInput(unsigned idx) { |
12486 | if (idx == 0) { return Input_; } |
12487 | idx -= 1; |
12488 | llvm_unreachable("Invalid index" ); |
12489 | } |
12490 | |
12491 | void RescaleQuantizedNode::setNthInput(unsigned idx, NodeValue val) { |
12492 | if (idx == 0) { Input_ = val; return; } |
12493 | idx -= 1; |
12494 | llvm_unreachable("Invalid index" ); |
12495 | } |
12496 | |
12497 | llvm::StringRef RescaleQuantizedNode::getOutputName(unsigned idx) const { |
12498 | if (idx == 0) { return "Result" ; } |
12499 | llvm_unreachable("Invalid index" ); |
12500 | } |
12501 | |
12502 | std::string RescaleQuantizedNode::getDebugDesc() const { |
12503 | DescriptionBuilder db(getKindName()); |
12504 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12505 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12506 | db |
12507 | .addParam("Input" , *(getInput().getType())) |
12508 | .addParam("Users" , getNumUsers()); |
12509 | db.addParam("Result" , *(getResult().getType())); |
12510 | return db; |
12511 | } |
12512 | |
12513 | void RescaleQuantizedNode::visit(Node *parent, NodeWalker *visitor) { |
12514 | if (!visitor->shouldVisit(parent, this)) { return; } |
12515 | visitor->pre(parent, this); |
12516 | if (hasPredicate()) |
12517 | getPredicate().getNode()->visit(this, visitor); |
12518 | getInput().getNode()->visit(this, visitor); |
12519 | visitor->post(parent, this); |
12520 | } |
12521 | |
12522 | bool RescaleQuantizedNode::isEqual(const RescaleQuantizedNode &other) const { |
12523 | return true && |
12524 | Input_ == other.Input_ && |
12525 | predicate_ == other.predicate_ && |
12526 | getType(0) == other.getType(0); |
12527 | } |
12528 | |
12529 | Node* RescaleQuantizedNode::clone() const { |
12530 | return new RescaleQuantizedNode(getName(), getResult().getType(), getInput()); |
12531 | } |
12532 | |
12533 | llvm::hash_code RescaleQuantizedNode::getHash() const { |
12534 | return llvm::hash_combine( |
12535 | Input_); |
12536 | } |
12537 | |
12538 | unsigned TopKNode::getNumInputs() const { |
12539 | return 1; |
12540 | } |
12541 | |
12542 | std::string TopKNode::getInputName(unsigned idx) const { |
12543 | if (idx == 0) { return "Input" ; } |
12544 | idx -= 1; |
12545 | llvm_unreachable("Invalid index" ); |
12546 | } |
12547 | |
12548 | NodeValue TopKNode::getNthInput(unsigned idx) { |
12549 | if (idx == 0) { return Input_; } |
12550 | idx -= 1; |
12551 | llvm_unreachable("Invalid index" ); |
12552 | } |
12553 | |
12554 | void TopKNode::setNthInput(unsigned idx, NodeValue val) { |
12555 | if (idx == 0) { Input_ = val; return; } |
12556 | idx -= 1; |
12557 | llvm_unreachable("Invalid index" ); |
12558 | } |
12559 | |
12560 | llvm::StringRef TopKNode::getOutputName(unsigned idx) const { |
12561 | if (idx == 0) { return "Values" ; } |
12562 | if (idx == 1) { return "Indices" ; } |
12563 | llvm_unreachable("Invalid index" ); |
12564 | } |
12565 | |
12566 | std::string TopKNode::getDebugDesc() const { |
12567 | DescriptionBuilder db(getKindName()); |
12568 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12569 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12570 | db |
12571 | .addParam("Input" , *(getInput().getType())) |
12572 | .addParam("K" , getK()) |
12573 | .addParam("Users" , getNumUsers()); |
12574 | db.addParam("Values" , *(getValues().getType())); |
12575 | db.addParam("Indices" , *(getIndices().getType())); |
12576 | return db; |
12577 | } |
12578 | |
12579 | void TopKNode::visit(Node *parent, NodeWalker *visitor) { |
12580 | if (!visitor->shouldVisit(parent, this)) { return; } |
12581 | visitor->pre(parent, this); |
12582 | if (hasPredicate()) |
12583 | getPredicate().getNode()->visit(this, visitor); |
12584 | getInput().getNode()->visit(this, visitor); |
12585 | visitor->post(parent, this); |
12586 | } |
12587 | |
12588 | bool TopKNode::isEqual(const TopKNode &other) const { |
12589 | return true && |
12590 | Input_ == other.Input_ && |
12591 | predicate_ == other.predicate_ && |
12592 | K_ == other.K_ && |
12593 | getType(0) == other.getType(0) && |
12594 | getType(1) == other.getType(1); |
12595 | } |
12596 | |
12597 | Node* TopKNode::clone() const { |
12598 | return new TopKNode(getName(), getValues().getType(), getIndices().getType(), getInput(), getK()); |
12599 | } |
12600 | |
12601 | llvm::hash_code TopKNode::getHash() const { |
12602 | return llvm::hash_combine( |
12603 | K_, |
12604 | Input_); |
12605 | } |
12606 | |
12607 | unsigned LSTMUnitNode::getNumInputs() const { |
12608 | return 2; |
12609 | } |
12610 | |
12611 | std::string LSTMUnitNode::getInputName(unsigned idx) const { |
12612 | if (idx == 0) { return "Input" ; } |
12613 | if (idx == 1) { return "C" ; } |
12614 | idx -= 2; |
12615 | llvm_unreachable("Invalid index" ); |
12616 | } |
12617 | |
12618 | NodeValue LSTMUnitNode::getNthInput(unsigned idx) { |
12619 | if (idx == 0) { return Input_; } |
12620 | if (idx == 1) { return C_; } |
12621 | idx -= 2; |
12622 | llvm_unreachable("Invalid index" ); |
12623 | } |
12624 | |
12625 | void LSTMUnitNode::setNthInput(unsigned idx, NodeValue val) { |
12626 | if (idx == 0) { Input_ = val; return; } |
12627 | if (idx == 1) { C_ = val; return; } |
12628 | idx -= 2; |
12629 | llvm_unreachable("Invalid index" ); |
12630 | } |
12631 | |
12632 | llvm::StringRef LSTMUnitNode::getOutputName(unsigned idx) const { |
12633 | if (idx == 0) { return "newC" ; } |
12634 | if (idx == 1) { return "newH" ; } |
12635 | llvm_unreachable("Invalid index" ); |
12636 | } |
12637 | |
12638 | std::string LSTMUnitNode::getDebugDesc() const { |
12639 | DescriptionBuilder db(getKindName()); |
12640 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12641 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12642 | db |
12643 | .addParam("Input" , *(getInput().getType())) |
12644 | .addParam("C" , *(getC().getType())) |
12645 | .addParam("Users" , getNumUsers()); |
12646 | db.addParam("newC" , *(getnewC().getType())); |
12647 | db.addParam("newH" , *(getnewH().getType())); |
12648 | return db; |
12649 | } |
12650 | |
12651 | void LSTMUnitNode::visit(Node *parent, NodeWalker *visitor) { |
12652 | if (!visitor->shouldVisit(parent, this)) { return; } |
12653 | visitor->pre(parent, this); |
12654 | if (hasPredicate()) |
12655 | getPredicate().getNode()->visit(this, visitor); |
12656 | getInput().getNode()->visit(this, visitor); |
12657 | getC().getNode()->visit(this, visitor); |
12658 | visitor->post(parent, this); |
12659 | } |
12660 | |
12661 | bool LSTMUnitNode::isEqual(const LSTMUnitNode &other) const { |
12662 | return true && |
12663 | Input_ == other.Input_ && |
12664 | C_ == other.C_ && |
12665 | predicate_ == other.predicate_ && |
12666 | getType(0) == other.getType(0) && |
12667 | getType(1) == other.getType(1); |
12668 | } |
12669 | |
12670 | Node* LSTMUnitNode::clone() const { |
12671 | return new LSTMUnitNode(getName(), getInput(), getC()); |
12672 | } |
12673 | |
12674 | llvm::hash_code LSTMUnitNode::getHash() const { |
12675 | return llvm::hash_combine( |
12676 | Input_, |
12677 | C_); |
12678 | } |
12679 | |
12680 | unsigned ConvertToNode::getNumInputs() const { |
12681 | return 1; |
12682 | } |
12683 | |
12684 | std::string ConvertToNode::getInputName(unsigned idx) const { |
12685 | if (idx == 0) { return "Input" ; } |
12686 | idx -= 1; |
12687 | llvm_unreachable("Invalid index" ); |
12688 | } |
12689 | |
12690 | NodeValue ConvertToNode::getNthInput(unsigned idx) { |
12691 | if (idx == 0) { return Input_; } |
12692 | idx -= 1; |
12693 | llvm_unreachable("Invalid index" ); |
12694 | } |
12695 | |
12696 | void ConvertToNode::setNthInput(unsigned idx, NodeValue val) { |
12697 | if (idx == 0) { Input_ = val; return; } |
12698 | idx -= 1; |
12699 | llvm_unreachable("Invalid index" ); |
12700 | } |
12701 | |
12702 | llvm::StringRef ConvertToNode::getOutputName(unsigned idx) const { |
12703 | if (idx == 0) { return "Result" ; } |
12704 | llvm_unreachable("Invalid index" ); |
12705 | } |
12706 | |
12707 | std::string ConvertToNode::getDebugDesc() const { |
12708 | DescriptionBuilder db(getKindName()); |
12709 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12710 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12711 | db |
12712 | .addParam("Input" , *(getInput().getType())) |
12713 | .addParam("Users" , getNumUsers()); |
12714 | db.addParam("Result" , *(getResult().getType())); |
12715 | return db; |
12716 | } |
12717 | |
12718 | void ConvertToNode::visit(Node *parent, NodeWalker *visitor) { |
12719 | if (!visitor->shouldVisit(parent, this)) { return; } |
12720 | visitor->pre(parent, this); |
12721 | if (hasPredicate()) |
12722 | getPredicate().getNode()->visit(this, visitor); |
12723 | getInput().getNode()->visit(this, visitor); |
12724 | visitor->post(parent, this); |
12725 | } |
12726 | |
12727 | bool ConvertToNode::isEqual(const ConvertToNode &other) const { |
12728 | return true && |
12729 | Input_ == other.Input_ && |
12730 | predicate_ == other.predicate_ && |
12731 | getType(0) == other.getType(0); |
12732 | } |
12733 | |
12734 | Node* ConvertToNode::clone() const { |
12735 | return new ConvertToNode(getName(), getResult().getType(), getInput()); |
12736 | } |
12737 | |
12738 | llvm::hash_code ConvertToNode::getHash() const { |
12739 | return llvm::hash_combine( |
12740 | Input_); |
12741 | } |
12742 | |
12743 | unsigned ExternalFunctionCallNode::getNumInputs() const { |
12744 | return 0 + Inputs_.size(); |
12745 | } |
12746 | |
12747 | std::string ExternalFunctionCallNode::getInputName(unsigned idx) const { |
12748 | idx -= 0; |
12749 | if (idx < Inputs_.size()) { return "Inputs" + std::to_string(idx); } |
12750 | idx -= Inputs_.size(); |
12751 | llvm_unreachable("Invalid index" ); |
12752 | } |
12753 | |
12754 | NodeValue ExternalFunctionCallNode::getNthInput(unsigned idx) { |
12755 | idx -= 0; |
12756 | if (idx < Inputs_.size()) { return Inputs_[idx]; } |
12757 | idx -= Inputs_.size(); |
12758 | llvm_unreachable("Invalid index" ); |
12759 | } |
12760 | |
12761 | void ExternalFunctionCallNode::setNthInput(unsigned idx, NodeValue val) { |
12762 | idx -= 0; |
12763 | if (idx < Inputs_.size()) { Inputs_[idx] = val; return; } |
12764 | idx -= Inputs_.size(); |
12765 | llvm_unreachable("Invalid index" ); |
12766 | } |
12767 | |
12768 | llvm::StringRef ExternalFunctionCallNode::getOutputName(unsigned idx) const { |
12769 | if (idx == 0) { return "Result" ; } |
12770 | llvm_unreachable("Invalid index" ); |
12771 | } |
12772 | |
12773 | std::string ExternalFunctionCallNode::getDebugDesc() const { |
12774 | DescriptionBuilder db(getKindName()); |
12775 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12776 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12777 | db |
12778 | .addParam("FunctionName" , getFunctionName()) |
12779 | .addParam("FunctionImpl" , getFunctionImpl()) |
12780 | .addParam("FunctionKind" , getFunctionKind()) |
12781 | .addParam("Users" , getNumUsers()); |
12782 | { |
12783 | unsigned mIndex = 0; |
12784 | for (const auto &II : getInputs()) { |
12785 | db.addParam("Inputs" +std::to_string(mIndex++), *II.getType()); |
12786 | } |
12787 | } |
12788 | db.addParam("Result" , *(getResult().getType())); |
12789 | return db; |
12790 | } |
12791 | |
12792 | void ExternalFunctionCallNode::visit(Node *parent, NodeWalker *visitor) { |
12793 | if (!visitor->shouldVisit(parent, this)) { return; } |
12794 | visitor->pre(parent, this); |
12795 | if (hasPredicate()) |
12796 | getPredicate().getNode()->visit(this, visitor); |
12797 | for (auto &I : Inputs_) { I.getNode()->visit(this, visitor); } |
12798 | visitor->post(parent, this); |
12799 | } |
12800 | |
12801 | bool ExternalFunctionCallNode::isEqual(const ExternalFunctionCallNode &other) const { |
12802 | return true && |
12803 | predicate_ == other.predicate_ && |
12804 | Inputs_ == other.Inputs_ && |
12805 | FunctionName_ == other.FunctionName_ && |
12806 | FunctionImpl_ == other.FunctionImpl_ && |
12807 | FunctionKind_ == other.FunctionKind_ && |
12808 | getType(0) == other.getType(0); |
12809 | } |
12810 | |
12811 | Node* ExternalFunctionCallNode::clone() const { |
12812 | return new ExternalFunctionCallNode(getName(), getResult().getType(), getInputs(), getFunctionName(), getFunctionImpl(), getFunctionKind()); |
12813 | } |
12814 | |
12815 | llvm::hash_code ExternalFunctionCallNode::getHash() const { |
12816 | return llvm::hash_combine( |
12817 | llvm::hash_combine_range(Inputs_.begin(), Inputs_.end()), |
12818 | FunctionName_, |
12819 | FunctionImpl_, |
12820 | FunctionKind_); |
12821 | } |
12822 | |
12823 | unsigned AudioSpectrogramNode::getNumInputs() const { |
12824 | return 5; |
12825 | } |
12826 | |
12827 | std::string AudioSpectrogramNode::getInputName(unsigned idx) const { |
12828 | if (idx == 0) { return "Input" ; } |
12829 | if (idx == 1) { return "Window" ; } |
12830 | if (idx == 2) { return "TwiddleFactors" ; } |
12831 | if (idx == 3) { return "BitReverseIndices" ; } |
12832 | if (idx == 4) { return "ComplexToRealWeights" ; } |
12833 | idx -= 5; |
12834 | llvm_unreachable("Invalid index" ); |
12835 | } |
12836 | |
12837 | NodeValue AudioSpectrogramNode::getNthInput(unsigned idx) { |
12838 | if (idx == 0) { return Input_; } |
12839 | if (idx == 1) { return Window_; } |
12840 | if (idx == 2) { return TwiddleFactors_; } |
12841 | if (idx == 3) { return BitReverseIndices_; } |
12842 | if (idx == 4) { return ComplexToRealWeights_; } |
12843 | idx -= 5; |
12844 | llvm_unreachable("Invalid index" ); |
12845 | } |
12846 | |
12847 | void AudioSpectrogramNode::setNthInput(unsigned idx, NodeValue val) { |
12848 | if (idx == 0) { Input_ = val; return; } |
12849 | if (idx == 1) { Window_ = val; return; } |
12850 | if (idx == 2) { TwiddleFactors_ = val; return; } |
12851 | if (idx == 3) { BitReverseIndices_ = val; return; } |
12852 | if (idx == 4) { ComplexToRealWeights_ = val; return; } |
12853 | idx -= 5; |
12854 | llvm_unreachable("Invalid index" ); |
12855 | } |
12856 | |
12857 | llvm::StringRef AudioSpectrogramNode::getOutputName(unsigned idx) const { |
12858 | if (idx == 0) { return "Spectrogram" ; } |
12859 | llvm_unreachable("Invalid index" ); |
12860 | } |
12861 | |
12862 | std::string AudioSpectrogramNode::getDebugDesc() const { |
12863 | DescriptionBuilder db(getKindName()); |
12864 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12865 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12866 | db |
12867 | .addParam("Input" , *(getInput().getType())) |
12868 | .addParam("Window" , *(getWindow().getType())) |
12869 | .addParam("TwiddleFactors" , *(getTwiddleFactors().getType())) |
12870 | .addParam("BitReverseIndices" , *(getBitReverseIndices().getType())) |
12871 | .addParam("ComplexToRealWeights" , *(getComplexToRealWeights().getType())) |
12872 | .addParam("WindowSize" , getWindowSize()) |
12873 | .addParam("WindowStride" , getWindowStride()) |
12874 | .addParam("MagnitudeSquared" , getMagnitudeSquared()) |
12875 | .addParam("Users" , getNumUsers()); |
12876 | db.addParam("Spectrogram" , *(getSpectrogram().getType())); |
12877 | return db; |
12878 | } |
12879 | |
12880 | void AudioSpectrogramNode::visit(Node *parent, NodeWalker *visitor) { |
12881 | if (!visitor->shouldVisit(parent, this)) { return; } |
12882 | visitor->pre(parent, this); |
12883 | if (hasPredicate()) |
12884 | getPredicate().getNode()->visit(this, visitor); |
12885 | getInput().getNode()->visit(this, visitor); |
12886 | getWindow().getNode()->visit(this, visitor); |
12887 | getTwiddleFactors().getNode()->visit(this, visitor); |
12888 | getBitReverseIndices().getNode()->visit(this, visitor); |
12889 | getComplexToRealWeights().getNode()->visit(this, visitor); |
12890 | visitor->post(parent, this); |
12891 | } |
12892 | |
12893 | bool AudioSpectrogramNode::isEqual(const AudioSpectrogramNode &other) const { |
12894 | return true && |
12895 | Input_ == other.Input_ && |
12896 | Window_ == other.Window_ && |
12897 | TwiddleFactors_ == other.TwiddleFactors_ && |
12898 | BitReverseIndices_ == other.BitReverseIndices_ && |
12899 | ComplexToRealWeights_ == other.ComplexToRealWeights_ && |
12900 | predicate_ == other.predicate_ && |
12901 | WindowSize_ == other.WindowSize_ && |
12902 | WindowStride_ == other.WindowStride_ && |
12903 | MagnitudeSquared_ == other.MagnitudeSquared_ && |
12904 | getType(0) == other.getType(0); |
12905 | } |
12906 | |
12907 | Node* AudioSpectrogramNode::clone() const { |
12908 | return new AudioSpectrogramNode(getName(), getSpectrogram().getType(), getInput(), getWindow(), getTwiddleFactors(), getBitReverseIndices(), getComplexToRealWeights(), getWindowSize(), getWindowStride(), getMagnitudeSquared()); |
12909 | } |
12910 | |
12911 | llvm::hash_code AudioSpectrogramNode::getHash() const { |
12912 | return llvm::hash_combine( |
12913 | WindowSize_, |
12914 | WindowStride_, |
12915 | MagnitudeSquared_, |
12916 | Input_, |
12917 | Window_, |
12918 | TwiddleFactors_, |
12919 | BitReverseIndices_, |
12920 | ComplexToRealWeights_); |
12921 | } |
12922 | |
12923 | unsigned MFCCNode::getNumInputs() const { |
12924 | return 4; |
12925 | } |
12926 | |
12927 | std::string MFCCNode::getInputName(unsigned idx) const { |
12928 | if (idx == 0) { return "Spectrogram" ; } |
12929 | if (idx == 1) { return "MelWeights" ; } |
12930 | if (idx == 2) { return "MelRanges" ; } |
12931 | if (idx == 3) { return "DctMat" ; } |
12932 | idx -= 4; |
12933 | llvm_unreachable("Invalid index" ); |
12934 | } |
12935 | |
12936 | NodeValue MFCCNode::getNthInput(unsigned idx) { |
12937 | if (idx == 0) { return Spectrogram_; } |
12938 | if (idx == 1) { return MelWeights_; } |
12939 | if (idx == 2) { return MelRanges_; } |
12940 | if (idx == 3) { return DctMat_; } |
12941 | idx -= 4; |
12942 | llvm_unreachable("Invalid index" ); |
12943 | } |
12944 | |
12945 | void MFCCNode::setNthInput(unsigned idx, NodeValue val) { |
12946 | if (idx == 0) { Spectrogram_ = val; return; } |
12947 | if (idx == 1) { MelWeights_ = val; return; } |
12948 | if (idx == 2) { MelRanges_ = val; return; } |
12949 | if (idx == 3) { DctMat_ = val; return; } |
12950 | idx -= 4; |
12951 | llvm_unreachable("Invalid index" ); |
12952 | } |
12953 | |
12954 | llvm::StringRef MFCCNode::getOutputName(unsigned idx) const { |
12955 | if (idx == 0) { return "Coefficients" ; } |
12956 | llvm_unreachable("Invalid index" ); |
12957 | } |
12958 | |
12959 | std::string MFCCNode::getDebugDesc() const { |
12960 | DescriptionBuilder db(getKindName()); |
12961 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
12962 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
12963 | db |
12964 | .addParam("Spectrogram" , *(getSpectrogram().getType())) |
12965 | .addParam("MelWeights" , *(getMelWeights().getType())) |
12966 | .addParam("MelRanges" , *(getMelRanges().getType())) |
12967 | .addParam("DctMat" , *(getDctMat().getType())) |
12968 | .addParam("SampleRate" , getSampleRate()) |
12969 | .addParam("LowerFrequency" , getLowerFrequency()) |
12970 | .addParam("UpperFrequency" , getUpperFrequency()) |
12971 | .addParam("FilterBankCount" , getFilterBankCount()) |
12972 | .addParam("NumCoefficients" , getNumCoefficients()) |
12973 | .addParam("Users" , getNumUsers()); |
12974 | db.addParam("Coefficients" , *(getCoefficients().getType())); |
12975 | return db; |
12976 | } |
12977 | |
12978 | void MFCCNode::visit(Node *parent, NodeWalker *visitor) { |
12979 | if (!visitor->shouldVisit(parent, this)) { return; } |
12980 | visitor->pre(parent, this); |
12981 | if (hasPredicate()) |
12982 | getPredicate().getNode()->visit(this, visitor); |
12983 | getSpectrogram().getNode()->visit(this, visitor); |
12984 | getMelWeights().getNode()->visit(this, visitor); |
12985 | getMelRanges().getNode()->visit(this, visitor); |
12986 | getDctMat().getNode()->visit(this, visitor); |
12987 | visitor->post(parent, this); |
12988 | } |
12989 | |
12990 | bool MFCCNode::isEqual(const MFCCNode &other) const { |
12991 | return true && |
12992 | Spectrogram_ == other.Spectrogram_ && |
12993 | MelWeights_ == other.MelWeights_ && |
12994 | MelRanges_ == other.MelRanges_ && |
12995 | DctMat_ == other.DctMat_ && |
12996 | predicate_ == other.predicate_ && |
12997 | SampleRate_ == other.SampleRate_ && |
12998 | LowerFrequency_ == other.LowerFrequency_ && |
12999 | UpperFrequency_ == other.UpperFrequency_ && |
13000 | FilterBankCount_ == other.FilterBankCount_ && |
13001 | NumCoefficients_ == other.NumCoefficients_ && |
13002 | getType(0) == other.getType(0); |
13003 | } |
13004 | |
13005 | Node* MFCCNode::clone() const { |
13006 | return new MFCCNode(getName(), getCoefficients().getType(), getSpectrogram(), getMelWeights(), getMelRanges(), getDctMat(), getSampleRate(), getLowerFrequency(), getUpperFrequency(), getFilterBankCount(), getNumCoefficients()); |
13007 | } |
13008 | |
13009 | llvm::hash_code MFCCNode::getHash() const { |
13010 | return llvm::hash_combine( |
13011 | toBinary(SampleRate_), |
13012 | toBinary(LowerFrequency_), |
13013 | toBinary(UpperFrequency_), |
13014 | FilterBankCount_, |
13015 | NumCoefficients_, |
13016 | Spectrogram_, |
13017 | MelWeights_, |
13018 | MelRanges_, |
13019 | DctMat_); |
13020 | } |
13021 | |
13022 | unsigned NonMaxSuppressionNode::getNumInputs() const { |
13023 | return 2; |
13024 | } |
13025 | |
13026 | std::string NonMaxSuppressionNode::getInputName(unsigned idx) const { |
13027 | if (idx == 0) { return "Boxes" ; } |
13028 | if (idx == 1) { return "Scores" ; } |
13029 | idx -= 2; |
13030 | llvm_unreachable("Invalid index" ); |
13031 | } |
13032 | |
13033 | NodeValue NonMaxSuppressionNode::getNthInput(unsigned idx) { |
13034 | if (idx == 0) { return Boxes_; } |
13035 | if (idx == 1) { return Scores_; } |
13036 | idx -= 2; |
13037 | llvm_unreachable("Invalid index" ); |
13038 | } |
13039 | |
13040 | void NonMaxSuppressionNode::setNthInput(unsigned idx, NodeValue val) { |
13041 | if (idx == 0) { Boxes_ = val; return; } |
13042 | if (idx == 1) { Scores_ = val; return; } |
13043 | idx -= 2; |
13044 | llvm_unreachable("Invalid index" ); |
13045 | } |
13046 | |
13047 | llvm::StringRef NonMaxSuppressionNode::getOutputName(unsigned idx) const { |
13048 | if (idx == 0) { return "Indices" ; } |
13049 | if (idx == 1) { return "NumberOfSelectedIndices" ; } |
13050 | llvm_unreachable("Invalid index" ); |
13051 | } |
13052 | |
13053 | std::string NonMaxSuppressionNode::getDebugDesc() const { |
13054 | DescriptionBuilder db(getKindName()); |
13055 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
13056 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
13057 | db |
13058 | .addParam("Boxes" , *(getBoxes().getType())) |
13059 | .addParam("Scores" , *(getScores().getType())) |
13060 | .addParam("CenterPointBox" , getCenterPointBox()) |
13061 | .addParam("MaxOutputBoxesPerClass" , getMaxOutputBoxesPerClass()) |
13062 | .addParam("IouThreshold" , getIouThreshold()) |
13063 | .addParam("ScoreThreshold" , getScoreThreshold()) |
13064 | .addParam("IsTFVersion4" , getIsTFVersion4()) |
13065 | .addParam("Users" , getNumUsers()); |
13066 | db.addParam("Indices" , *(getIndices().getType())); |
13067 | db.addParam("NumberOfSelectedIndices" , *(getNumberOfSelectedIndices().getType())); |
13068 | return db; |
13069 | } |
13070 | |
13071 | void NonMaxSuppressionNode::visit(Node *parent, NodeWalker *visitor) { |
13072 | if (!visitor->shouldVisit(parent, this)) { return; } |
13073 | visitor->pre(parent, this); |
13074 | if (hasPredicate()) |
13075 | getPredicate().getNode()->visit(this, visitor); |
13076 | getBoxes().getNode()->visit(this, visitor); |
13077 | getScores().getNode()->visit(this, visitor); |
13078 | visitor->post(parent, this); |
13079 | } |
13080 | |
13081 | bool NonMaxSuppressionNode::isEqual(const NonMaxSuppressionNode &other) const { |
13082 | return true && |
13083 | Boxes_ == other.Boxes_ && |
13084 | Scores_ == other.Scores_ && |
13085 | predicate_ == other.predicate_ && |
13086 | CenterPointBox_ == other.CenterPointBox_ && |
13087 | MaxOutputBoxesPerClass_ == other.MaxOutputBoxesPerClass_ && |
13088 | IouThreshold_ == other.IouThreshold_ && |
13089 | ScoreThreshold_ == other.ScoreThreshold_ && |
13090 | IsTFVersion4_ == other.IsTFVersion4_ && |
13091 | getType(0) == other.getType(0) && |
13092 | getType(1) == other.getType(1); |
13093 | } |
13094 | |
13095 | Node* NonMaxSuppressionNode::clone() const { |
13096 | return new NonMaxSuppressionNode(getName(), getIndices().getType(), getNumberOfSelectedIndices().getType(), getBoxes(), getScores(), getCenterPointBox(), getMaxOutputBoxesPerClass(), getIouThreshold(), getScoreThreshold(), getIsTFVersion4()); |
13097 | } |
13098 | |
13099 | llvm::hash_code NonMaxSuppressionNode::getHash() const { |
13100 | return llvm::hash_combine( |
13101 | CenterPointBox_, |
13102 | MaxOutputBoxesPerClass_, |
13103 | toBinary(IouThreshold_), |
13104 | toBinary(ScoreThreshold_), |
13105 | IsTFVersion4_, |
13106 | Boxes_, |
13107 | Scores_); |
13108 | } |
13109 | |
13110 | unsigned TFLiteDetectionPostProcessNode::getNumInputs() const { |
13111 | return 3; |
13112 | } |
13113 | |
13114 | std::string TFLiteDetectionPostProcessNode::getInputName(unsigned idx) const { |
13115 | if (idx == 0) { return "Boxes" ; } |
13116 | if (idx == 1) { return "Scores" ; } |
13117 | if (idx == 2) { return "Anchors" ; } |
13118 | idx -= 3; |
13119 | llvm_unreachable("Invalid index" ); |
13120 | } |
13121 | |
13122 | NodeValue TFLiteDetectionPostProcessNode::getNthInput(unsigned idx) { |
13123 | if (idx == 0) { return Boxes_; } |
13124 | if (idx == 1) { return Scores_; } |
13125 | if (idx == 2) { return Anchors_; } |
13126 | idx -= 3; |
13127 | llvm_unreachable("Invalid index" ); |
13128 | } |
13129 | |
13130 | void TFLiteDetectionPostProcessNode::setNthInput(unsigned idx, NodeValue val) { |
13131 | if (idx == 0) { Boxes_ = val; return; } |
13132 | if (idx == 1) { Scores_ = val; return; } |
13133 | if (idx == 2) { Anchors_ = val; return; } |
13134 | idx -= 3; |
13135 | llvm_unreachable("Invalid index" ); |
13136 | } |
13137 | |
13138 | llvm::StringRef TFLiteDetectionPostProcessNode::getOutputName(unsigned idx) const { |
13139 | if (idx == 0) { return "DetectionBoxes" ; } |
13140 | if (idx == 1) { return "DetectionClasses" ; } |
13141 | if (idx == 2) { return "DetectionScores" ; } |
13142 | if (idx == 3) { return "NumDetections" ; } |
13143 | llvm_unreachable("Invalid index" ); |
13144 | } |
13145 | |
13146 | std::string TFLiteDetectionPostProcessNode::getDebugDesc() const { |
13147 | DescriptionBuilder db(getKindName()); |
13148 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
13149 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
13150 | db |
13151 | .addParam("Boxes" , *(getBoxes().getType())) |
13152 | .addParam("Scores" , *(getScores().getType())) |
13153 | .addParam("Anchors" , *(getAnchors().getType())) |
13154 | .addParam("NumClasses" , getNumClasses()) |
13155 | .addParam("MaxDetections" , getMaxDetections()) |
13156 | .addParam("MaxClassesPerDetection" , getMaxClassesPerDetection()) |
13157 | .addParam("MaxDetectionsPerClass" , getMaxDetectionsPerClass()) |
13158 | .addParam("IouThreshold" , getIouThreshold()) |
13159 | .addParam("ScoreThreshold" , getScoreThreshold()) |
13160 | .addParam("XScale" , getXScale()) |
13161 | .addParam("YScale" , getYScale()) |
13162 | .addParam("HScale" , getHScale()) |
13163 | .addParam("WScale" , getWScale()) |
13164 | .addParam("RegularNMS" , getRegularNMS()) |
13165 | .addParam("Users" , getNumUsers()); |
13166 | db.addParam("DetectionBoxes" , *(getDetectionBoxes().getType())); |
13167 | db.addParam("DetectionClasses" , *(getDetectionClasses().getType())); |
13168 | db.addParam("DetectionScores" , *(getDetectionScores().getType())); |
13169 | db.addParam("NumDetections" , *(getNumDetections().getType())); |
13170 | return db; |
13171 | } |
13172 | |
13173 | void TFLiteDetectionPostProcessNode::visit(Node *parent, NodeWalker *visitor) { |
13174 | if (!visitor->shouldVisit(parent, this)) { return; } |
13175 | visitor->pre(parent, this); |
13176 | if (hasPredicate()) |
13177 | getPredicate().getNode()->visit(this, visitor); |
13178 | getBoxes().getNode()->visit(this, visitor); |
13179 | getScores().getNode()->visit(this, visitor); |
13180 | getAnchors().getNode()->visit(this, visitor); |
13181 | visitor->post(parent, this); |
13182 | } |
13183 | |
13184 | bool TFLiteDetectionPostProcessNode::isEqual(const TFLiteDetectionPostProcessNode &other) const { |
13185 | return true && |
13186 | Boxes_ == other.Boxes_ && |
13187 | Scores_ == other.Scores_ && |
13188 | Anchors_ == other.Anchors_ && |
13189 | predicate_ == other.predicate_ && |
13190 | NumClasses_ == other.NumClasses_ && |
13191 | MaxDetections_ == other.MaxDetections_ && |
13192 | MaxClassesPerDetection_ == other.MaxClassesPerDetection_ && |
13193 | MaxDetectionsPerClass_ == other.MaxDetectionsPerClass_ && |
13194 | IouThreshold_ == other.IouThreshold_ && |
13195 | ScoreThreshold_ == other.ScoreThreshold_ && |
13196 | XScale_ == other.XScale_ && |
13197 | YScale_ == other.YScale_ && |
13198 | HScale_ == other.HScale_ && |
13199 | WScale_ == other.WScale_ && |
13200 | RegularNMS_ == other.RegularNMS_ && |
13201 | getType(0) == other.getType(0) && |
13202 | getType(1) == other.getType(1) && |
13203 | getType(2) == other.getType(2) && |
13204 | getType(3) == other.getType(3); |
13205 | } |
13206 | |
13207 | Node* TFLiteDetectionPostProcessNode::clone() const { |
13208 | return new TFLiteDetectionPostProcessNode(getName(), getDetectionBoxes().getType(), getDetectionClasses().getType(), getDetectionScores().getType(), getNumDetections().getType(), getBoxes(), getScores(), getAnchors(), getNumClasses(), getMaxDetections(), getMaxClassesPerDetection(), getMaxDetectionsPerClass(), getIouThreshold(), getScoreThreshold(), getXScale(), getYScale(), getHScale(), getWScale(), getRegularNMS()); |
13209 | } |
13210 | |
13211 | llvm::hash_code TFLiteDetectionPostProcessNode::getHash() const { |
13212 | return llvm::hash_combine( |
13213 | NumClasses_, |
13214 | MaxDetections_, |
13215 | MaxClassesPerDetection_, |
13216 | MaxDetectionsPerClass_, |
13217 | toBinary(IouThreshold_), |
13218 | toBinary(ScoreThreshold_), |
13219 | toBinary(XScale_), |
13220 | toBinary(YScale_), |
13221 | toBinary(HScale_), |
13222 | toBinary(WScale_), |
13223 | RegularNMS_, |
13224 | Boxes_, |
13225 | Scores_, |
13226 | Anchors_); |
13227 | } |
13228 | |
13229 | unsigned ROIAlignNode::getNumInputs() const { |
13230 | return 3; |
13231 | } |
13232 | |
13233 | std::string ROIAlignNode::getInputName(unsigned idx) const { |
13234 | if (idx == 0) { return "FeatureMap" ; } |
13235 | if (idx == 1) { return "Boxes" ; } |
13236 | if (idx == 2) { return "BatchIndices" ; } |
13237 | idx -= 3; |
13238 | llvm_unreachable("Invalid index" ); |
13239 | } |
13240 | |
13241 | NodeValue ROIAlignNode::getNthInput(unsigned idx) { |
13242 | if (idx == 0) { return FeatureMap_; } |
13243 | if (idx == 1) { return Boxes_; } |
13244 | if (idx == 2) { return BatchIndices_; } |
13245 | idx -= 3; |
13246 | llvm_unreachable("Invalid index" ); |
13247 | } |
13248 | |
13249 | void ROIAlignNode::setNthInput(unsigned idx, NodeValue val) { |
13250 | if (idx == 0) { FeatureMap_ = val; return; } |
13251 | if (idx == 1) { Boxes_ = val; return; } |
13252 | if (idx == 2) { BatchIndices_ = val; return; } |
13253 | idx -= 3; |
13254 | llvm_unreachable("Invalid index" ); |
13255 | } |
13256 | |
13257 | llvm::StringRef ROIAlignNode::getOutputName(unsigned idx) const { |
13258 | if (idx == 0) { return "Result" ; } |
13259 | llvm_unreachable("Invalid index" ); |
13260 | } |
13261 | |
13262 | std::string ROIAlignNode::getDebugDesc() const { |
13263 | DescriptionBuilder db(getKindName()); |
13264 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
13265 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
13266 | db |
13267 | .addParam("FeatureMap" , *(getFeatureMap().getType())) |
13268 | .addParam("Boxes" , *(getBoxes().getType())) |
13269 | .addParam("BatchIndices" , *(getBatchIndices().getType())) |
13270 | .addParam("Mode" , getMode()) |
13271 | .addParam("OutputHeight" , getOutputHeight()) |
13272 | .addParam("OutputWidth" , getOutputWidth()) |
13273 | .addParam("SamplingRatio" , getSamplingRatio()) |
13274 | .addParam("SpatialScale" , getSpatialScale()) |
13275 | .addParam("Aligned" , getAligned()) |
13276 | .addParam("Rotated" , getRotated()) |
13277 | .addParam("Users" , getNumUsers()); |
13278 | db.addParam("Result" , *(getResult().getType())); |
13279 | return db; |
13280 | } |
13281 | |
13282 | void ROIAlignNode::visit(Node *parent, NodeWalker *visitor) { |
13283 | if (!visitor->shouldVisit(parent, this)) { return; } |
13284 | visitor->pre(parent, this); |
13285 | if (hasPredicate()) |
13286 | getPredicate().getNode()->visit(this, visitor); |
13287 | getFeatureMap().getNode()->visit(this, visitor); |
13288 | getBoxes().getNode()->visit(this, visitor); |
13289 | getBatchIndices().getNode()->visit(this, visitor); |
13290 | visitor->post(parent, this); |
13291 | } |
13292 | |
13293 | bool ROIAlignNode::isEqual(const ROIAlignNode &other) const { |
13294 | return true && |
13295 | FeatureMap_ == other.FeatureMap_ && |
13296 | Boxes_ == other.Boxes_ && |
13297 | BatchIndices_ == other.BatchIndices_ && |
13298 | predicate_ == other.predicate_ && |
13299 | Mode_ == other.Mode_ && |
13300 | OutputHeight_ == other.OutputHeight_ && |
13301 | OutputWidth_ == other.OutputWidth_ && |
13302 | SamplingRatio_ == other.SamplingRatio_ && |
13303 | SpatialScale_ == other.SpatialScale_ && |
13304 | Aligned_ == other.Aligned_ && |
13305 | Rotated_ == other.Rotated_ && |
13306 | getType(0) == other.getType(0); |
13307 | } |
13308 | |
13309 | Node* ROIAlignNode::clone() const { |
13310 | return new ROIAlignNode(getName(), getResult().getType(), getFeatureMap(), getBoxes(), getBatchIndices(), getMode(), getOutputHeight(), getOutputWidth(), getSamplingRatio(), getSpatialScale(), getAligned(), getRotated()); |
13311 | } |
13312 | |
13313 | llvm::hash_code ROIAlignNode::getHash() const { |
13314 | return llvm::hash_combine( |
13315 | Mode_, |
13316 | OutputHeight_, |
13317 | OutputWidth_, |
13318 | SamplingRatio_, |
13319 | toBinary(SpatialScale_), |
13320 | Aligned_, |
13321 | Rotated_, |
13322 | FeatureMap_, |
13323 | Boxes_, |
13324 | BatchIndices_); |
13325 | } |
13326 | |
13327 | unsigned BBoxTransformNode::getNumInputs() const { |
13328 | return 3; |
13329 | } |
13330 | |
13331 | std::string BBoxTransformNode::getInputName(unsigned idx) const { |
13332 | if (idx == 0) { return "Rois" ; } |
13333 | if (idx == 1) { return "Deltas" ; } |
13334 | if (idx == 2) { return "ImInfo" ; } |
13335 | idx -= 3; |
13336 | llvm_unreachable("Invalid index" ); |
13337 | } |
13338 | |
13339 | NodeValue BBoxTransformNode::getNthInput(unsigned idx) { |
13340 | if (idx == 0) { return Rois_; } |
13341 | if (idx == 1) { return Deltas_; } |
13342 | if (idx == 2) { return ImInfo_; } |
13343 | idx -= 3; |
13344 | llvm_unreachable("Invalid index" ); |
13345 | } |
13346 | |
13347 | void BBoxTransformNode::setNthInput(unsigned idx, NodeValue val) { |
13348 | if (idx == 0) { Rois_ = val; return; } |
13349 | if (idx == 1) { Deltas_ = val; return; } |
13350 | if (idx == 2) { ImInfo_ = val; return; } |
13351 | idx -= 3; |
13352 | llvm_unreachable("Invalid index" ); |
13353 | } |
13354 | |
13355 | llvm::StringRef BBoxTransformNode::getOutputName(unsigned idx) const { |
13356 | if (idx == 0) { return "BoxOut" ; } |
13357 | if (idx == 1) { return "RoiBatchSplits" ; } |
13358 | llvm_unreachable("Invalid index" ); |
13359 | } |
13360 | |
13361 | std::string BBoxTransformNode::getDebugDesc() const { |
13362 | DescriptionBuilder db(getKindName()); |
13363 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
13364 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
13365 | db |
13366 | .addParam("Rois" , *(getRois().getType())) |
13367 | .addParam("Deltas" , *(getDeltas().getType())) |
13368 | .addParam("ImInfo" , *(getImInfo().getType())) |
13369 | .addParam("Weights" , getWeights()) |
13370 | .addParam("ApplyScale" , getApplyScale()) |
13371 | .addParam("Rotated" , getRotated()) |
13372 | .addParam("AngleBoundOn" , getAngleBoundOn()) |
13373 | .addParam("AngleBoundLo" , getAngleBoundLo()) |
13374 | .addParam("AngleBoundHi" , getAngleBoundHi()) |
13375 | .addParam("ClipAngleThresh" , getClipAngleThresh()) |
13376 | .addParam("LegacyPlusOne" , getLegacyPlusOne()) |
13377 | .addParam("Users" , getNumUsers()); |
13378 | db.addParam("BoxOut" , *(getBoxOut().getType())); |
13379 | db.addParam("RoiBatchSplits" , *(getRoiBatchSplits().getType())); |
13380 | return db; |
13381 | } |
13382 | |
13383 | void BBoxTransformNode::visit(Node *parent, NodeWalker *visitor) { |
13384 | if (!visitor->shouldVisit(parent, this)) { return; } |
13385 | visitor->pre(parent, this); |
13386 | if (hasPredicate()) |
13387 | getPredicate().getNode()->visit(this, visitor); |
13388 | getRois().getNode()->visit(this, visitor); |
13389 | getDeltas().getNode()->visit(this, visitor); |
13390 | getImInfo().getNode()->visit(this, visitor); |
13391 | visitor->post(parent, this); |
13392 | } |
13393 | |
13394 | bool BBoxTransformNode::isEqual(const BBoxTransformNode &other) const { |
13395 | return true && |
13396 | Rois_ == other.Rois_ && |
13397 | Deltas_ == other.Deltas_ && |
13398 | ImInfo_ == other.ImInfo_ && |
13399 | predicate_ == other.predicate_ && |
13400 | Weights_ == other.Weights_ && |
13401 | ApplyScale_ == other.ApplyScale_ && |
13402 | Rotated_ == other.Rotated_ && |
13403 | AngleBoundOn_ == other.AngleBoundOn_ && |
13404 | AngleBoundLo_ == other.AngleBoundLo_ && |
13405 | AngleBoundHi_ == other.AngleBoundHi_ && |
13406 | ClipAngleThresh_ == other.ClipAngleThresh_ && |
13407 | LegacyPlusOne_ == other.LegacyPlusOne_ && |
13408 | getType(0) == other.getType(0) && |
13409 | getType(1) == other.getType(1); |
13410 | } |
13411 | |
13412 | Node* BBoxTransformNode::clone() const { |
13413 | return new BBoxTransformNode(getName(), getBoxOut().getType(), getRoiBatchSplits().getType(), getRois(), getDeltas(), getImInfo(), getWeights(), getApplyScale(), getRotated(), getAngleBoundOn(), getAngleBoundLo(), getAngleBoundHi(), getClipAngleThresh(), getLegacyPlusOne()); |
13414 | } |
13415 | |
13416 | llvm::hash_code BBoxTransformNode::getHash() const { |
13417 | return llvm::hash_combine( |
13418 | [](const std::vector<float>& floatVec) -> llvm::hash_code { |
13419 | std::vector<size_t> sizeVec = toBinary(floatVec); |
13420 | return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end()); |
13421 | }(Weights_), |
13422 | ApplyScale_, |
13423 | Rotated_, |
13424 | AngleBoundOn_, |
13425 | AngleBoundLo_, |
13426 | AngleBoundHi_, |
13427 | toBinary(ClipAngleThresh_), |
13428 | LegacyPlusOne_, |
13429 | Rois_, |
13430 | Deltas_, |
13431 | ImInfo_); |
13432 | } |
13433 | |
13434 | unsigned CollectRpnProposalsNode::getNumInputs() const { |
13435 | return 0 + RoisIn_.size() + RoisProbsIn_.size(); |
13436 | } |
13437 | |
13438 | std::string CollectRpnProposalsNode::getInputName(unsigned idx) const { |
13439 | idx -= 0; |
13440 | if (idx < RoisIn_.size()) { return "RoisIn" + std::to_string(idx); } |
13441 | idx -= RoisIn_.size(); |
13442 | if (idx < RoisProbsIn_.size()) { return "RoisProbsIn" + std::to_string(idx); } |
13443 | idx -= RoisProbsIn_.size(); |
13444 | llvm_unreachable("Invalid index" ); |
13445 | } |
13446 | |
13447 | NodeValue CollectRpnProposalsNode::getNthInput(unsigned idx) { |
13448 | idx -= 0; |
13449 | if (idx < RoisIn_.size()) { return RoisIn_[idx]; } |
13450 | idx -= RoisIn_.size(); |
13451 | if (idx < RoisProbsIn_.size()) { return RoisProbsIn_[idx]; } |
13452 | idx -= RoisProbsIn_.size(); |
13453 | llvm_unreachable("Invalid index" ); |
13454 | } |
13455 | |
13456 | void CollectRpnProposalsNode::setNthInput(unsigned idx, NodeValue val) { |
13457 | idx -= 0; |
13458 | if (idx < RoisIn_.size()) { RoisIn_[idx] = val; return; } |
13459 | idx -= RoisIn_.size(); |
13460 | if (idx < RoisProbsIn_.size()) { RoisProbsIn_[idx] = val; return; } |
13461 | idx -= RoisProbsIn_.size(); |
13462 | llvm_unreachable("Invalid index" ); |
13463 | } |
13464 | |
13465 | llvm::StringRef CollectRpnProposalsNode::getOutputName(unsigned idx) const { |
13466 | if (idx == 0) { return "Result" ; } |
13467 | llvm_unreachable("Invalid index" ); |
13468 | } |
13469 | |
13470 | std::string CollectRpnProposalsNode::getDebugDesc() const { |
13471 | DescriptionBuilder db(getKindName()); |
13472 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
13473 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
13474 | db |
13475 | .addParam("RpnMaxLevel" , getRpnMaxLevel()) |
13476 | .addParam("RpnMinLevel" , getRpnMinLevel()) |
13477 | .addParam("RpnPostNmsTopN" , getRpnPostNmsTopN()) |
13478 | .addParam("Users" , getNumUsers()); |
13479 | { |
13480 | unsigned mIndex = 0; |
13481 | for (const auto &II : getRoisIn()) { |
13482 | db.addParam("RoisIn" +std::to_string(mIndex++), *II.getType()); |
13483 | } |
13484 | } |
13485 | { |
13486 | unsigned mIndex = 0; |
13487 | for (const auto &II : getRoisProbsIn()) { |
13488 | db.addParam("RoisProbsIn" +std::to_string(mIndex++), *II.getType()); |
13489 | } |
13490 | } |
13491 | db.addParam("Result" , *(getResult().getType())); |
13492 | return db; |
13493 | } |
13494 | |
13495 | void CollectRpnProposalsNode::visit(Node *parent, NodeWalker *visitor) { |
13496 | if (!visitor->shouldVisit(parent, this)) { return; } |
13497 | visitor->pre(parent, this); |
13498 | if (hasPredicate()) |
13499 | getPredicate().getNode()->visit(this, visitor); |
13500 | for (auto &I : RoisIn_) { I.getNode()->visit(this, visitor); } |
13501 | for (auto &I : RoisProbsIn_) { I.getNode()->visit(this, visitor); } |
13502 | visitor->post(parent, this); |
13503 | } |
13504 | |
13505 | bool CollectRpnProposalsNode::isEqual(const CollectRpnProposalsNode &other) const { |
13506 | return true && |
13507 | predicate_ == other.predicate_ && |
13508 | RoisIn_ == other.RoisIn_ && |
13509 | RoisProbsIn_ == other.RoisProbsIn_ && |
13510 | RpnMaxLevel_ == other.RpnMaxLevel_ && |
13511 | RpnMinLevel_ == other.RpnMinLevel_ && |
13512 | RpnPostNmsTopN_ == other.RpnPostNmsTopN_ && |
13513 | getType(0) == other.getType(0); |
13514 | } |
13515 | |
13516 | Node* CollectRpnProposalsNode::clone() const { |
13517 | return new CollectRpnProposalsNode(getName(), getResult().getType(), getRoisIn(), getRoisProbsIn(), getRpnMaxLevel(), getRpnMinLevel(), getRpnPostNmsTopN()); |
13518 | } |
13519 | |
13520 | llvm::hash_code CollectRpnProposalsNode::getHash() const { |
13521 | return llvm::hash_combine( |
13522 | llvm::hash_combine_range(RoisIn_.begin(), RoisIn_.end()), |
13523 | llvm::hash_combine_range(RoisProbsIn_.begin(), RoisProbsIn_.end()), |
13524 | RpnMaxLevel_, |
13525 | RpnMinLevel_, |
13526 | RpnPostNmsTopN_); |
13527 | } |
13528 | |
13529 | unsigned LookupTableNode::getNumInputs() const { |
13530 | return 3; |
13531 | } |
13532 | |
13533 | std::string LookupTableNode::getInputName(unsigned idx) const { |
13534 | if (idx == 0) { return "Input" ; } |
13535 | if (idx == 1) { return "Table" ; } |
13536 | if (idx == 2) { return "TableIdx" ; } |
13537 | idx -= 3; |
13538 | llvm_unreachable("Invalid index" ); |
13539 | } |
13540 | |
13541 | NodeValue LookupTableNode::getNthInput(unsigned idx) { |
13542 | if (idx == 0) { return Input_; } |
13543 | if (idx == 1) { return Table_; } |
13544 | if (idx == 2) { return TableIdx_; } |
13545 | idx -= 3; |
13546 | llvm_unreachable("Invalid index" ); |
13547 | } |
13548 | |
13549 | void LookupTableNode::setNthInput(unsigned idx, NodeValue val) { |
13550 | if (idx == 0) { Input_ = val; return; } |
13551 | if (idx == 1) { Table_ = val; return; } |
13552 | if (idx == 2) { TableIdx_ = val; return; } |
13553 | idx -= 3; |
13554 | llvm_unreachable("Invalid index" ); |
13555 | } |
13556 | |
13557 | llvm::StringRef LookupTableNode::getOutputName(unsigned idx) const { |
13558 | if (idx == 0) { return "Result" ; } |
13559 | llvm_unreachable("Invalid index" ); |
13560 | } |
13561 | |
13562 | std::string LookupTableNode::getDebugDesc() const { |
13563 | DescriptionBuilder db(getKindName()); |
13564 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
13565 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
13566 | db |
13567 | .addParam("Input" , *(getInput().getType())) |
13568 | .addParam("Table" , *(getTable().getType())) |
13569 | .addParam("TableIdx" , *(getTableIdx().getType())) |
13570 | .addParam("Operator" , getOperator()) |
13571 | .addParam("OperatorArgs" , getOperatorArgs()) |
13572 | .addParam("Users" , getNumUsers()); |
13573 | db.addParam("Result" , *(getResult().getType())); |
13574 | return db; |
13575 | } |
13576 | |
13577 | void LookupTableNode::visit(Node *parent, NodeWalker *visitor) { |
13578 | if (!visitor->shouldVisit(parent, this)) { return; } |
13579 | visitor->pre(parent, this); |
13580 | if (hasPredicate()) |
13581 | getPredicate().getNode()->visit(this, visitor); |
13582 | getInput().getNode()->visit(this, visitor); |
13583 | getTable().getNode()->visit(this, visitor); |
13584 | getTableIdx().getNode()->visit(this, visitor); |
13585 | visitor->post(parent, this); |
13586 | } |
13587 | |
13588 | bool LookupTableNode::isEqual(const LookupTableNode &other) const { |
13589 | return true && |
13590 | Input_ == other.Input_ && |
13591 | Table_ == other.Table_ && |
13592 | TableIdx_ == other.TableIdx_ && |
13593 | predicate_ == other.predicate_ && |
13594 | Operator_ == other.Operator_ && |
13595 | OperatorArgs_ == other.OperatorArgs_ && |
13596 | getType(0) == other.getType(0); |
13597 | } |
13598 | |
13599 | Node* LookupTableNode::clone() const { |
13600 | return new LookupTableNode(getName(), getResult().getType(), getInput(), getTable(), getTableIdx(), getOperator(), getOperatorArgs()); |
13601 | } |
13602 | |
13603 | llvm::hash_code LookupTableNode::getHash() const { |
13604 | return llvm::hash_combine( |
13605 | Operator_, |
13606 | [](const std::vector<float>& floatVec) -> llvm::hash_code { |
13607 | std::vector<size_t> sizeVec = toBinary(floatVec); |
13608 | return llvm::hash_combine_range(sizeVec.begin(), sizeVec.end()); |
13609 | }(OperatorArgs_), |
13610 | Input_, |
13611 | Table_, |
13612 | TableIdx_); |
13613 | } |
13614 | |
13615 | unsigned CPUMaxSplatNode::getNumInputs() const { |
13616 | return 1; |
13617 | } |
13618 | |
13619 | std::string CPUMaxSplatNode::getInputName(unsigned idx) const { |
13620 | if (idx == 0) { return "Input" ; } |
13621 | idx -= 1; |
13622 | llvm_unreachable("Invalid index" ); |
13623 | } |
13624 | |
13625 | NodeValue CPUMaxSplatNode::getNthInput(unsigned idx) { |
13626 | if (idx == 0) { return Input_; } |
13627 | idx -= 1; |
13628 | llvm_unreachable("Invalid index" ); |
13629 | } |
13630 | |
13631 | void CPUMaxSplatNode::setNthInput(unsigned idx, NodeValue val) { |
13632 | if (idx == 0) { Input_ = val; return; } |
13633 | idx -= 1; |
13634 | llvm_unreachable("Invalid index" ); |
13635 | } |
13636 | |
13637 | llvm::StringRef CPUMaxSplatNode::getOutputName(unsigned idx) const { |
13638 | if (idx == 0) { return "Result" ; } |
13639 | llvm_unreachable("Invalid index" ); |
13640 | } |
13641 | |
13642 | std::string CPUMaxSplatNode::getDebugDesc() const { |
13643 | DescriptionBuilder db(getKindName()); |
13644 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
13645 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
13646 | db |
13647 | .addParam("Input" , *(getInput().getType())) |
13648 | .addParam("SplatValue" , getSplatValue()) |
13649 | .addParam("Users" , getNumUsers()); |
13650 | db.addParam("Result" , *(getResult().getType())); |
13651 | return db; |
13652 | } |
13653 | |
13654 | void CPUMaxSplatNode::visit(Node *parent, NodeWalker *visitor) { |
13655 | if (!visitor->shouldVisit(parent, this)) { return; } |
13656 | visitor->pre(parent, this); |
13657 | if (hasPredicate()) |
13658 | getPredicate().getNode()->visit(this, visitor); |
13659 | getInput().getNode()->visit(this, visitor); |
13660 | visitor->post(parent, this); |
13661 | } |
13662 | |
13663 | bool CPUMaxSplatNode::isEqual(const CPUMaxSplatNode &other) const { |
13664 | return true && |
13665 | Input_ == other.Input_ && |
13666 | predicate_ == other.predicate_ && |
13667 | SplatValue_ == other.SplatValue_ && |
13668 | getType(0) == other.getType(0); |
13669 | } |
13670 | |
13671 | Node* CPUMaxSplatNode::clone() const { |
13672 | return new CPUMaxSplatNode(getName(), getInput(), getSplatValue()); |
13673 | } |
13674 | |
13675 | llvm::hash_code CPUMaxSplatNode::getHash() const { |
13676 | return llvm::hash_combine( |
13677 | toBinary(SplatValue_), |
13678 | Input_); |
13679 | } |
13680 | |
13681 | unsigned CPUConvDKKC8Node::getNumInputs() const { |
13682 | return 3; |
13683 | } |
13684 | |
13685 | std::string CPUConvDKKC8Node::getInputName(unsigned idx) const { |
13686 | if (idx == 0) { return "Input" ; } |
13687 | if (idx == 1) { return "Filter" ; } |
13688 | if (idx == 2) { return "Bias" ; } |
13689 | idx -= 3; |
13690 | llvm_unreachable("Invalid index" ); |
13691 | } |
13692 | |
13693 | NodeValue CPUConvDKKC8Node::getNthInput(unsigned idx) { |
13694 | if (idx == 0) { return Input_; } |
13695 | if (idx == 1) { return Filter_; } |
13696 | if (idx == 2) { return Bias_; } |
13697 | idx -= 3; |
13698 | llvm_unreachable("Invalid index" ); |
13699 | } |
13700 | |
13701 | void CPUConvDKKC8Node::setNthInput(unsigned idx, NodeValue val) { |
13702 | if (idx == 0) { Input_ = val; return; } |
13703 | if (idx == 1) { Filter_ = val; return; } |
13704 | if (idx == 2) { Bias_ = val; return; } |
13705 | idx -= 3; |
13706 | llvm_unreachable("Invalid index" ); |
13707 | } |
13708 | |
13709 | llvm::StringRef CPUConvDKKC8Node::getOutputName(unsigned idx) const { |
13710 | if (idx == 0) { return "Result" ; } |
13711 | llvm_unreachable("Invalid index" ); |
13712 | } |
13713 | |
13714 | std::string CPUConvDKKC8Node::getDebugDesc() const { |
13715 | DescriptionBuilder db(getKindName()); |
13716 | db.addParam("Name" , separateString(getName(), 100, "\n" )); |
13717 | if (hasPredicate()) db.addParam("Predicate" , "Yes" ); |
13718 | db |
13719 | .addParam("Input" , *(getInput().getType())) |
13720 | .addParam("Filter" , *(getFilter().getType())) |
13721 | .addParam("Bias" , *(getBias().getType())) |
13722 | .addParam("Kernels" , getKernels()) |
13723 | .addParam("Strides" , getStrides()) |
13724 | .addParam("Pads" , getPads()) |
13725 | .addParam("Group" , getGroup()) |
13726 | .addParam("Users" , getNumUsers()); |
13727 | db.addParam("Result" , *(getResult().getType())); |
13728 | return db; |
13729 | } |
13730 | |
13731 | void CPUConvDKKC8Node::visit(Node *parent, NodeWalker *visitor) { |
13732 | if (!visitor->shouldVisit(parent, this)) { return; } |
13733 | visitor->pre(parent, this); |
13734 | if (hasPredicate()) |
13735 | getPredicate().getNode()->visit(this, visitor); |
13736 | getInput().getNode()->visit(this, visitor); |
13737 | getFilter().getNode()->visit(this, visitor); |
13738 | getBias().getNode()->visit(this, visitor); |
13739 | visitor->post(parent, this); |
13740 | } |
13741 | |
13742 | bool CPUConvDKKC8Node::isEqual(const CPUConvDKKC8Node &other) const { |
13743 | return true && |
13744 | Input_ == other.Input_ && |
13745 | Filter_ == other.Filter_ && |
13746 | Bias_ == other.Bias_ && |
13747 | predicate_ == other.predicate_ && |
13748 | Kernels_ == other.Kernels_ && |
13749 | Strides_ == other.Strides_ && |
13750 | Pads_ == other.Pads_ && |
13751 | Group_ == other.Group_ && |
13752 | getType(0) == other.getType(0); |
13753 | } |
13754 | |
13755 | Node* CPUConvDKKC8Node::clone() const { |
13756 | return new CPUConvDKKC8Node(getName(), getResult().getType(), getInput(), getFilter(), getBias(), getKernels(), getStrides(), getPads(), getGroup()); |
13757 | } |
13758 | |
13759 | llvm::hash_code CPUConvDKKC8Node::getHash() const { |
13760 | return llvm::hash_combine( |
13761 | llvm::hash_combine_range(Kernels_.begin(), Kernels_.end()), |
13762 | llvm::hash_combine_range(Strides_.begin(), Strides_.end()), |
13763 | llvm::hash_combine_range(Pads_.begin(), Pads_.end()), |
13764 | Group_, |
13765 | Input_, |
13766 | Filter_, |
13767 | Bias_); |
13768 | } |
13769 | |
13770 | #include "glow/CPUSpecificNodesVerification.h" |
13771 | |