1 | |
2 | #include "glow/Graph/Nodes.h" |
3 | |
4 | |
5 | namespace glow { |
6 | /// Specifies a node whose Input will be copied to Output.This node prevents graph optimizations from eliminating this node and all of its ancestor nodes. Generally intended to save the final result of a network. |
7 | class SaveNode final : public Node { |
8 | NodeHandle Input_; |
9 | NodeHandle Output_; |
10 | |
11 | public: |
12 | enum InputIndices { |
13 | InputIdx = 0, |
14 | OutputIdx = 1, |
15 | }; |
16 | |
17 | enum ResultIndices { |
18 | }; |
19 | |
20 | SaveNode(llvm::StringRef name, NodeValue Input, NodeValue Output) |
21 | : Node(Kinded::Kind::SaveNodeKind, name), Input_(this, Input), Output_(this, Output) { |
22 | } |
23 | const NodeValue getInput() const { return Input_; } |
24 | const NodeValue getOutput() const { return Output_; } |
25 | |
26 | static bool classof(const Kinded *k) { |
27 | return k->getKind() == Kinded::Kind::SaveNodeKind; |
28 | } |
29 | |
30 | |
31 | bool isOverwrittenNthInput(unsigned idx) const { |
32 | if (idx == 1) return true; |
33 | return false; |
34 | } |
35 | |
36 | unsigned getNumInputs() const; |
37 | std::string getInputName(unsigned idx) const; |
38 | NodeValue getNthInput(unsigned idx); |
39 | void setNthInput(unsigned idx, NodeValue val); |
40 | llvm::StringRef getOutputName(unsigned idx) const; |
41 | bool hasSideEffects() const { return 1; } |
42 | bool isCanonical() const { return 1; } |
43 | bool isDataParallel() const { return 1; } |
44 | std::string getDebugDesc() const; |
45 | bool isEqual(const SaveNode &other) const; |
46 | llvm::hash_code getHash() const; |
47 | void visit(Node *parent, NodeWalker *visitor); |
48 | Node* clone() const; |
49 | bool verify() const; |
50 | Placeholder *getPlaceholder() const;}; |
51 | } // namespace glow |
52 | |
53 | |
54 | namespace glow { |
55 | /// Performs padding of a given input tensor. The Padding information must be specified for each dimension of the tensor in Pads (start and end padding). In case the padding is negative, it means that the tensor must be cropped. Mode defines how extra padding elements are created. Supported modes are defined in the PaddingMode enum: CONSTANT, REFLECT, EDGE. Value is only used with the CONSTANT mode. |
56 | class PadNode final : public Node { |
57 | NodeHandle Input_; |
58 | unsigned_t Mode_; |
59 | std::vector<int> Pads_; |
60 | float Value_; |
61 | |
62 | public: |
63 | enum InputIndices { |
64 | InputIdx = 0, |
65 | }; |
66 | |
67 | enum ResultIndices { |
68 | ResultIdx = 0, |
69 | }; |
70 | |
71 | PadNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Mode, std::vector<int> Pads, float Value) |
72 | : Node(Kinded::Kind::PadNodeKind, name), Input_(this, Input), Mode_(Mode), Pads_(Pads), Value_(Value) { |
73 | addResult(Result); |
74 | } |
75 | const NodeValue getInput() const { return Input_; } |
76 | NodeValue getResult() { return getNthResult(0); } |
77 | const NodeValue getResult() const { return getNthResult(0); } |
78 | unsigned_t getMode() const { return Mode_; } |
79 | llvm::ArrayRef<int> getPads() const { return Pads_; } |
80 | float getValue() const { return Value_; } |
81 | |
82 | static bool classof(const Kinded *k) { |
83 | return k->getKind() == Kinded::Kind::PadNodeKind; |
84 | } |
85 | |
86 | |
87 | bool isOverwrittenNthInput(unsigned idx) const { |
88 | return false; |
89 | } |
90 | |
91 | unsigned getNumInputs() const; |
92 | std::string getInputName(unsigned idx) const; |
93 | NodeValue getNthInput(unsigned idx); |
94 | void setNthInput(unsigned idx, NodeValue val); |
95 | llvm::StringRef getOutputName(unsigned idx) const; |
96 | bool hasSideEffects() const { return 0; } |
97 | bool isCanonical() const { return 1; } |
98 | bool isDataParallel() const { return 0; } |
99 | std::string getDebugDesc() const; |
100 | bool isEqual(const PadNode &other) const; |
101 | llvm::hash_code getHash() const; |
102 | void visit(Node *parent, NodeWalker *visitor); |
103 | Node* clone() const; |
104 | bool verify() const; |
105 | }; |
106 | } // namespace glow |
107 | |
108 | |
109 | namespace glow { |
110 | class ConvolutionGradNode final : public Node { |
111 | NodeHandle Input_; |
112 | NodeHandle Filter_; |
113 | NodeHandle Bias_; |
114 | NodeHandle OriginalOutputForResult_; |
115 | NodeHandle GradOfOriginalOutputNamedResult_; |
116 | std::vector<unsigned_t> Kernels_; |
117 | std::vector<unsigned_t> Strides_; |
118 | std::vector<unsigned_t> Pads_; |
119 | unsigned_t Group_; |
120 | std::vector<unsigned_t> Dilation_; |
121 | glow::ConvolutionLayout Layout_; |
122 | glow::FusedActivation FusedActivation_; |
123 | std::vector<float> FusedActivationArgs_; |
124 | |
125 | public: |
126 | enum InputIndices { |
127 | InputIdx = 0, |
128 | FilterIdx = 1, |
129 | BiasIdx = 2, |
130 | OriginalOutputForResultIdx = 3, |
131 | GradOfOriginalOutputNamedResultIdx = 4, |
132 | }; |
133 | |
134 | enum ResultIndices { |
135 | GradOfInputNamedInputIdx = 0, |
136 | GradOfInputNamedFilterIdx = 1, |
137 | GradOfInputNamedBiasIdx = 2, |
138 | }; |
139 | |
140 | ConvolutionGradNode(llvm::StringRef name, NodeValue Input, NodeValue Filter, NodeValue Bias, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group, std::vector<unsigned_t> Dilation, glow::ConvolutionLayout Layout, glow::FusedActivation FusedActivation, std::vector<float> FusedActivationArgs) |
141 | : Node(Kinded::Kind::ConvolutionGradNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group), Dilation_(Dilation), Layout_(Layout), FusedActivation_(FusedActivation), FusedActivationArgs_(FusedActivationArgs) { |
142 | addResult(Input.getType()); |
143 | addResult(Filter.getType()); |
144 | addResult(Bias.getType()); |
145 | } |
146 | const NodeValue getInput() const { return Input_; } |
147 | const NodeValue getFilter() const { return Filter_; } |
148 | const NodeValue getBias() const { return Bias_; } |
149 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
150 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
151 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
152 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
153 | NodeValue getGradOfInputNamedFilter() { return getNthResult(1); } |
154 | const NodeValue getGradOfInputNamedFilter() const { return getNthResult(1); } |
155 | NodeValue getGradOfInputNamedBias() { return getNthResult(2); } |
156 | const NodeValue getGradOfInputNamedBias() const { return getNthResult(2); } |
157 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
158 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
159 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
160 | void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; } |
161 | unsigned_t getGroup() const { return Group_; } |
162 | void setGroup(unsigned_t a) {Group_ = a; } |
163 | llvm::ArrayRef<unsigned_t> getDilation() const { return Dilation_; } |
164 | glow::ConvolutionLayout getLayout() const { return Layout_; } |
165 | glow::FusedActivation getFusedActivation() const { return FusedActivation_; } |
166 | void setFusedActivation(glow::FusedActivation a) {FusedActivation_ = a; } |
167 | llvm::ArrayRef<float> getFusedActivationArgs() const { return FusedActivationArgs_; } |
168 | void setFusedActivationArgs(llvm::ArrayRef<float> a) {FusedActivationArgs_ = a; } |
169 | |
170 | static bool classof(const Kinded *k) { |
171 | return k->getKind() == Kinded::Kind::ConvolutionGradNodeKind; |
172 | } |
173 | |
174 | |
175 | bool isOverwrittenNthInput(unsigned idx) const { |
176 | return false; |
177 | } |
178 | |
179 | unsigned getNumInputs() const; |
180 | std::string getInputName(unsigned idx) const; |
181 | NodeValue getNthInput(unsigned idx); |
182 | void setNthInput(unsigned idx, NodeValue val); |
183 | llvm::StringRef getOutputName(unsigned idx) const; |
184 | bool hasSideEffects() const { return 0; } |
185 | bool isCanonical() const { return 1; } |
186 | bool isDataParallel() const { return 0; } |
187 | std::string getDebugDesc() const; |
188 | bool isEqual(const ConvolutionGradNode &other) const; |
189 | llvm::hash_code getHash() const; |
190 | void visit(Node *parent, NodeWalker *visitor); |
191 | Node* clone() const; |
192 | bool verify() const; |
193 | }; |
194 | } // namespace glow |
195 | |
196 | |
197 | namespace glow { |
198 | /// Performs 2D Convolution using a given Input, Filter, and Bias tensors, as well as provided Kernels, Strides, Pads, Group and Dilation. Supported Layouts are defined in the ConvolutionLayout enum: NHWC and NCHW. Supported FusedActivations are defined in the FusedActivation enum. |
199 | class ConvolutionNode final : public Node { |
200 | NodeHandle Input_; |
201 | NodeHandle Filter_; |
202 | NodeHandle Bias_; |
203 | std::vector<unsigned_t> Kernels_; |
204 | std::vector<unsigned_t> Strides_; |
205 | std::vector<unsigned_t> Pads_; |
206 | unsigned_t Group_; |
207 | std::vector<unsigned_t> Dilation_; |
208 | glow::ConvolutionLayout Layout_; |
209 | glow::FusedActivation FusedActivation_; |
210 | std::vector<float> FusedActivationArgs_; |
211 | |
212 | public: |
213 | enum InputIndices { |
214 | InputIdx = 0, |
215 | FilterIdx = 1, |
216 | BiasIdx = 2, |
217 | }; |
218 | |
219 | enum ResultIndices { |
220 | ResultIdx = 0, |
221 | }; |
222 | |
223 | ConvolutionNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group, std::vector<unsigned_t> Dilation, glow::ConvolutionLayout Layout, glow::FusedActivation FusedActivation, std::vector<float> FusedActivationArgs) |
224 | : Node(Kinded::Kind::ConvolutionNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group), Dilation_(Dilation), Layout_(Layout), FusedActivation_(FusedActivation), FusedActivationArgs_(FusedActivationArgs) { |
225 | addResult(Result); |
226 | } |
227 | const NodeValue getInput() const { return Input_; } |
228 | const NodeValue getFilter() const { return Filter_; } |
229 | const NodeValue getBias() const { return Bias_; } |
230 | NodeValue getResult() { return getNthResult(0); } |
231 | const NodeValue getResult() const { return getNthResult(0); } |
232 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
233 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
234 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
235 | void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; } |
236 | unsigned_t getGroup() const { return Group_; } |
237 | void setGroup(unsigned_t a) {Group_ = a; } |
238 | llvm::ArrayRef<unsigned_t> getDilation() const { return Dilation_; } |
239 | glow::ConvolutionLayout getLayout() const { return Layout_; } |
240 | glow::FusedActivation getFusedActivation() const { return FusedActivation_; } |
241 | void setFusedActivation(glow::FusedActivation a) {FusedActivation_ = a; } |
242 | llvm::ArrayRef<float> getFusedActivationArgs() const { return FusedActivationArgs_; } |
243 | void setFusedActivationArgs(llvm::ArrayRef<float> a) {FusedActivationArgs_ = a; } |
244 | |
245 | static bool classof(const Kinded *k) { |
246 | return k->getKind() == Kinded::Kind::ConvolutionNodeKind; |
247 | } |
248 | |
249 | |
250 | bool isOverwrittenNthInput(unsigned idx) const { |
251 | return false; |
252 | } |
253 | |
254 | unsigned getNumInputs() const; |
255 | std::string getInputName(unsigned idx) const; |
256 | NodeValue getNthInput(unsigned idx); |
257 | void setNthInput(unsigned idx, NodeValue val); |
258 | llvm::StringRef getOutputName(unsigned idx) const; |
259 | bool hasSideEffects() const { return 0; } |
260 | bool isCanonical() const { return 1; } |
261 | bool isDataParallel() const { return 0; } |
262 | std::string getDebugDesc() const; |
263 | bool isEqual(const ConvolutionNode &other) const; |
264 | llvm::hash_code getHash() const; |
265 | void visit(Node *parent, NodeWalker *visitor); |
266 | Node* clone() const; |
267 | bool verify() const; |
268 | bool hasFusedActivation() const; ConvolutionGradNode *getGrad(GraphGradMapper &builder); |
269 | }; |
270 | } // namespace glow |
271 | |
272 | |
273 | namespace glow { |
274 | /// Performs 2D Convolution using a given Input, Filter, and Bias tensors, as well as provided Kernels, Strides, Pads, and Group. The filter channel wise quantization parameters are provided by FilterScales and FilterOffsets while the bias channel wise quantization parameters are provided by BiasScales and BiasOffsets. |
275 | class ChannelwiseQuantizedConvolutionNode final : public Node { |
276 | NodeHandle Input_; |
277 | NodeHandle Filter_; |
278 | NodeHandle Bias_; |
279 | NodeHandle FilterScales_; |
280 | NodeHandle FilterOffsets_; |
281 | NodeHandle BiasScales_; |
282 | NodeHandle BiasOffsets_; |
283 | std::vector<unsigned_t> Kernels_; |
284 | std::vector<unsigned_t> Strides_; |
285 | std::vector<unsigned_t> Pads_; |
286 | unsigned_t Group_; |
287 | std::vector<unsigned_t> Dilation_; |
288 | glow::FusedActivation FusedActivation_; |
289 | std::vector<float> FusedActivationArgs_; |
290 | |
291 | public: |
292 | enum InputIndices { |
293 | InputIdx = 0, |
294 | FilterIdx = 1, |
295 | BiasIdx = 2, |
296 | FilterScalesIdx = 3, |
297 | FilterOffsetsIdx = 4, |
298 | BiasScalesIdx = 5, |
299 | BiasOffsetsIdx = 6, |
300 | }; |
301 | |
302 | enum ResultIndices { |
303 | ResultIdx = 0, |
304 | }; |
305 | |
306 | ChannelwiseQuantizedConvolutionNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, NodeValue FilterScales, NodeValue FilterOffsets, NodeValue BiasScales, NodeValue BiasOffsets, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group, std::vector<unsigned_t> Dilation, glow::FusedActivation FusedActivation, std::vector<float> FusedActivationArgs) |
307 | : Node(Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), FilterScales_(this, FilterScales), FilterOffsets_(this, FilterOffsets), BiasScales_(this, BiasScales), BiasOffsets_(this, BiasOffsets), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group), Dilation_(Dilation), FusedActivation_(FusedActivation), FusedActivationArgs_(FusedActivationArgs) { |
308 | addResult(Result); |
309 | } |
310 | const NodeValue getInput() const { return Input_; } |
311 | const NodeValue getFilter() const { return Filter_; } |
312 | const NodeValue getBias() const { return Bias_; } |
313 | const NodeValue getFilterScales() const { return FilterScales_; } |
314 | const NodeValue getFilterOffsets() const { return FilterOffsets_; } |
315 | const NodeValue getBiasScales() const { return BiasScales_; } |
316 | const NodeValue getBiasOffsets() const { return BiasOffsets_; } |
317 | NodeValue getResult() { return getNthResult(0); } |
318 | const NodeValue getResult() const { return getNthResult(0); } |
319 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
320 | void setKernels(llvm::ArrayRef<unsigned_t> a) {Kernels_ = a; } |
321 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
322 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
323 | void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; } |
324 | unsigned_t getGroup() const { return Group_; } |
325 | void setGroup(unsigned_t a) {Group_ = a; } |
326 | llvm::ArrayRef<unsigned_t> getDilation() const { return Dilation_; } |
327 | glow::FusedActivation getFusedActivation() const { return FusedActivation_; } |
328 | void setFusedActivation(glow::FusedActivation a) {FusedActivation_ = a; } |
329 | llvm::ArrayRef<float> getFusedActivationArgs() const { return FusedActivationArgs_; } |
330 | void setFusedActivationArgs(llvm::ArrayRef<float> a) {FusedActivationArgs_ = a; } |
331 | |
332 | static bool classof(const Kinded *k) { |
333 | return k->getKind() == Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind; |
334 | } |
335 | |
336 | |
337 | bool isOverwrittenNthInput(unsigned idx) const { |
338 | return false; |
339 | } |
340 | |
341 | unsigned getNumInputs() const; |
342 | std::string getInputName(unsigned idx) const; |
343 | NodeValue getNthInput(unsigned idx); |
344 | void setNthInput(unsigned idx, NodeValue val); |
345 | llvm::StringRef getOutputName(unsigned idx) const; |
346 | bool hasSideEffects() const { return 0; } |
347 | bool isCanonical() const { return 1; } |
348 | bool isDataParallel() const { return 0; } |
349 | std::string getDebugDesc() const; |
350 | bool isEqual(const ChannelwiseQuantizedConvolutionNode &other) const; |
351 | llvm::hash_code getHash() const; |
352 | void visit(Node *parent, NodeWalker *visitor); |
353 | Node* clone() const; |
354 | bool verify() const; |
355 | bool hasFusedActivation() const;}; |
356 | } // namespace glow |
357 | |
358 | |
359 | namespace glow { |
360 | /// Performs 2D Transposed Convolution using a given Input,Filter, and Bias tensors, as well as provided Kernels,Strides, Pads, and Group. |
361 | class ConvTransposeNode final : public Node { |
362 | NodeHandle Input_; |
363 | NodeHandle Filter_; |
364 | NodeHandle Bias_; |
365 | std::vector<unsigned_t> Kernels_; |
366 | std::vector<unsigned_t> Strides_; |
367 | std::vector<unsigned_t> Pads_; |
368 | unsigned_t Group_; |
369 | std::vector<unsigned_t> Dilation_; |
370 | |
371 | public: |
372 | enum InputIndices { |
373 | InputIdx = 0, |
374 | FilterIdx = 1, |
375 | BiasIdx = 2, |
376 | }; |
377 | |
378 | enum ResultIndices { |
379 | ResultIdx = 0, |
380 | }; |
381 | |
382 | ConvTransposeNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group, std::vector<unsigned_t> Dilation) |
383 | : Node(Kinded::Kind::ConvTransposeNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group), Dilation_(Dilation) { |
384 | addResult(Result); |
385 | } |
386 | const NodeValue getInput() const { return Input_; } |
387 | const NodeValue getFilter() const { return Filter_; } |
388 | const NodeValue getBias() const { return Bias_; } |
389 | NodeValue getResult() { return getNthResult(0); } |
390 | const NodeValue getResult() const { return getNthResult(0); } |
391 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
392 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
393 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
394 | unsigned_t getGroup() const { return Group_; } |
395 | llvm::ArrayRef<unsigned_t> getDilation() const { return Dilation_; } |
396 | |
397 | static bool classof(const Kinded *k) { |
398 | return k->getKind() == Kinded::Kind::ConvTransposeNodeKind; |
399 | } |
400 | |
401 | |
402 | bool isOverwrittenNthInput(unsigned idx) const { |
403 | return false; |
404 | } |
405 | |
406 | unsigned getNumInputs() const; |
407 | std::string getInputName(unsigned idx) const; |
408 | NodeValue getNthInput(unsigned idx); |
409 | void setNthInput(unsigned idx, NodeValue val); |
410 | llvm::StringRef getOutputName(unsigned idx) const; |
411 | bool hasSideEffects() const { return 0; } |
412 | bool isCanonical() const { return 1; } |
413 | bool isDataParallel() const { return 0; } |
414 | std::string getDebugDesc() const; |
415 | bool isEqual(const ConvTransposeNode &other) const; |
416 | llvm::hash_code getHash() const; |
417 | void visit(Node *parent, NodeWalker *visitor); |
418 | Node* clone() const; |
419 | bool verify() const; |
420 | }; |
421 | } // namespace glow |
422 | |
423 | |
424 | namespace glow { |
425 | class Convolution3DGradNode final : public Node { |
426 | NodeHandle Input_; |
427 | NodeHandle Filter_; |
428 | NodeHandle Bias_; |
429 | NodeHandle OriginalOutputForResult_; |
430 | NodeHandle GradOfOriginalOutputNamedResult_; |
431 | std::vector<unsigned_t> Kernels_; |
432 | std::vector<unsigned_t> Strides_; |
433 | std::vector<unsigned_t> Pads_; |
434 | unsigned_t Group_; |
435 | |
436 | public: |
437 | enum InputIndices { |
438 | InputIdx = 0, |
439 | FilterIdx = 1, |
440 | BiasIdx = 2, |
441 | OriginalOutputForResultIdx = 3, |
442 | GradOfOriginalOutputNamedResultIdx = 4, |
443 | }; |
444 | |
445 | enum ResultIndices { |
446 | GradOfInputNamedInputIdx = 0, |
447 | GradOfInputNamedFilterIdx = 1, |
448 | GradOfInputNamedBiasIdx = 2, |
449 | }; |
450 | |
451 | Convolution3DGradNode(llvm::StringRef name, NodeValue Input, NodeValue Filter, NodeValue Bias, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group) |
452 | : Node(Kinded::Kind::Convolution3DGradNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group) { |
453 | addResult(Input.getType()); |
454 | addResult(Filter.getType()); |
455 | addResult(Bias.getType()); |
456 | } |
457 | const NodeValue getInput() const { return Input_; } |
458 | const NodeValue getFilter() const { return Filter_; } |
459 | const NodeValue getBias() const { return Bias_; } |
460 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
461 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
462 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
463 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
464 | NodeValue getGradOfInputNamedFilter() { return getNthResult(1); } |
465 | const NodeValue getGradOfInputNamedFilter() const { return getNthResult(1); } |
466 | NodeValue getGradOfInputNamedBias() { return getNthResult(2); } |
467 | const NodeValue getGradOfInputNamedBias() const { return getNthResult(2); } |
468 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
469 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
470 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
471 | unsigned_t getGroup() const { return Group_; } |
472 | |
473 | static bool classof(const Kinded *k) { |
474 | return k->getKind() == Kinded::Kind::Convolution3DGradNodeKind; |
475 | } |
476 | |
477 | |
478 | bool isOverwrittenNthInput(unsigned idx) const { |
479 | return false; |
480 | } |
481 | |
482 | unsigned getNumInputs() const; |
483 | std::string getInputName(unsigned idx) const; |
484 | NodeValue getNthInput(unsigned idx); |
485 | void setNthInput(unsigned idx, NodeValue val); |
486 | llvm::StringRef getOutputName(unsigned idx) const; |
487 | bool hasSideEffects() const { return 0; } |
488 | bool isCanonical() const { return 1; } |
489 | bool isDataParallel() const { return 0; } |
490 | std::string getDebugDesc() const; |
491 | bool isEqual(const Convolution3DGradNode &other) const; |
492 | llvm::hash_code getHash() const; |
493 | void visit(Node *parent, NodeWalker *visitor); |
494 | Node* clone() const; |
495 | bool verify() const; |
496 | }; |
497 | } // namespace glow |
498 | |
499 | |
500 | namespace glow { |
501 | /// Performs 3D Convolution using a given Input, Filter, and Bias tensors, as well as provided Kernels, Strides, Pads, and Group. |
502 | class Convolution3DNode final : public Node { |
503 | NodeHandle Input_; |
504 | NodeHandle Filter_; |
505 | NodeHandle Bias_; |
506 | std::vector<unsigned_t> Kernels_; |
507 | std::vector<unsigned_t> Strides_; |
508 | std::vector<unsigned_t> Pads_; |
509 | unsigned_t Group_; |
510 | |
511 | public: |
512 | enum InputIndices { |
513 | InputIdx = 0, |
514 | FilterIdx = 1, |
515 | BiasIdx = 2, |
516 | }; |
517 | |
518 | enum ResultIndices { |
519 | ResultIdx = 0, |
520 | }; |
521 | |
522 | Convolution3DNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group) |
523 | : Node(Kinded::Kind::Convolution3DNodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group) { |
524 | addResult(Result); |
525 | } |
526 | const NodeValue getInput() const { return Input_; } |
527 | const NodeValue getFilter() const { return Filter_; } |
528 | const NodeValue getBias() const { return Bias_; } |
529 | NodeValue getResult() { return getNthResult(0); } |
530 | const NodeValue getResult() const { return getNthResult(0); } |
531 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
532 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
533 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
534 | unsigned_t getGroup() const { return Group_; } |
535 | |
536 | static bool classof(const Kinded *k) { |
537 | return k->getKind() == Kinded::Kind::Convolution3DNodeKind; |
538 | } |
539 | |
540 | |
541 | bool isOverwrittenNthInput(unsigned idx) const { |
542 | return false; |
543 | } |
544 | |
545 | unsigned getNumInputs() const; |
546 | std::string getInputName(unsigned idx) const; |
547 | NodeValue getNthInput(unsigned idx); |
548 | void setNthInput(unsigned idx, NodeValue val); |
549 | llvm::StringRef getOutputName(unsigned idx) const; |
550 | bool hasSideEffects() const { return 0; } |
551 | bool isCanonical() const { return 1; } |
552 | bool isDataParallel() const { return 0; } |
553 | std::string getDebugDesc() const; |
554 | bool isEqual(const Convolution3DNode &other) const; |
555 | llvm::hash_code getHash() const; |
556 | void visit(Node *parent, NodeWalker *visitor); |
557 | Node* clone() const; |
558 | bool verify() const; |
559 | Convolution3DGradNode *getGrad(GraphGradMapper &builder); |
560 | }; |
561 | } // namespace glow |
562 | |
563 | |
564 | namespace glow { |
565 | class MaxPoolGradNode final : public Node { |
566 | NodeHandle Input_; |
567 | NodeHandle OriginalOutputForResult_; |
568 | NodeHandle GradOfOriginalOutputNamedResult_; |
569 | NodeHandle OriginalOutputForArgmax_; |
570 | NodeHandle GradOfOriginalOutputNamedArgmax_; |
571 | std::vector<unsigned_t> Kernels_; |
572 | std::vector<unsigned_t> Strides_; |
573 | std::vector<unsigned_t> Pads_; |
574 | unsigned_t Layout_; |
575 | |
576 | public: |
577 | enum InputIndices { |
578 | InputIdx = 0, |
579 | OriginalOutputForResultIdx = 1, |
580 | GradOfOriginalOutputNamedResultIdx = 2, |
581 | OriginalOutputForArgmaxIdx = 3, |
582 | GradOfOriginalOutputNamedArgmaxIdx = 4, |
583 | }; |
584 | |
585 | enum ResultIndices { |
586 | GradOfInputNamedInputIdx = 0, |
587 | }; |
588 | |
589 | MaxPoolGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, NodeValue OriginalOutputForArgmax, NodeValue GradOfOriginalOutputNamedArgmax, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Layout) |
590 | : Node(Kinded::Kind::MaxPoolGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), OriginalOutputForArgmax_(this, OriginalOutputForArgmax), GradOfOriginalOutputNamedArgmax_(this, GradOfOriginalOutputNamedArgmax), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Layout_(Layout) { |
591 | addResult(Input.getType()); |
592 | } |
593 | const NodeValue getInput() const { return Input_; } |
594 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
595 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
596 | const NodeValue getOriginalOutputForArgmax() const { return OriginalOutputForArgmax_; } |
597 | const NodeValue getGradOfOriginalOutputNamedArgmax() const { return GradOfOriginalOutputNamedArgmax_; } |
598 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
599 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
600 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
601 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
602 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
603 | void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; } |
604 | unsigned_t getLayout() const { return Layout_; } |
605 | |
606 | static bool classof(const Kinded *k) { |
607 | return k->getKind() == Kinded::Kind::MaxPoolGradNodeKind; |
608 | } |
609 | |
610 | |
611 | bool isOverwrittenNthInput(unsigned idx) const { |
612 | return false; |
613 | } |
614 | |
615 | unsigned getNumInputs() const; |
616 | std::string getInputName(unsigned idx) const; |
617 | NodeValue getNthInput(unsigned idx); |
618 | void setNthInput(unsigned idx, NodeValue val); |
619 | llvm::StringRef getOutputName(unsigned idx) const; |
620 | bool hasSideEffects() const { return 0; } |
621 | bool isCanonical() const { return 1; } |
622 | bool isDataParallel() const { return 0; } |
623 | std::string getDebugDesc() const; |
624 | bool isEqual(const MaxPoolGradNode &other) const; |
625 | llvm::hash_code getHash() const; |
626 | void visit(Node *parent, NodeWalker *visitor); |
627 | Node* clone() const; |
628 | bool verify() const; |
629 | }; |
630 | } // namespace glow |
631 | |
632 | |
633 | namespace glow { |
634 | /// Performs a Max Pool with Argmax operation on the Input given provided Kernels, Strides, and Pads. Argmax is a flattened index corresponding to respective max element. Supported layouts are defined in the ConvolutionLayout enum: NHWC and NCHW. |
635 | class MaxPoolNode final : public Node { |
636 | NodeHandle Input_; |
637 | std::vector<unsigned_t> Kernels_; |
638 | std::vector<unsigned_t> Strides_; |
639 | std::vector<unsigned_t> Pads_; |
640 | unsigned_t Layout_; |
641 | |
642 | public: |
643 | enum InputIndices { |
644 | InputIdx = 0, |
645 | }; |
646 | |
647 | enum ResultIndices { |
648 | ResultIdx = 0, |
649 | ArgmaxIdx = 1, |
650 | }; |
651 | |
652 | MaxPoolNode(llvm::StringRef name, TypeRef Result , TypeRef Argmax , NodeValue Input, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Layout) |
653 | : Node(Kinded::Kind::MaxPoolNodeKind, name), Input_(this, Input), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Layout_(Layout) { |
654 | addResult(Result); |
655 | addResult(Argmax); |
656 | } |
657 | const NodeValue getInput() const { return Input_; } |
658 | NodeValue getResult() { return getNthResult(0); } |
659 | const NodeValue getResult() const { return getNthResult(0); } |
660 | NodeValue getArgmax() { return getNthResult(1); } |
661 | const NodeValue getArgmax() const { return getNthResult(1); } |
662 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
663 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
664 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
665 | void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; } |
666 | unsigned_t getLayout() const { return Layout_; } |
667 | |
668 | static bool classof(const Kinded *k) { |
669 | return k->getKind() == Kinded::Kind::MaxPoolNodeKind; |
670 | } |
671 | |
672 | |
673 | bool isOverwrittenNthInput(unsigned idx) const { |
674 | return false; |
675 | } |
676 | |
677 | unsigned getNumInputs() const; |
678 | std::string getInputName(unsigned idx) const; |
679 | NodeValue getNthInput(unsigned idx); |
680 | void setNthInput(unsigned idx, NodeValue val); |
681 | llvm::StringRef getOutputName(unsigned idx) const; |
682 | bool hasSideEffects() const { return 0; } |
683 | bool isCanonical() const { return 1; } |
684 | bool isDataParallel() const { return 0; } |
685 | std::string getDebugDesc() const; |
686 | bool isEqual(const MaxPoolNode &other) const; |
687 | llvm::hash_code getHash() const; |
688 | void visit(Node *parent, NodeWalker *visitor); |
689 | Node* clone() const; |
690 | bool verify() const; |
691 | MaxPoolGradNode *getGrad(GraphGradMapper &builder); |
692 | }; |
693 | } // namespace glow |
694 | |
695 | |
696 | namespace glow { |
697 | /// Finds index of a maximum element along Axis. If KeepDims is not true, the axis is removed from output |
698 | class ArgMaxNode final : public Node { |
699 | NodeHandle Input_; |
700 | unsigned_t Axis_; |
701 | bool KeepDims_; |
702 | |
703 | public: |
704 | enum InputIndices { |
705 | InputIdx = 0, |
706 | }; |
707 | |
708 | enum ResultIndices { |
709 | ResultIdx = 0, |
710 | }; |
711 | |
712 | ArgMaxNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis, bool KeepDims) |
713 | : Node(Kinded::Kind::ArgMaxNodeKind, name), Input_(this, Input), Axis_(Axis), KeepDims_(KeepDims) { |
714 | addResult(Result); |
715 | } |
716 | const NodeValue getInput() const { return Input_; } |
717 | NodeValue getResult() { return getNthResult(0); } |
718 | const NodeValue getResult() const { return getNthResult(0); } |
719 | unsigned_t getAxis() const { return Axis_; } |
720 | bool getKeepDims() const { return KeepDims_; } |
721 | |
722 | static bool classof(const Kinded *k) { |
723 | return k->getKind() == Kinded::Kind::ArgMaxNodeKind; |
724 | } |
725 | |
726 | |
727 | bool isOverwrittenNthInput(unsigned idx) const { |
728 | return false; |
729 | } |
730 | |
731 | unsigned getNumInputs() const; |
732 | std::string getInputName(unsigned idx) const; |
733 | NodeValue getNthInput(unsigned idx); |
734 | void setNthInput(unsigned idx, NodeValue val); |
735 | llvm::StringRef getOutputName(unsigned idx) const; |
736 | bool hasSideEffects() const { return 0; } |
737 | bool isCanonical() const { return 1; } |
738 | bool isDataParallel() const { return 0; } |
739 | std::string getDebugDesc() const; |
740 | bool isEqual(const ArgMaxNode &other) const; |
741 | llvm::hash_code getHash() const; |
742 | void visit(Node *parent, NodeWalker *visitor); |
743 | Node* clone() const; |
744 | bool verify() const; |
745 | }; |
746 | } // namespace glow |
747 | |
748 | |
749 | namespace glow { |
750 | /// Finds index of a minimum element along Axis. If KeepDims is not true, the axis is removed from output |
751 | class ArgMinNode final : public Node { |
752 | NodeHandle Input_; |
753 | unsigned_t Axis_; |
754 | bool KeepDims_; |
755 | |
756 | public: |
757 | enum InputIndices { |
758 | InputIdx = 0, |
759 | }; |
760 | |
761 | enum ResultIndices { |
762 | ResultIdx = 0, |
763 | }; |
764 | |
765 | ArgMinNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis, bool KeepDims) |
766 | : Node(Kinded::Kind::ArgMinNodeKind, name), Input_(this, Input), Axis_(Axis), KeepDims_(KeepDims) { |
767 | addResult(Result); |
768 | } |
769 | const NodeValue getInput() const { return Input_; } |
770 | NodeValue getResult() { return getNthResult(0); } |
771 | const NodeValue getResult() const { return getNthResult(0); } |
772 | unsigned_t getAxis() const { return Axis_; } |
773 | bool getKeepDims() const { return KeepDims_; } |
774 | |
775 | static bool classof(const Kinded *k) { |
776 | return k->getKind() == Kinded::Kind::ArgMinNodeKind; |
777 | } |
778 | |
779 | |
780 | bool isOverwrittenNthInput(unsigned idx) const { |
781 | return false; |
782 | } |
783 | |
784 | unsigned getNumInputs() const; |
785 | std::string getInputName(unsigned idx) const; |
786 | NodeValue getNthInput(unsigned idx); |
787 | void setNthInput(unsigned idx, NodeValue val); |
788 | llvm::StringRef getOutputName(unsigned idx) const; |
789 | bool hasSideEffects() const { return 0; } |
790 | bool isCanonical() const { return 1; } |
791 | bool isDataParallel() const { return 0; } |
792 | std::string getDebugDesc() const; |
793 | bool isEqual(const ArgMinNode &other) const; |
794 | llvm::hash_code getHash() const; |
795 | void visit(Node *parent, NodeWalker *visitor); |
796 | Node* clone() const; |
797 | bool verify() const; |
798 | }; |
799 | } // namespace glow |
800 | |
801 | |
802 | namespace glow { |
803 | class AvgPoolGradNode final : public Node { |
804 | NodeHandle Input_; |
805 | NodeHandle OriginalOutputForResult_; |
806 | NodeHandle GradOfOriginalOutputNamedResult_; |
807 | std::vector<unsigned_t> Kernels_; |
808 | std::vector<unsigned_t> Strides_; |
809 | std::vector<unsigned_t> Pads_; |
810 | unsigned_t Layout_; |
811 | bool CountIncludePads_; |
812 | |
813 | public: |
814 | enum InputIndices { |
815 | InputIdx = 0, |
816 | OriginalOutputForResultIdx = 1, |
817 | GradOfOriginalOutputNamedResultIdx = 2, |
818 | }; |
819 | |
820 | enum ResultIndices { |
821 | GradOfInputNamedInputIdx = 0, |
822 | }; |
823 | |
824 | AvgPoolGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Layout, bool CountIncludePads) |
825 | : Node(Kinded::Kind::AvgPoolGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Layout_(Layout), CountIncludePads_(CountIncludePads) { |
826 | addResult(Input.getType()); |
827 | } |
828 | const NodeValue getInput() const { return Input_; } |
829 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
830 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
831 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
832 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
833 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
834 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
835 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
836 | void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; } |
837 | unsigned_t getLayout() const { return Layout_; } |
838 | bool getCountIncludePads() const { return CountIncludePads_; } |
839 | |
840 | static bool classof(const Kinded *k) { |
841 | return k->getKind() == Kinded::Kind::AvgPoolGradNodeKind; |
842 | } |
843 | |
844 | |
845 | bool isOverwrittenNthInput(unsigned idx) const { |
846 | return false; |
847 | } |
848 | |
849 | unsigned getNumInputs() const; |
850 | std::string getInputName(unsigned idx) const; |
851 | NodeValue getNthInput(unsigned idx); |
852 | void setNthInput(unsigned idx, NodeValue val); |
853 | llvm::StringRef getOutputName(unsigned idx) const; |
854 | bool hasSideEffects() const { return 0; } |
855 | bool isCanonical() const { return 1; } |
856 | bool isDataParallel() const { return 0; } |
857 | std::string getDebugDesc() const; |
858 | bool isEqual(const AvgPoolGradNode &other) const; |
859 | llvm::hash_code getHash() const; |
860 | void visit(Node *parent, NodeWalker *visitor); |
861 | Node* clone() const; |
862 | bool verify() const; |
863 | }; |
864 | } // namespace glow |
865 | |
866 | |
867 | namespace glow { |
868 | /// Performs an Average Pool operation on the Input given provided Kernels, Strides, and Pads. Supported layouts are defined in the ConvolutionLayout enum: NHWC, NCHW, NTHWC and NCTHW. |
869 | class AvgPoolNode final : public Node { |
870 | NodeHandle Input_; |
871 | std::vector<unsigned_t> Kernels_; |
872 | std::vector<unsigned_t> Strides_; |
873 | std::vector<unsigned_t> Pads_; |
874 | unsigned_t Layout_; |
875 | bool CountIncludePads_; |
876 | |
877 | public: |
878 | enum InputIndices { |
879 | InputIdx = 0, |
880 | }; |
881 | |
882 | enum ResultIndices { |
883 | ResultIdx = 0, |
884 | }; |
885 | |
886 | AvgPoolNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Layout, bool CountIncludePads) |
887 | : Node(Kinded::Kind::AvgPoolNodeKind, name), Input_(this, Input), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Layout_(Layout), CountIncludePads_(CountIncludePads) { |
888 | addResult(Result); |
889 | } |
890 | const NodeValue getInput() const { return Input_; } |
891 | NodeValue getResult() { return getNthResult(0); } |
892 | const NodeValue getResult() const { return getNthResult(0); } |
893 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
894 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
895 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
896 | void setPads(llvm::ArrayRef<unsigned_t> a) {Pads_ = a; } |
897 | unsigned_t getLayout() const { return Layout_; } |
898 | bool getCountIncludePads() const { return CountIncludePads_; } |
899 | |
900 | static bool classof(const Kinded *k) { |
901 | return k->getKind() == Kinded::Kind::AvgPoolNodeKind; |
902 | } |
903 | |
904 | |
905 | bool isOverwrittenNthInput(unsigned idx) const { |
906 | return false; |
907 | } |
908 | |
909 | unsigned getNumInputs() const; |
910 | std::string getInputName(unsigned idx) const; |
911 | NodeValue getNthInput(unsigned idx); |
912 | void setNthInput(unsigned idx, NodeValue val); |
913 | llvm::StringRef getOutputName(unsigned idx) const; |
914 | bool hasSideEffects() const { return 0; } |
915 | bool isCanonical() const { return 1; } |
916 | bool isDataParallel() const { return 0; } |
917 | std::string getDebugDesc() const; |
918 | bool isEqual(const AvgPoolNode &other) const; |
919 | llvm::hash_code getHash() const; |
920 | void visit(Node *parent, NodeWalker *visitor); |
921 | Node* clone() const; |
922 | bool verify() const; |
923 | AvgPoolGradNode *getGrad(GraphGradMapper &builder); |
924 | }; |
925 | } // namespace glow |
926 | |
927 | |
928 | namespace glow { |
929 | class AdaptiveAvgPoolGradNode final : public Node { |
930 | NodeHandle Input_; |
931 | NodeHandle OriginalOutputForResult_; |
932 | NodeHandle GradOfOriginalOutputNamedResult_; |
933 | |
934 | public: |
935 | enum InputIndices { |
936 | InputIdx = 0, |
937 | OriginalOutputForResultIdx = 1, |
938 | GradOfOriginalOutputNamedResultIdx = 2, |
939 | }; |
940 | |
941 | enum ResultIndices { |
942 | GradOfInputNamedInputIdx = 0, |
943 | }; |
944 | |
945 | AdaptiveAvgPoolGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
946 | : Node(Kinded::Kind::AdaptiveAvgPoolGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
947 | addResult(Input.getType()); |
948 | } |
949 | const NodeValue getInput() const { return Input_; } |
950 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
951 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
952 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
953 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
954 | |
955 | static bool classof(const Kinded *k) { |
956 | return k->getKind() == Kinded::Kind::AdaptiveAvgPoolGradNodeKind; |
957 | } |
958 | |
959 | |
960 | bool isOverwrittenNthInput(unsigned idx) const { |
961 | return false; |
962 | } |
963 | |
964 | unsigned getNumInputs() const; |
965 | std::string getInputName(unsigned idx) const; |
966 | NodeValue getNthInput(unsigned idx); |
967 | void setNthInput(unsigned idx, NodeValue val); |
968 | llvm::StringRef getOutputName(unsigned idx) const; |
969 | bool hasSideEffects() const { return 0; } |
970 | bool isCanonical() const { return 1; } |
971 | bool isDataParallel() const { return 0; } |
972 | std::string getDebugDesc() const; |
973 | bool isEqual(const AdaptiveAvgPoolGradNode &other) const; |
974 | llvm::hash_code getHash() const; |
975 | void visit(Node *parent, NodeWalker *visitor); |
976 | Node* clone() const; |
977 | bool verify() const; |
978 | }; |
979 | } // namespace glow |
980 | |
981 | |
982 | namespace glow { |
983 | /// Performs an Adaptive Average Pool operation on the Input given |
984 | class AdaptiveAvgPoolNode final : public Node { |
985 | NodeHandle Input_; |
986 | |
987 | public: |
988 | enum InputIndices { |
989 | InputIdx = 0, |
990 | }; |
991 | |
992 | enum ResultIndices { |
993 | ResultIdx = 0, |
994 | }; |
995 | |
996 | AdaptiveAvgPoolNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
997 | : Node(Kinded::Kind::AdaptiveAvgPoolNodeKind, name), Input_(this, Input) { |
998 | addResult(Result); |
999 | } |
1000 | const NodeValue getInput() const { return Input_; } |
1001 | NodeValue getResult() { return getNthResult(0); } |
1002 | const NodeValue getResult() const { return getNthResult(0); } |
1003 | |
1004 | static bool classof(const Kinded *k) { |
1005 | return k->getKind() == Kinded::Kind::AdaptiveAvgPoolNodeKind; |
1006 | } |
1007 | |
1008 | |
1009 | bool isOverwrittenNthInput(unsigned idx) const { |
1010 | return false; |
1011 | } |
1012 | |
1013 | unsigned getNumInputs() const; |
1014 | std::string getInputName(unsigned idx) const; |
1015 | NodeValue getNthInput(unsigned idx); |
1016 | void setNthInput(unsigned idx, NodeValue val); |
1017 | llvm::StringRef getOutputName(unsigned idx) const; |
1018 | bool hasSideEffects() const { return 0; } |
1019 | bool isCanonical() const { return 1; } |
1020 | bool isDataParallel() const { return 0; } |
1021 | std::string getDebugDesc() const; |
1022 | bool isEqual(const AdaptiveAvgPoolNode &other) const; |
1023 | llvm::hash_code getHash() const; |
1024 | void visit(Node *parent, NodeWalker *visitor); |
1025 | Node* clone() const; |
1026 | bool verify() const; |
1027 | AdaptiveAvgPoolGradNode *getGrad(GraphGradMapper &builder); |
1028 | }; |
1029 | } // namespace glow |
1030 | |
1031 | |
1032 | namespace glow { |
1033 | /// Computes Y = Alpha * A * B + Beta * C where Alpha, Beta are scalars and A, B, C are matrices. If TransposeA or TransposeB is used then A or B is additionally transposed. |
1034 | class GemmNode final : public Node { |
1035 | NodeHandle A_; |
1036 | NodeHandle B_; |
1037 | NodeHandle C_; |
1038 | float Alpha_; |
1039 | float Beta_; |
1040 | bool TransposeA_; |
1041 | bool TransposeB_; |
1042 | |
1043 | public: |
1044 | enum InputIndices { |
1045 | AIdx = 0, |
1046 | BIdx = 1, |
1047 | CIdx = 2, |
1048 | }; |
1049 | |
1050 | enum ResultIndices { |
1051 | ResultIdx = 0, |
1052 | }; |
1053 | |
1054 | GemmNode(llvm::StringRef name, TypeRef Result , NodeValue A, NodeValue B, NodeValue C, float Alpha, float Beta, bool TransposeA, bool TransposeB) |
1055 | : Node(Kinded::Kind::GemmNodeKind, name), A_(this, A), B_(this, B), C_(this, C), Alpha_(Alpha), Beta_(Beta), TransposeA_(TransposeA), TransposeB_(TransposeB) { |
1056 | addResult(Result); |
1057 | } |
1058 | const NodeValue getA() const { return A_; } |
1059 | const NodeValue getB() const { return B_; } |
1060 | const NodeValue getC() const { return C_; } |
1061 | NodeValue getResult() { return getNthResult(0); } |
1062 | const NodeValue getResult() const { return getNthResult(0); } |
1063 | float getAlpha() const { return Alpha_; } |
1064 | float getBeta() const { return Beta_; } |
1065 | bool getTransposeA() const { return TransposeA_; } |
1066 | bool getTransposeB() const { return TransposeB_; } |
1067 | |
1068 | static bool classof(const Kinded *k) { |
1069 | return k->getKind() == Kinded::Kind::GemmNodeKind; |
1070 | } |
1071 | |
1072 | |
1073 | bool isOverwrittenNthInput(unsigned idx) const { |
1074 | return false; |
1075 | } |
1076 | |
1077 | unsigned getNumInputs() const; |
1078 | std::string getInputName(unsigned idx) const; |
1079 | NodeValue getNthInput(unsigned idx); |
1080 | void setNthInput(unsigned idx, NodeValue val); |
1081 | llvm::StringRef getOutputName(unsigned idx) const; |
1082 | bool hasSideEffects() const { return 0; } |
1083 | bool isCanonical() const { return 1; } |
1084 | bool isDataParallel() const { return 0; } |
1085 | std::string getDebugDesc() const; |
1086 | bool isEqual(const GemmNode &other) const; |
1087 | llvm::hash_code getHash() const; |
1088 | void visit(Node *parent, NodeWalker *visitor); |
1089 | Node* clone() const; |
1090 | bool verify() const; |
1091 | }; |
1092 | } // namespace glow |
1093 | |
1094 | |
1095 | namespace glow { |
1096 | class FullyConnectedGradNode final : public Node { |
1097 | NodeHandle Input_; |
1098 | NodeHandle Weights_; |
1099 | NodeHandle Bias_; |
1100 | NodeHandle OriginalOutputForResult_; |
1101 | NodeHandle GradOfOriginalOutputNamedResult_; |
1102 | |
1103 | public: |
1104 | enum InputIndices { |
1105 | InputIdx = 0, |
1106 | WeightsIdx = 1, |
1107 | BiasIdx = 2, |
1108 | OriginalOutputForResultIdx = 3, |
1109 | GradOfOriginalOutputNamedResultIdx = 4, |
1110 | }; |
1111 | |
1112 | enum ResultIndices { |
1113 | GradOfInputNamedInputIdx = 0, |
1114 | GradOfInputNamedWeightsIdx = 1, |
1115 | GradOfInputNamedBiasIdx = 2, |
1116 | }; |
1117 | |
1118 | FullyConnectedGradNode(llvm::StringRef name, NodeValue Input, NodeValue Weights, NodeValue Bias, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
1119 | : Node(Kinded::Kind::FullyConnectedGradNodeKind, name), Input_(this, Input), Weights_(this, Weights), Bias_(this, Bias), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
1120 | addResult(Input.getType()); |
1121 | addResult(Weights.getType()); |
1122 | addResult(Bias.getType()); |
1123 | } |
1124 | const NodeValue getInput() const { return Input_; } |
1125 | const NodeValue getWeights() const { return Weights_; } |
1126 | const NodeValue getBias() const { return Bias_; } |
1127 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
1128 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
1129 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
1130 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
1131 | NodeValue getGradOfInputNamedWeights() { return getNthResult(1); } |
1132 | const NodeValue getGradOfInputNamedWeights() const { return getNthResult(1); } |
1133 | NodeValue getGradOfInputNamedBias() { return getNthResult(2); } |
1134 | const NodeValue getGradOfInputNamedBias() const { return getNthResult(2); } |
1135 | |
1136 | static bool classof(const Kinded *k) { |
1137 | return k->getKind() == Kinded::Kind::FullyConnectedGradNodeKind; |
1138 | } |
1139 | |
1140 | |
1141 | bool isOverwrittenNthInput(unsigned idx) const { |
1142 | return false; |
1143 | } |
1144 | |
1145 | unsigned getNumInputs() const; |
1146 | std::string getInputName(unsigned idx) const; |
1147 | NodeValue getNthInput(unsigned idx); |
1148 | void setNthInput(unsigned idx, NodeValue val); |
1149 | llvm::StringRef getOutputName(unsigned idx) const; |
1150 | bool hasSideEffects() const { return 0; } |
1151 | bool isCanonical() const { return 1; } |
1152 | bool isDataParallel() const { return 0; } |
1153 | std::string getDebugDesc() const; |
1154 | bool isEqual(const FullyConnectedGradNode &other) const; |
1155 | llvm::hash_code getHash() const; |
1156 | void visit(Node *parent, NodeWalker *visitor); |
1157 | Node* clone() const; |
1158 | bool verify() const; |
1159 | }; |
1160 | } // namespace glow |
1161 | |
1162 | |
1163 | namespace glow { |
1164 | /// Creates a FullyConnected node where the Input tensor and Weights tensor are multiplied, and then the Bias tensor is added to it, producing the Output. |
1165 | class FullyConnectedNode final : public Node { |
1166 | NodeHandle Input_; |
1167 | NodeHandle Weights_; |
1168 | NodeHandle Bias_; |
1169 | |
1170 | public: |
1171 | enum InputIndices { |
1172 | InputIdx = 0, |
1173 | WeightsIdx = 1, |
1174 | BiasIdx = 2, |
1175 | }; |
1176 | |
1177 | enum ResultIndices { |
1178 | ResultIdx = 0, |
1179 | }; |
1180 | |
1181 | FullyConnectedNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Weights, NodeValue Bias) |
1182 | : Node(Kinded::Kind::FullyConnectedNodeKind, name), Input_(this, Input), Weights_(this, Weights), Bias_(this, Bias) { |
1183 | addResult(Result); |
1184 | } |
1185 | const NodeValue getInput() const { return Input_; } |
1186 | const NodeValue getWeights() const { return Weights_; } |
1187 | const NodeValue getBias() const { return Bias_; } |
1188 | NodeValue getResult() { return getNthResult(0); } |
1189 | const NodeValue getResult() const { return getNthResult(0); } |
1190 | |
1191 | static bool classof(const Kinded *k) { |
1192 | return k->getKind() == Kinded::Kind::FullyConnectedNodeKind; |
1193 | } |
1194 | |
1195 | |
1196 | bool isOverwrittenNthInput(unsigned idx) const { |
1197 | return false; |
1198 | } |
1199 | |
1200 | unsigned getNumInputs() const; |
1201 | std::string getInputName(unsigned idx) const; |
1202 | NodeValue getNthInput(unsigned idx); |
1203 | void setNthInput(unsigned idx, NodeValue val); |
1204 | llvm::StringRef getOutputName(unsigned idx) const; |
1205 | bool hasSideEffects() const { return 0; } |
1206 | bool isCanonical() const { return 1; } |
1207 | bool isDataParallel() const { return 0; } |
1208 | std::string getDebugDesc() const; |
1209 | bool isEqual(const FullyConnectedNode &other) const; |
1210 | llvm::hash_code getHash() const; |
1211 | void visit(Node *parent, NodeWalker *visitor); |
1212 | Node* clone() const; |
1213 | bool verify() const; |
1214 | FullyConnectedGradNode *getGrad(GraphGradMapper &builder); |
1215 | }; |
1216 | } // namespace glow |
1217 | |
1218 | |
1219 | namespace glow { |
1220 | /// Creates a RowwiseQuantizedFullyConnected node where the Input matrix and the transpose of Weights matrix are multiplied, and then the Bias vector is broadcast-added to the result. Input, Bias and Result are regularly quantized, while Weights use row-wisequantization. |
1221 | class RowwiseQuantizedFullyConnectedNode final : public Node { |
1222 | NodeHandle Input_; |
1223 | NodeHandle Weights_; |
1224 | NodeHandle Scales_; |
1225 | NodeHandle Offsets_; |
1226 | NodeHandle Bias_; |
1227 | |
1228 | public: |
1229 | enum InputIndices { |
1230 | InputIdx = 0, |
1231 | WeightsIdx = 1, |
1232 | ScalesIdx = 2, |
1233 | OffsetsIdx = 3, |
1234 | BiasIdx = 4, |
1235 | }; |
1236 | |
1237 | enum ResultIndices { |
1238 | ResultIdx = 0, |
1239 | }; |
1240 | |
1241 | RowwiseQuantizedFullyConnectedNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Weights, NodeValue Scales, NodeValue Offsets, NodeValue Bias) |
1242 | : Node(Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind, name), Input_(this, Input), Weights_(this, Weights), Scales_(this, Scales), Offsets_(this, Offsets), Bias_(this, Bias) { |
1243 | addResult(Result); |
1244 | } |
1245 | const NodeValue getInput() const { return Input_; } |
1246 | const NodeValue getWeights() const { return Weights_; } |
1247 | const NodeValue getScales() const { return Scales_; } |
1248 | const NodeValue getOffsets() const { return Offsets_; } |
1249 | const NodeValue getBias() const { return Bias_; } |
1250 | NodeValue getResult() { return getNthResult(0); } |
1251 | const NodeValue getResult() const { return getNthResult(0); } |
1252 | |
1253 | static bool classof(const Kinded *k) { |
1254 | return k->getKind() == Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind; |
1255 | } |
1256 | |
1257 | |
1258 | bool isOverwrittenNthInput(unsigned idx) const { |
1259 | return false; |
1260 | } |
1261 | |
1262 | unsigned getNumInputs() const; |
1263 | std::string getInputName(unsigned idx) const; |
1264 | NodeValue getNthInput(unsigned idx); |
1265 | void setNthInput(unsigned idx, NodeValue val); |
1266 | llvm::StringRef getOutputName(unsigned idx) const; |
1267 | bool hasSideEffects() const { return 0; } |
1268 | bool isCanonical() const { return 1; } |
1269 | bool isDataParallel() const { return 0; } |
1270 | std::string getDebugDesc() const; |
1271 | bool isEqual(const RowwiseQuantizedFullyConnectedNode &other) const; |
1272 | llvm::hash_code getHash() const; |
1273 | void visit(Node *parent, NodeWalker *visitor); |
1274 | Node* clone() const; |
1275 | bool verify() const; |
1276 | }; |
1277 | } // namespace glow |
1278 | |
1279 | |
1280 | namespace glow { |
1281 | /// Creates a DynamicQuantizedFullyConnectedNode which implement the functionality of dynamic_quantization => quantized_fc => dequantize, which support symmteric/asymmetric quantization. Quantize parameters are automatically selected from range of input, while weights are pre-quantized to int8 and bias are whether float or int32 |
1282 | class DynamicQuantizedFullyConnectedNode final : public Node { |
1283 | NodeHandle Input_; |
1284 | NodeHandle Weights_; |
1285 | NodeHandle Bias_; |
1286 | bool IsSymmetric_; |
1287 | bool IsPerBatchElement_; |
1288 | |
1289 | public: |
1290 | enum InputIndices { |
1291 | InputIdx = 0, |
1292 | WeightsIdx = 1, |
1293 | BiasIdx = 2, |
1294 | }; |
1295 | |
1296 | enum ResultIndices { |
1297 | ResultIdx = 0, |
1298 | }; |
1299 | |
1300 | DynamicQuantizedFullyConnectedNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Weights, NodeValue Bias, bool IsSymmetric, bool IsPerBatchElement) |
1301 | : Node(Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind, name), Input_(this, Input), Weights_(this, Weights), Bias_(this, Bias), IsSymmetric_(IsSymmetric), IsPerBatchElement_(IsPerBatchElement) { |
1302 | addResult(Result); |
1303 | } |
1304 | const NodeValue getInput() const { return Input_; } |
1305 | const NodeValue getWeights() const { return Weights_; } |
1306 | const NodeValue getBias() const { return Bias_; } |
1307 | NodeValue getResult() { return getNthResult(0); } |
1308 | const NodeValue getResult() const { return getNthResult(0); } |
1309 | bool getIsSymmetric() const { return IsSymmetric_; } |
1310 | bool getIsPerBatchElement() const { return IsPerBatchElement_; } |
1311 | |
1312 | static bool classof(const Kinded *k) { |
1313 | return k->getKind() == Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind; |
1314 | } |
1315 | |
1316 | |
1317 | bool isOverwrittenNthInput(unsigned idx) const { |
1318 | return false; |
1319 | } |
1320 | |
1321 | unsigned getNumInputs() const; |
1322 | std::string getInputName(unsigned idx) const; |
1323 | NodeValue getNthInput(unsigned idx); |
1324 | void setNthInput(unsigned idx, NodeValue val); |
1325 | llvm::StringRef getOutputName(unsigned idx) const; |
1326 | bool hasSideEffects() const { return 0; } |
1327 | bool isCanonical() const { return 1; } |
1328 | bool isDataParallel() const { return 0; } |
1329 | std::string getDebugDesc() const; |
1330 | bool isEqual(const DynamicQuantizedFullyConnectedNode &other) const; |
1331 | llvm::hash_code getHash() const; |
1332 | void visit(Node *parent, NodeWalker *visitor); |
1333 | Node* clone() const; |
1334 | bool verify() const; |
1335 | }; |
1336 | } // namespace glow |
1337 | |
1338 | |
1339 | namespace glow { |
1340 | /// Creates a DynamicRowwiseQuantizedFullyConnectedNode which implement the functionality of dynamic_quantization => quantized_fc => dequantize, which support symmteric/asymmetric quantization. Quantize parameters are automatically selected from range of input, while weights are pre-rowwise-quantized to int8, whose rowwise params stored in Scales and Offsets, and bias are whether float or int32 |
1341 | class DynamicRowwiseQuantizedFullyConnectedNode final : public Node { |
1342 | NodeHandle Input_; |
1343 | NodeHandle Weights_; |
1344 | NodeHandle Bias_; |
1345 | NodeHandle Scales_; |
1346 | NodeHandle Offsets_; |
1347 | bool IsSymmetric_; |
1348 | bool IsPerBatchElement_; |
1349 | |
1350 | public: |
1351 | enum InputIndices { |
1352 | InputIdx = 0, |
1353 | WeightsIdx = 1, |
1354 | BiasIdx = 2, |
1355 | ScalesIdx = 3, |
1356 | OffsetsIdx = 4, |
1357 | }; |
1358 | |
1359 | enum ResultIndices { |
1360 | ResultIdx = 0, |
1361 | }; |
1362 | |
1363 | DynamicRowwiseQuantizedFullyConnectedNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Weights, NodeValue Bias, NodeValue Scales, NodeValue Offsets, bool IsSymmetric, bool IsPerBatchElement) |
1364 | : Node(Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind, name), Input_(this, Input), Weights_(this, Weights), Bias_(this, Bias), Scales_(this, Scales), Offsets_(this, Offsets), IsSymmetric_(IsSymmetric), IsPerBatchElement_(IsPerBatchElement) { |
1365 | addResult(Result); |
1366 | } |
1367 | const NodeValue getInput() const { return Input_; } |
1368 | const NodeValue getWeights() const { return Weights_; } |
1369 | const NodeValue getBias() const { return Bias_; } |
1370 | const NodeValue getScales() const { return Scales_; } |
1371 | const NodeValue getOffsets() const { return Offsets_; } |
1372 | NodeValue getResult() { return getNthResult(0); } |
1373 | const NodeValue getResult() const { return getNthResult(0); } |
1374 | bool getIsSymmetric() const { return IsSymmetric_; } |
1375 | bool getIsPerBatchElement() const { return IsPerBatchElement_; } |
1376 | |
1377 | static bool classof(const Kinded *k) { |
1378 | return k->getKind() == Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind; |
1379 | } |
1380 | |
1381 | |
1382 | bool isOverwrittenNthInput(unsigned idx) const { |
1383 | return false; |
1384 | } |
1385 | |
1386 | unsigned getNumInputs() const; |
1387 | std::string getInputName(unsigned idx) const; |
1388 | NodeValue getNthInput(unsigned idx); |
1389 | void setNthInput(unsigned idx, NodeValue val); |
1390 | llvm::StringRef getOutputName(unsigned idx) const; |
1391 | bool hasSideEffects() const { return 0; } |
1392 | bool isCanonical() const { return 1; } |
1393 | bool isDataParallel() const { return 0; } |
1394 | std::string getDebugDesc() const; |
1395 | bool isEqual(const DynamicRowwiseQuantizedFullyConnectedNode &other) const; |
1396 | llvm::hash_code getHash() const; |
1397 | void visit(Node *parent, NodeWalker *visitor); |
1398 | Node* clone() const; |
1399 | bool verify() const; |
1400 | }; |
1401 | } // namespace glow |
1402 | |
1403 | |
1404 | namespace glow { |
1405 | class BatchNormalizationGradNode final : public Node { |
1406 | NodeHandle Input_; |
1407 | NodeHandle Scale_; |
1408 | NodeHandle Bias_; |
1409 | NodeHandle Mean_; |
1410 | NodeHandle Var_; |
1411 | NodeHandle OriginalOutputForResult_; |
1412 | NodeHandle GradOfOriginalOutputNamedResult_; |
1413 | unsigned_t ChannelIdx_; |
1414 | float Epsilon_; |
1415 | float Momentum_; |
1416 | |
1417 | public: |
1418 | enum InputIndices { |
1419 | InputIdx = 0, |
1420 | ScaleIdx = 1, |
1421 | BiasIdx = 2, |
1422 | MeanIdx = 3, |
1423 | VarIdx = 4, |
1424 | OriginalOutputForResultIdx = 5, |
1425 | GradOfOriginalOutputNamedResultIdx = 6, |
1426 | }; |
1427 | |
1428 | enum ResultIndices { |
1429 | GradOfInputNamedInputIdx = 0, |
1430 | GradOfInputNamedScaleIdx = 1, |
1431 | GradOfInputNamedBiasIdx = 2, |
1432 | GradOfInputNamedMeanIdx = 3, |
1433 | GradOfInputNamedVarIdx = 4, |
1434 | }; |
1435 | |
1436 | BatchNormalizationGradNode(llvm::StringRef name, NodeValue Input, NodeValue Scale, NodeValue Bias, NodeValue Mean, NodeValue Var, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, unsigned_t ChannelIdx, float Epsilon, float Momentum) |
1437 | : Node(Kinded::Kind::BatchNormalizationGradNodeKind, name), Input_(this, Input), Scale_(this, Scale), Bias_(this, Bias), Mean_(this, Mean), Var_(this, Var), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), ChannelIdx_(ChannelIdx), Epsilon_(Epsilon), Momentum_(Momentum) { |
1438 | addResult(Input.getType()); |
1439 | addResult(Scale.getType()); |
1440 | addResult(Bias.getType()); |
1441 | addResult(Mean.getType()); |
1442 | addResult(Var.getType()); |
1443 | } |
1444 | const NodeValue getInput() const { return Input_; } |
1445 | const NodeValue getScale() const { return Scale_; } |
1446 | const NodeValue getBias() const { return Bias_; } |
1447 | const NodeValue getMean() const { return Mean_; } |
1448 | const NodeValue getVar() const { return Var_; } |
1449 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
1450 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
1451 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
1452 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
1453 | NodeValue getGradOfInputNamedScale() { return getNthResult(1); } |
1454 | const NodeValue getGradOfInputNamedScale() const { return getNthResult(1); } |
1455 | NodeValue getGradOfInputNamedBias() { return getNthResult(2); } |
1456 | const NodeValue getGradOfInputNamedBias() const { return getNthResult(2); } |
1457 | NodeValue getGradOfInputNamedMean() { return getNthResult(3); } |
1458 | const NodeValue getGradOfInputNamedMean() const { return getNthResult(3); } |
1459 | NodeValue getGradOfInputNamedVar() { return getNthResult(4); } |
1460 | const NodeValue getGradOfInputNamedVar() const { return getNthResult(4); } |
1461 | unsigned_t getChannelIdx() const { return ChannelIdx_; } |
1462 | float getEpsilon() const { return Epsilon_; } |
1463 | float getMomentum() const { return Momentum_; } |
1464 | |
1465 | static bool classof(const Kinded *k) { |
1466 | return k->getKind() == Kinded::Kind::BatchNormalizationGradNodeKind; |
1467 | } |
1468 | |
1469 | |
1470 | bool isOverwrittenNthInput(unsigned idx) const { |
1471 | return false; |
1472 | } |
1473 | |
1474 | unsigned getNumInputs() const; |
1475 | std::string getInputName(unsigned idx) const; |
1476 | NodeValue getNthInput(unsigned idx); |
1477 | void setNthInput(unsigned idx, NodeValue val); |
1478 | llvm::StringRef getOutputName(unsigned idx) const; |
1479 | bool hasSideEffects() const { return 0; } |
1480 | bool isCanonical() const { return 1; } |
1481 | bool isDataParallel() const { return 0; } |
1482 | std::string getDebugDesc() const; |
1483 | bool isEqual(const BatchNormalizationGradNode &other) const; |
1484 | llvm::hash_code getHash() const; |
1485 | void visit(Node *parent, NodeWalker *visitor); |
1486 | Node* clone() const; |
1487 | bool verify() const; |
1488 | }; |
1489 | } // namespace glow |
1490 | |
1491 | |
1492 | namespace glow { |
1493 | /// Performs batch normalization on the Input tensor with the provided Scale, Bias, Mean, Var, ChannelIdx, Epsilon, and Momentum. Similar to Caffe2 SpatialBN, and ONNX BatchNormalization operator. |
1494 | class BatchNormalizationNode final : public Node { |
1495 | NodeHandle Input_; |
1496 | NodeHandle Scale_; |
1497 | NodeHandle Bias_; |
1498 | NodeHandle Mean_; |
1499 | NodeHandle Var_; |
1500 | unsigned_t ChannelIdx_; |
1501 | float Epsilon_; |
1502 | float Momentum_; |
1503 | |
1504 | public: |
1505 | enum InputIndices { |
1506 | InputIdx = 0, |
1507 | ScaleIdx = 1, |
1508 | BiasIdx = 2, |
1509 | MeanIdx = 3, |
1510 | VarIdx = 4, |
1511 | }; |
1512 | |
1513 | enum ResultIndices { |
1514 | ResultIdx = 0, |
1515 | }; |
1516 | |
1517 | BatchNormalizationNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Scale, NodeValue Bias, NodeValue Mean, NodeValue Var, unsigned_t ChannelIdx, float Epsilon, float Momentum) |
1518 | : Node(Kinded::Kind::BatchNormalizationNodeKind, name), Input_(this, Input), Scale_(this, Scale), Bias_(this, Bias), Mean_(this, Mean), Var_(this, Var), ChannelIdx_(ChannelIdx), Epsilon_(Epsilon), Momentum_(Momentum) { |
1519 | addResult(Result); |
1520 | } |
1521 | const NodeValue getInput() const { return Input_; } |
1522 | const NodeValue getScale() const { return Scale_; } |
1523 | const NodeValue getBias() const { return Bias_; } |
1524 | const NodeValue getMean() const { return Mean_; } |
1525 | const NodeValue getVar() const { return Var_; } |
1526 | NodeValue getResult() { return getNthResult(0); } |
1527 | const NodeValue getResult() const { return getNthResult(0); } |
1528 | unsigned_t getChannelIdx() const { return ChannelIdx_; } |
1529 | float getEpsilon() const { return Epsilon_; } |
1530 | float getMomentum() const { return Momentum_; } |
1531 | |
1532 | static bool classof(const Kinded *k) { |
1533 | return k->getKind() == Kinded::Kind::BatchNormalizationNodeKind; |
1534 | } |
1535 | |
1536 | |
1537 | bool isOverwrittenNthInput(unsigned idx) const { |
1538 | return false; |
1539 | } |
1540 | |
1541 | unsigned getNumInputs() const; |
1542 | std::string getInputName(unsigned idx) const; |
1543 | NodeValue getNthInput(unsigned idx); |
1544 | void setNthInput(unsigned idx, NodeValue val); |
1545 | llvm::StringRef getOutputName(unsigned idx) const; |
1546 | bool hasSideEffects() const { return 0; } |
1547 | bool isCanonical() const { return 1; } |
1548 | bool isDataParallel() const { return 0; } |
1549 | std::string getDebugDesc() const; |
1550 | bool isEqual(const BatchNormalizationNode &other) const; |
1551 | llvm::hash_code getHash() const; |
1552 | void visit(Node *parent, NodeWalker *visitor); |
1553 | Node* clone() const; |
1554 | bool verify() const; |
1555 | BatchNormalizationGradNode *getGrad(GraphGradMapper &builder); |
1556 | }; |
1557 | } // namespace glow |
1558 | |
1559 | |
1560 | namespace glow { |
1561 | /// Performs instance normalization on the Input tensor with the provided Scale, Bias, Epsilon. Similar to ONNX InstanceNormalization operator. |
1562 | class InstanceNormalizationNode final : public Node { |
1563 | NodeHandle Input_; |
1564 | NodeHandle Scale_; |
1565 | NodeHandle Bias_; |
1566 | unsigned_t ChannelIdx_; |
1567 | float Epsilon_; |
1568 | |
1569 | public: |
1570 | enum InputIndices { |
1571 | InputIdx = 0, |
1572 | ScaleIdx = 1, |
1573 | BiasIdx = 2, |
1574 | }; |
1575 | |
1576 | enum ResultIndices { |
1577 | ResultIdx = 0, |
1578 | }; |
1579 | |
1580 | InstanceNormalizationNode(llvm::StringRef name, NodeValue Input, NodeValue Scale, NodeValue Bias, unsigned_t ChannelIdx, float Epsilon) |
1581 | : Node(Kinded::Kind::InstanceNormalizationNodeKind, name), Input_(this, Input), Scale_(this, Scale), Bias_(this, Bias), ChannelIdx_(ChannelIdx), Epsilon_(Epsilon) { |
1582 | addResult(Input.getType()); |
1583 | } |
1584 | const NodeValue getInput() const { return Input_; } |
1585 | const NodeValue getScale() const { return Scale_; } |
1586 | const NodeValue getBias() const { return Bias_; } |
1587 | NodeValue getResult() { return getNthResult(0); } |
1588 | const NodeValue getResult() const { return getNthResult(0); } |
1589 | unsigned_t getChannelIdx() const { return ChannelIdx_; } |
1590 | float getEpsilon() const { return Epsilon_; } |
1591 | |
1592 | static bool classof(const Kinded *k) { |
1593 | return k->getKind() == Kinded::Kind::InstanceNormalizationNodeKind; |
1594 | } |
1595 | |
1596 | |
1597 | bool isOverwrittenNthInput(unsigned idx) const { |
1598 | return false; |
1599 | } |
1600 | |
1601 | unsigned getNumInputs() const; |
1602 | std::string getInputName(unsigned idx) const; |
1603 | NodeValue getNthInput(unsigned idx); |
1604 | void setNthInput(unsigned idx, NodeValue val); |
1605 | llvm::StringRef getOutputName(unsigned idx) const; |
1606 | bool hasSideEffects() const { return 0; } |
1607 | bool isCanonical() const { return 1; } |
1608 | bool isDataParallel() const { return 0; } |
1609 | std::string getDebugDesc() const; |
1610 | bool isEqual(const InstanceNormalizationNode &other) const; |
1611 | llvm::hash_code getHash() const; |
1612 | void visit(Node *parent, NodeWalker *visitor); |
1613 | Node* clone() const; |
1614 | bool verify() const; |
1615 | }; |
1616 | } // namespace glow |
1617 | |
1618 | |
1619 | namespace glow { |
1620 | /// Calculates new normalized mean and variance based on the input mean, variance, and input. |
1621 | class MeanVarNormalizationNode final : public Node { |
1622 | NodeHandle Input_; |
1623 | NodeHandle Mean_; |
1624 | NodeHandle Var_; |
1625 | unsigned_t ChannelIdx_; |
1626 | float Momentum_; |
1627 | |
1628 | public: |
1629 | enum InputIndices { |
1630 | InputIdx = 0, |
1631 | MeanIdx = 1, |
1632 | VarIdx = 2, |
1633 | }; |
1634 | |
1635 | enum ResultIndices { |
1636 | NewMeanIdx = 0, |
1637 | NewVarIdx = 1, |
1638 | }; |
1639 | |
1640 | MeanVarNormalizationNode(llvm::StringRef name, NodeValue Input, NodeValue Mean, NodeValue Var, unsigned_t ChannelIdx, float Momentum) |
1641 | : Node(Kinded::Kind::MeanVarNormalizationNodeKind, name), Input_(this, Input), Mean_(this, Mean), Var_(this, Var), ChannelIdx_(ChannelIdx), Momentum_(Momentum) { |
1642 | addResult(Mean.getType()); |
1643 | addResult(Var.getType()); |
1644 | } |
1645 | const NodeValue getInput() const { return Input_; } |
1646 | const NodeValue getMean() const { return Mean_; } |
1647 | const NodeValue getVar() const { return Var_; } |
1648 | NodeValue getNewMean() { return getNthResult(0); } |
1649 | const NodeValue getNewMean() const { return getNthResult(0); } |
1650 | NodeValue getNewVar() { return getNthResult(1); } |
1651 | const NodeValue getNewVar() const { return getNthResult(1); } |
1652 | unsigned_t getChannelIdx() const { return ChannelIdx_; } |
1653 | float getMomentum() const { return Momentum_; } |
1654 | |
1655 | static bool classof(const Kinded *k) { |
1656 | return k->getKind() == Kinded::Kind::MeanVarNormalizationNodeKind; |
1657 | } |
1658 | |
1659 | |
1660 | bool isOverwrittenNthInput(unsigned idx) const { |
1661 | return false; |
1662 | } |
1663 | |
1664 | unsigned getNumInputs() const; |
1665 | std::string getInputName(unsigned idx) const; |
1666 | NodeValue getNthInput(unsigned idx); |
1667 | void setNthInput(unsigned idx, NodeValue val); |
1668 | llvm::StringRef getOutputName(unsigned idx) const; |
1669 | bool hasSideEffects() const { return 0; } |
1670 | bool isCanonical() const { return 1; } |
1671 | bool isDataParallel() const { return 0; } |
1672 | std::string getDebugDesc() const; |
1673 | bool isEqual(const MeanVarNormalizationNode &other) const; |
1674 | llvm::hash_code getHash() const; |
1675 | void visit(Node *parent, NodeWalker *visitor); |
1676 | Node* clone() const; |
1677 | bool verify() const; |
1678 | }; |
1679 | } // namespace glow |
1680 | |
1681 | |
1682 | namespace glow { |
1683 | class LocalResponseNormalizationGradNode final : public Node { |
1684 | NodeHandle Input_; |
1685 | NodeHandle OriginalOutputForResult_; |
1686 | NodeHandle GradOfOriginalOutputNamedResult_; |
1687 | unsigned_t HalfWindowSize_; |
1688 | float Alpha_; |
1689 | float Beta_; |
1690 | float K_; |
1691 | |
1692 | public: |
1693 | enum InputIndices { |
1694 | InputIdx = 0, |
1695 | OriginalOutputForResultIdx = 1, |
1696 | GradOfOriginalOutputNamedResultIdx = 2, |
1697 | }; |
1698 | |
1699 | enum ResultIndices { |
1700 | GradOfInputNamedInputIdx = 0, |
1701 | }; |
1702 | |
1703 | LocalResponseNormalizationGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, unsigned_t HalfWindowSize, float Alpha, float Beta, float K) |
1704 | : Node(Kinded::Kind::LocalResponseNormalizationGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), HalfWindowSize_(HalfWindowSize), Alpha_(Alpha), Beta_(Beta), K_(K) { |
1705 | addResult(Input.getType()); |
1706 | } |
1707 | const NodeValue getInput() const { return Input_; } |
1708 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
1709 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
1710 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
1711 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
1712 | unsigned_t getHalfWindowSize() const { return HalfWindowSize_; } |
1713 | float getAlpha() const { return Alpha_; } |
1714 | float getBeta() const { return Beta_; } |
1715 | float getK() const { return K_; } |
1716 | |
1717 | static bool classof(const Kinded *k) { |
1718 | return k->getKind() == Kinded::Kind::LocalResponseNormalizationGradNodeKind; |
1719 | } |
1720 | |
1721 | |
1722 | bool isOverwrittenNthInput(unsigned idx) const { |
1723 | return false; |
1724 | } |
1725 | |
1726 | unsigned getNumInputs() const; |
1727 | std::string getInputName(unsigned idx) const; |
1728 | NodeValue getNthInput(unsigned idx); |
1729 | void setNthInput(unsigned idx, NodeValue val); |
1730 | llvm::StringRef getOutputName(unsigned idx) const; |
1731 | bool hasSideEffects() const { return 0; } |
1732 | bool isCanonical() const { return 1; } |
1733 | bool isDataParallel() const { return 0; } |
1734 | std::string getDebugDesc() const; |
1735 | bool isEqual(const LocalResponseNormalizationGradNode &other) const; |
1736 | llvm::hash_code getHash() const; |
1737 | void visit(Node *parent, NodeWalker *visitor); |
1738 | Node* clone() const; |
1739 | bool verify() const; |
1740 | }; |
1741 | } // namespace glow |
1742 | |
1743 | |
1744 | namespace glow { |
1745 | /// Performs local response normalization on the Input tensor with the provided Scale, Bias, Mean, Var, ChannelIdx, Epsilon, and Momentum. Similar to Caffe2 and ONNX LRN. |
1746 | class LocalResponseNormalizationNode final : public Node { |
1747 | NodeHandle Input_; |
1748 | unsigned_t HalfWindowSize_; |
1749 | float Alpha_; |
1750 | float Beta_; |
1751 | float K_; |
1752 | |
1753 | public: |
1754 | enum InputIndices { |
1755 | InputIdx = 0, |
1756 | }; |
1757 | |
1758 | enum ResultIndices { |
1759 | ResultIdx = 0, |
1760 | }; |
1761 | |
1762 | LocalResponseNormalizationNode(llvm::StringRef name, NodeValue Input, unsigned_t HalfWindowSize, float Alpha, float Beta, float K) |
1763 | : Node(Kinded::Kind::LocalResponseNormalizationNodeKind, name), Input_(this, Input), HalfWindowSize_(HalfWindowSize), Alpha_(Alpha), Beta_(Beta), K_(K) { |
1764 | addResult(Input.getType()); |
1765 | } |
1766 | const NodeValue getInput() const { return Input_; } |
1767 | NodeValue getResult() { return getNthResult(0); } |
1768 | const NodeValue getResult() const { return getNthResult(0); } |
1769 | unsigned_t getHalfWindowSize() const { return HalfWindowSize_; } |
1770 | float getAlpha() const { return Alpha_; } |
1771 | float getBeta() const { return Beta_; } |
1772 | float getK() const { return K_; } |
1773 | |
1774 | static bool classof(const Kinded *k) { |
1775 | return k->getKind() == Kinded::Kind::LocalResponseNormalizationNodeKind; |
1776 | } |
1777 | |
1778 | |
1779 | bool isOverwrittenNthInput(unsigned idx) const { |
1780 | return false; |
1781 | } |
1782 | |
1783 | unsigned getNumInputs() const; |
1784 | std::string getInputName(unsigned idx) const; |
1785 | NodeValue getNthInput(unsigned idx); |
1786 | void setNthInput(unsigned idx, NodeValue val); |
1787 | llvm::StringRef getOutputName(unsigned idx) const; |
1788 | bool hasSideEffects() const { return 0; } |
1789 | bool isCanonical() const { return 1; } |
1790 | bool isDataParallel() const { return 0; } |
1791 | std::string getDebugDesc() const; |
1792 | bool isEqual(const LocalResponseNormalizationNode &other) const; |
1793 | llvm::hash_code getHash() const; |
1794 | void visit(Node *parent, NodeWalker *visitor); |
1795 | Node* clone() const; |
1796 | bool verify() const; |
1797 | LocalResponseNormalizationGradNode *getGrad(GraphGradMapper &builder); |
1798 | }; |
1799 | } // namespace glow |
1800 | |
1801 | |
1802 | namespace glow { |
1803 | /// Performs layer normalization on the Input tensor with the provided Scale, Bias, and Epsilon. Layer sizes are determined by the dimensions of Scale and Bias. Similar to PyTorch layer_norm. |
1804 | class LayerNormalizationNode final : public Node { |
1805 | NodeHandle Input_; |
1806 | NodeHandle Scale_; |
1807 | NodeHandle Bias_; |
1808 | float Epsilon_; |
1809 | |
1810 | public: |
1811 | enum InputIndices { |
1812 | InputIdx = 0, |
1813 | ScaleIdx = 1, |
1814 | BiasIdx = 2, |
1815 | }; |
1816 | |
1817 | enum ResultIndices { |
1818 | ResultIdx = 0, |
1819 | }; |
1820 | |
1821 | LayerNormalizationNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Scale, NodeValue Bias, float Epsilon) |
1822 | : Node(Kinded::Kind::LayerNormalizationNodeKind, name), Input_(this, Input), Scale_(this, Scale), Bias_(this, Bias), Epsilon_(Epsilon) { |
1823 | addResult(Result); |
1824 | } |
1825 | const NodeValue getInput() const { return Input_; } |
1826 | const NodeValue getScale() const { return Scale_; } |
1827 | const NodeValue getBias() const { return Bias_; } |
1828 | NodeValue getResult() { return getNthResult(0); } |
1829 | const NodeValue getResult() const { return getNthResult(0); } |
1830 | float getEpsilon() const { return Epsilon_; } |
1831 | |
1832 | static bool classof(const Kinded *k) { |
1833 | return k->getKind() == Kinded::Kind::LayerNormalizationNodeKind; |
1834 | } |
1835 | |
1836 | |
1837 | bool isOverwrittenNthInput(unsigned idx) const { |
1838 | return false; |
1839 | } |
1840 | |
1841 | unsigned getNumInputs() const; |
1842 | std::string getInputName(unsigned idx) const; |
1843 | NodeValue getNthInput(unsigned idx); |
1844 | void setNthInput(unsigned idx, NodeValue val); |
1845 | llvm::StringRef getOutputName(unsigned idx) const; |
1846 | bool hasSideEffects() const { return 0; } |
1847 | bool isCanonical() const { return 1; } |
1848 | bool isDataParallel() const { return 0; } |
1849 | std::string getDebugDesc() const; |
1850 | bool isEqual(const LayerNormalizationNode &other) const; |
1851 | llvm::hash_code getHash() const; |
1852 | void visit(Node *parent, NodeWalker *visitor); |
1853 | Node* clone() const; |
1854 | bool verify() const; |
1855 | }; |
1856 | } // namespace glow |
1857 | |
1858 | |
1859 | namespace glow { |
1860 | /// Apply box-cox transform for each column for each column in NxD input tensor |
1861 | class BatchBoxCoxNode final : public Node { |
1862 | NodeHandle Input_; |
1863 | NodeHandle Lambda1_; |
1864 | NodeHandle Lambda2_; |
1865 | float Epsilon_; |
1866 | |
1867 | public: |
1868 | enum InputIndices { |
1869 | InputIdx = 0, |
1870 | Lambda1Idx = 1, |
1871 | Lambda2Idx = 2, |
1872 | }; |
1873 | |
1874 | enum ResultIndices { |
1875 | ResultIdx = 0, |
1876 | }; |
1877 | |
1878 | BatchBoxCoxNode(llvm::StringRef name, NodeValue Input, NodeValue Lambda1, NodeValue Lambda2, float Epsilon) |
1879 | : Node(Kinded::Kind::BatchBoxCoxNodeKind, name), Input_(this, Input), Lambda1_(this, Lambda1), Lambda2_(this, Lambda2), Epsilon_(Epsilon) { |
1880 | addResult(Input.getType()); |
1881 | } |
1882 | const NodeValue getInput() const { return Input_; } |
1883 | const NodeValue getLambda1() const { return Lambda1_; } |
1884 | const NodeValue getLambda2() const { return Lambda2_; } |
1885 | NodeValue getResult() { return getNthResult(0); } |
1886 | const NodeValue getResult() const { return getNthResult(0); } |
1887 | float getEpsilon() const { return Epsilon_; } |
1888 | |
1889 | static bool classof(const Kinded *k) { |
1890 | return k->getKind() == Kinded::Kind::BatchBoxCoxNodeKind; |
1891 | } |
1892 | |
1893 | |
1894 | bool isOverwrittenNthInput(unsigned idx) const { |
1895 | return false; |
1896 | } |
1897 | |
1898 | unsigned getNumInputs() const; |
1899 | std::string getInputName(unsigned idx) const; |
1900 | NodeValue getNthInput(unsigned idx); |
1901 | void setNthInput(unsigned idx, NodeValue val); |
1902 | llvm::StringRef getOutputName(unsigned idx) const; |
1903 | bool hasSideEffects() const { return 0; } |
1904 | bool isCanonical() const { return 1; } |
1905 | bool isDataParallel() const { return 0; } |
1906 | std::string getDebugDesc() const; |
1907 | bool isEqual(const BatchBoxCoxNode &other) const; |
1908 | llvm::hash_code getHash() const; |
1909 | void visit(Node *parent, NodeWalker *visitor); |
1910 | Node* clone() const; |
1911 | bool verify() const; |
1912 | }; |
1913 | } // namespace glow |
1914 | |
1915 | |
1916 | namespace glow { |
1917 | /// Performs L2 norm of the Input operand based on Axis. |
1918 | class VectorNormNode final : public Node { |
1919 | NodeHandle Input_; |
1920 | unsigned_t Axis_; |
1921 | unsigned_t P_; |
1922 | |
1923 | public: |
1924 | enum InputIndices { |
1925 | InputIdx = 0, |
1926 | }; |
1927 | |
1928 | enum ResultIndices { |
1929 | ResultIdx = 0, |
1930 | }; |
1931 | |
1932 | VectorNormNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis, unsigned_t P) |
1933 | : Node(Kinded::Kind::VectorNormNodeKind, name), Input_(this, Input), Axis_(Axis), P_(P) { |
1934 | addResult(Result); |
1935 | } |
1936 | const NodeValue getInput() const { return Input_; } |
1937 | NodeValue getResult() { return getNthResult(0); } |
1938 | const NodeValue getResult() const { return getNthResult(0); } |
1939 | unsigned_t getAxis() const { return Axis_; } |
1940 | unsigned_t getP() const { return P_; } |
1941 | |
1942 | static bool classof(const Kinded *k) { |
1943 | return k->getKind() == Kinded::Kind::VectorNormNodeKind; |
1944 | } |
1945 | |
1946 | |
1947 | bool isOverwrittenNthInput(unsigned idx) const { |
1948 | return false; |
1949 | } |
1950 | |
1951 | unsigned getNumInputs() const; |
1952 | std::string getInputName(unsigned idx) const; |
1953 | NodeValue getNthInput(unsigned idx); |
1954 | void setNthInput(unsigned idx, NodeValue val); |
1955 | llvm::StringRef getOutputName(unsigned idx) const; |
1956 | bool hasSideEffects() const { return 0; } |
1957 | bool isCanonical() const { return 1; } |
1958 | bool isDataParallel() const { return 0; } |
1959 | std::string getDebugDesc() const; |
1960 | bool isEqual(const VectorNormNode &other) const; |
1961 | llvm::hash_code getHash() const; |
1962 | void visit(Node *parent, NodeWalker *visitor); |
1963 | Node* clone() const; |
1964 | bool verify() const; |
1965 | }; |
1966 | } // namespace glow |
1967 | |
1968 | |
1969 | namespace glow { |
1970 | /// Performs bucketization on the input given Boundaries |
1971 | class BucketizeNode final : public Node { |
1972 | NodeHandle Input_; |
1973 | std::vector<float> Boundaries_; |
1974 | |
1975 | public: |
1976 | enum InputIndices { |
1977 | InputIdx = 0, |
1978 | }; |
1979 | |
1980 | enum ResultIndices { |
1981 | ResultIdx = 0, |
1982 | }; |
1983 | |
1984 | BucketizeNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<float> Boundaries) |
1985 | : Node(Kinded::Kind::BucketizeNodeKind, name), Input_(this, Input), Boundaries_(Boundaries) { |
1986 | addResult(Result); |
1987 | } |
1988 | const NodeValue getInput() const { return Input_; } |
1989 | NodeValue getResult() { return getNthResult(0); } |
1990 | const NodeValue getResult() const { return getNthResult(0); } |
1991 | llvm::ArrayRef<float> getBoundaries() const { return Boundaries_; } |
1992 | |
1993 | static bool classof(const Kinded *k) { |
1994 | return k->getKind() == Kinded::Kind::BucketizeNodeKind; |
1995 | } |
1996 | |
1997 | |
1998 | bool isOverwrittenNthInput(unsigned idx) const { |
1999 | return false; |
2000 | } |
2001 | |
2002 | unsigned getNumInputs() const; |
2003 | std::string getInputName(unsigned idx) const; |
2004 | NodeValue getNthInput(unsigned idx); |
2005 | void setNthInput(unsigned idx, NodeValue val); |
2006 | llvm::StringRef getOutputName(unsigned idx) const; |
2007 | bool hasSideEffects() const { return 0; } |
2008 | bool isCanonical() const { return 1; } |
2009 | bool isDataParallel() const { return 0; } |
2010 | std::string getDebugDesc() const; |
2011 | bool isEqual(const BucketizeNode &other) const; |
2012 | llvm::hash_code getHash() const; |
2013 | void visit(Node *parent, NodeWalker *visitor); |
2014 | Node* clone() const; |
2015 | bool verify() const; |
2016 | }; |
2017 | } // namespace glow |
2018 | |
2019 | |
2020 | namespace glow { |
2021 | class SoftMaxGradNode final : public Node { |
2022 | NodeHandle Input_; |
2023 | NodeHandle Selected_; |
2024 | NodeHandle OriginalOutputForResult_; |
2025 | NodeHandle GradOfOriginalOutputNamedResult_; |
2026 | |
2027 | public: |
2028 | enum InputIndices { |
2029 | InputIdx = 0, |
2030 | SelectedIdx = 1, |
2031 | OriginalOutputForResultIdx = 2, |
2032 | GradOfOriginalOutputNamedResultIdx = 3, |
2033 | }; |
2034 | |
2035 | enum ResultIndices { |
2036 | GradOfInputNamedInputIdx = 0, |
2037 | GradOfInputNamedSelectedIdx = 1, |
2038 | }; |
2039 | |
2040 | SoftMaxGradNode(llvm::StringRef name, NodeValue Input, NodeValue Selected, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
2041 | : Node(Kinded::Kind::SoftMaxGradNodeKind, name), Input_(this, Input), Selected_(this, Selected), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
2042 | addResult(Input.getType()); |
2043 | addResult(Selected.getType()); |
2044 | } |
2045 | const NodeValue getInput() const { return Input_; } |
2046 | const NodeValue getSelected() const { return Selected_; } |
2047 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
2048 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
2049 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
2050 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
2051 | NodeValue getGradOfInputNamedSelected() { return getNthResult(1); } |
2052 | const NodeValue getGradOfInputNamedSelected() const { return getNthResult(1); } |
2053 | |
2054 | static bool classof(const Kinded *k) { |
2055 | return k->getKind() == Kinded::Kind::SoftMaxGradNodeKind; |
2056 | } |
2057 | |
2058 | |
2059 | bool isOverwrittenNthInput(unsigned idx) const { |
2060 | return false; |
2061 | } |
2062 | |
2063 | unsigned getNumInputs() const; |
2064 | std::string getInputName(unsigned idx) const; |
2065 | NodeValue getNthInput(unsigned idx); |
2066 | void setNthInput(unsigned idx, NodeValue val); |
2067 | llvm::StringRef getOutputName(unsigned idx) const; |
2068 | bool hasSideEffects() const { return 0; } |
2069 | bool isCanonical() const { return 1; } |
2070 | bool isDataParallel() const { return 0; } |
2071 | std::string getDebugDesc() const; |
2072 | bool isEqual(const SoftMaxGradNode &other) const; |
2073 | llvm::hash_code getHash() const; |
2074 | void visit(Node *parent, NodeWalker *visitor); |
2075 | Node* clone() const; |
2076 | bool verify() const; |
2077 | }; |
2078 | } // namespace glow |
2079 | |
2080 | |
2081 | namespace glow { |
2082 | /// Performs SoftMax normalization on the Input tensor. |
2083 | class SoftMaxNode final : public Node { |
2084 | NodeHandle Input_; |
2085 | NodeHandle Selected_; |
2086 | |
2087 | public: |
2088 | enum InputIndices { |
2089 | InputIdx = 0, |
2090 | SelectedIdx = 1, |
2091 | }; |
2092 | |
2093 | enum ResultIndices { |
2094 | ResultIdx = 0, |
2095 | }; |
2096 | |
2097 | SoftMaxNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Selected) |
2098 | : Node(Kinded::Kind::SoftMaxNodeKind, name), Input_(this, Input), Selected_(this, Selected) { |
2099 | addResult(Result); |
2100 | } |
2101 | const NodeValue getInput() const { return Input_; } |
2102 | const NodeValue getSelected() const { return Selected_; } |
2103 | NodeValue getResult() { return getNthResult(0); } |
2104 | const NodeValue getResult() const { return getNthResult(0); } |
2105 | |
2106 | static bool classof(const Kinded *k) { |
2107 | return k->getKind() == Kinded::Kind::SoftMaxNodeKind; |
2108 | } |
2109 | |
2110 | |
2111 | bool isOverwrittenNthInput(unsigned idx) const { |
2112 | return false; |
2113 | } |
2114 | |
2115 | unsigned getNumInputs() const; |
2116 | std::string getInputName(unsigned idx) const; |
2117 | NodeValue getNthInput(unsigned idx); |
2118 | void setNthInput(unsigned idx, NodeValue val); |
2119 | llvm::StringRef getOutputName(unsigned idx) const; |
2120 | bool hasSideEffects() const { return 0; } |
2121 | bool isCanonical() const { return 1; } |
2122 | bool isDataParallel() const { return 0; } |
2123 | std::string getDebugDesc() const; |
2124 | bool isEqual(const SoftMaxNode &other) const; |
2125 | llvm::hash_code getHash() const; |
2126 | void visit(Node *parent, NodeWalker *visitor); |
2127 | Node* clone() const; |
2128 | bool verify() const; |
2129 | SoftMaxGradNode *getGrad(GraphGradMapper &builder); |
2130 | }; |
2131 | } // namespace glow |
2132 | |
2133 | |
2134 | namespace glow { |
2135 | class LogSoftMaxGradNode final : public Node { |
2136 | NodeHandle Input_; |
2137 | NodeHandle Selected_; |
2138 | NodeHandle OriginalOutputForResult_; |
2139 | NodeHandle GradOfOriginalOutputNamedResult_; |
2140 | |
2141 | public: |
2142 | enum InputIndices { |
2143 | InputIdx = 0, |
2144 | SelectedIdx = 1, |
2145 | OriginalOutputForResultIdx = 2, |
2146 | GradOfOriginalOutputNamedResultIdx = 3, |
2147 | }; |
2148 | |
2149 | enum ResultIndices { |
2150 | GradOfInputNamedInputIdx = 0, |
2151 | GradOfInputNamedSelectedIdx = 1, |
2152 | }; |
2153 | |
2154 | LogSoftMaxGradNode(llvm::StringRef name, NodeValue Input, NodeValue Selected, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
2155 | : Node(Kinded::Kind::LogSoftMaxGradNodeKind, name), Input_(this, Input), Selected_(this, Selected), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
2156 | addResult(Input.getType()); |
2157 | addResult(Selected.getType()); |
2158 | } |
2159 | const NodeValue getInput() const { return Input_; } |
2160 | const NodeValue getSelected() const { return Selected_; } |
2161 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
2162 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
2163 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
2164 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
2165 | NodeValue getGradOfInputNamedSelected() { return getNthResult(1); } |
2166 | const NodeValue getGradOfInputNamedSelected() const { return getNthResult(1); } |
2167 | |
2168 | static bool classof(const Kinded *k) { |
2169 | return k->getKind() == Kinded::Kind::LogSoftMaxGradNodeKind; |
2170 | } |
2171 | |
2172 | |
2173 | bool isOverwrittenNthInput(unsigned idx) const { |
2174 | return false; |
2175 | } |
2176 | |
2177 | unsigned getNumInputs() const; |
2178 | std::string getInputName(unsigned idx) const; |
2179 | NodeValue getNthInput(unsigned idx); |
2180 | void setNthInput(unsigned idx, NodeValue val); |
2181 | llvm::StringRef getOutputName(unsigned idx) const; |
2182 | bool hasSideEffects() const { return 0; } |
2183 | bool isCanonical() const { return 1; } |
2184 | bool isDataParallel() const { return 0; } |
2185 | std::string getDebugDesc() const; |
2186 | bool isEqual(const LogSoftMaxGradNode &other) const; |
2187 | llvm::hash_code getHash() const; |
2188 | void visit(Node *parent, NodeWalker *visitor); |
2189 | Node* clone() const; |
2190 | bool verify() const; |
2191 | }; |
2192 | } // namespace glow |
2193 | |
2194 | |
2195 | namespace glow { |
2196 | /// Performs LogSoftMax normalization on the Input tensor. |
2197 | class LogSoftMaxNode final : public Node { |
2198 | NodeHandle Input_; |
2199 | NodeHandle Selected_; |
2200 | |
2201 | public: |
2202 | enum InputIndices { |
2203 | InputIdx = 0, |
2204 | SelectedIdx = 1, |
2205 | }; |
2206 | |
2207 | enum ResultIndices { |
2208 | ResultIdx = 0, |
2209 | }; |
2210 | |
2211 | LogSoftMaxNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Selected) |
2212 | : Node(Kinded::Kind::LogSoftMaxNodeKind, name), Input_(this, Input), Selected_(this, Selected) { |
2213 | addResult(Result); |
2214 | } |
2215 | const NodeValue getInput() const { return Input_; } |
2216 | const NodeValue getSelected() const { return Selected_; } |
2217 | NodeValue getResult() { return getNthResult(0); } |
2218 | const NodeValue getResult() const { return getNthResult(0); } |
2219 | |
2220 | static bool classof(const Kinded *k) { |
2221 | return k->getKind() == Kinded::Kind::LogSoftMaxNodeKind; |
2222 | } |
2223 | |
2224 | |
2225 | bool isOverwrittenNthInput(unsigned idx) const { |
2226 | return false; |
2227 | } |
2228 | |
2229 | unsigned getNumInputs() const; |
2230 | std::string getInputName(unsigned idx) const; |
2231 | NodeValue getNthInput(unsigned idx); |
2232 | void setNthInput(unsigned idx, NodeValue val); |
2233 | llvm::StringRef getOutputName(unsigned idx) const; |
2234 | bool hasSideEffects() const { return 0; } |
2235 | bool isCanonical() const { return 1; } |
2236 | bool isDataParallel() const { return 0; } |
2237 | std::string getDebugDesc() const; |
2238 | bool isEqual(const LogSoftMaxNode &other) const; |
2239 | llvm::hash_code getHash() const; |
2240 | void visit(Node *parent, NodeWalker *visitor); |
2241 | Node* clone() const; |
2242 | bool verify() const; |
2243 | LogSoftMaxGradNode *getGrad(GraphGradMapper &builder); |
2244 | }; |
2245 | } // namespace glow |
2246 | |
2247 | |
2248 | namespace glow { |
2249 | class CrossEntropyLossGradNode final : public Node { |
2250 | NodeHandle P_; |
2251 | NodeHandle Labels_; |
2252 | NodeHandle OriginalOutputForCE_; |
2253 | NodeHandle GradOfOriginalOutputNamedCE_; |
2254 | |
2255 | public: |
2256 | enum InputIndices { |
2257 | PIdx = 0, |
2258 | LabelsIdx = 1, |
2259 | OriginalOutputForCEIdx = 2, |
2260 | GradOfOriginalOutputNamedCEIdx = 3, |
2261 | }; |
2262 | |
2263 | enum ResultIndices { |
2264 | GradOfInputNamedPIdx = 0, |
2265 | GradOfInputNamedLabelsIdx = 1, |
2266 | }; |
2267 | |
2268 | CrossEntropyLossGradNode(llvm::StringRef name, NodeValue P, NodeValue Labels, NodeValue OriginalOutputForCE, NodeValue GradOfOriginalOutputNamedCE) |
2269 | : Node(Kinded::Kind::CrossEntropyLossGradNodeKind, name), P_(this, P), Labels_(this, Labels), OriginalOutputForCE_(this, OriginalOutputForCE), GradOfOriginalOutputNamedCE_(this, GradOfOriginalOutputNamedCE) { |
2270 | addResult(P.getType()); |
2271 | addResult(Labels.getType()); |
2272 | } |
2273 | const NodeValue getP() const { return P_; } |
2274 | const NodeValue getLabels() const { return Labels_; } |
2275 | const NodeValue getOriginalOutputForCE() const { return OriginalOutputForCE_; } |
2276 | const NodeValue getGradOfOriginalOutputNamedCE() const { return GradOfOriginalOutputNamedCE_; } |
2277 | NodeValue getGradOfInputNamedP() { return getNthResult(0); } |
2278 | const NodeValue getGradOfInputNamedP() const { return getNthResult(0); } |
2279 | NodeValue getGradOfInputNamedLabels() { return getNthResult(1); } |
2280 | const NodeValue getGradOfInputNamedLabels() const { return getNthResult(1); } |
2281 | |
2282 | static bool classof(const Kinded *k) { |
2283 | return k->getKind() == Kinded::Kind::CrossEntropyLossGradNodeKind; |
2284 | } |
2285 | |
2286 | |
2287 | bool isOverwrittenNthInput(unsigned idx) const { |
2288 | return false; |
2289 | } |
2290 | |
2291 | unsigned getNumInputs() const; |
2292 | std::string getInputName(unsigned idx) const; |
2293 | NodeValue getNthInput(unsigned idx); |
2294 | void setNthInput(unsigned idx, NodeValue val); |
2295 | llvm::StringRef getOutputName(unsigned idx) const; |
2296 | bool hasSideEffects() const { return 0; } |
2297 | bool isCanonical() const { return 1; } |
2298 | bool isDataParallel() const { return 0; } |
2299 | std::string getDebugDesc() const; |
2300 | bool isEqual(const CrossEntropyLossGradNode &other) const; |
2301 | llvm::hash_code getHash() const; |
2302 | void visit(Node *parent, NodeWalker *visitor); |
2303 | Node* clone() const; |
2304 | bool verify() const; |
2305 | }; |
2306 | } // namespace glow |
2307 | |
2308 | |
2309 | namespace glow { |
2310 | /// Computes the average cross entropy loss of the input. |
2311 | class CrossEntropyLossNode final : public Node { |
2312 | NodeHandle P_; |
2313 | NodeHandle Labels_; |
2314 | |
2315 | public: |
2316 | enum InputIndices { |
2317 | PIdx = 0, |
2318 | LabelsIdx = 1, |
2319 | }; |
2320 | |
2321 | enum ResultIndices { |
2322 | CEIdx = 0, |
2323 | }; |
2324 | |
2325 | CrossEntropyLossNode(llvm::StringRef name, TypeRef CE , NodeValue P, NodeValue Labels) |
2326 | : Node(Kinded::Kind::CrossEntropyLossNodeKind, name), P_(this, P), Labels_(this, Labels) { |
2327 | addResult(CE); |
2328 | } |
2329 | const NodeValue getP() const { return P_; } |
2330 | const NodeValue getLabels() const { return Labels_; } |
2331 | NodeValue getCE() { return getNthResult(0); } |
2332 | const NodeValue getCE() const { return getNthResult(0); } |
2333 | |
2334 | static bool classof(const Kinded *k) { |
2335 | return k->getKind() == Kinded::Kind::CrossEntropyLossNodeKind; |
2336 | } |
2337 | |
2338 | |
2339 | bool isOverwrittenNthInput(unsigned idx) const { |
2340 | return false; |
2341 | } |
2342 | |
2343 | unsigned getNumInputs() const; |
2344 | std::string getInputName(unsigned idx) const; |
2345 | NodeValue getNthInput(unsigned idx); |
2346 | void setNthInput(unsigned idx, NodeValue val); |
2347 | llvm::StringRef getOutputName(unsigned idx) const; |
2348 | bool hasSideEffects() const { return 0; } |
2349 | bool isCanonical() const { return 1; } |
2350 | bool isDataParallel() const { return 0; } |
2351 | std::string getDebugDesc() const; |
2352 | bool isEqual(const CrossEntropyLossNode &other) const; |
2353 | llvm::hash_code getHash() const; |
2354 | void visit(Node *parent, NodeWalker *visitor); |
2355 | Node* clone() const; |
2356 | bool verify() const; |
2357 | CrossEntropyLossGradNode *getGrad(GraphGradMapper &builder); |
2358 | }; |
2359 | } // namespace glow |
2360 | |
2361 | |
2362 | namespace glow { |
2363 | class RegressionGradNode final : public Node { |
2364 | NodeHandle Input_; |
2365 | NodeHandle Expected_; |
2366 | NodeHandle OriginalOutputForResult_; |
2367 | NodeHandle GradOfOriginalOutputNamedResult_; |
2368 | |
2369 | public: |
2370 | enum InputIndices { |
2371 | InputIdx = 0, |
2372 | ExpectedIdx = 1, |
2373 | OriginalOutputForResultIdx = 2, |
2374 | GradOfOriginalOutputNamedResultIdx = 3, |
2375 | }; |
2376 | |
2377 | enum ResultIndices { |
2378 | GradOfInputNamedInputIdx = 0, |
2379 | GradOfInputNamedExpectedIdx = 1, |
2380 | }; |
2381 | |
2382 | RegressionGradNode(llvm::StringRef name, NodeValue Input, NodeValue Expected, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
2383 | : Node(Kinded::Kind::RegressionGradNodeKind, name), Input_(this, Input), Expected_(this, Expected), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
2384 | addResult(Input.getType()); |
2385 | addResult(Expected.getType()); |
2386 | } |
2387 | const NodeValue getInput() const { return Input_; } |
2388 | const NodeValue getExpected() const { return Expected_; } |
2389 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
2390 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
2391 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
2392 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
2393 | NodeValue getGradOfInputNamedExpected() { return getNthResult(1); } |
2394 | const NodeValue getGradOfInputNamedExpected() const { return getNthResult(1); } |
2395 | |
2396 | static bool classof(const Kinded *k) { |
2397 | return k->getKind() == Kinded::Kind::RegressionGradNodeKind; |
2398 | } |
2399 | |
2400 | |
2401 | bool isOverwrittenNthInput(unsigned idx) const { |
2402 | return false; |
2403 | } |
2404 | |
2405 | unsigned getNumInputs() const; |
2406 | std::string getInputName(unsigned idx) const; |
2407 | NodeValue getNthInput(unsigned idx); |
2408 | void setNthInput(unsigned idx, NodeValue val); |
2409 | llvm::StringRef getOutputName(unsigned idx) const; |
2410 | bool hasSideEffects() const { return 0; } |
2411 | bool isCanonical() const { return 1; } |
2412 | bool isDataParallel() const { return 0; } |
2413 | std::string getDebugDesc() const; |
2414 | bool isEqual(const RegressionGradNode &other) const; |
2415 | llvm::hash_code getHash() const; |
2416 | void visit(Node *parent, NodeWalker *visitor); |
2417 | Node* clone() const; |
2418 | bool verify() const; |
2419 | }; |
2420 | } // namespace glow |
2421 | |
2422 | |
2423 | namespace glow { |
2424 | /// Takes an Input tensor and creates a regression output layer. |
2425 | class RegressionNode final : public Node { |
2426 | NodeHandle Input_; |
2427 | NodeHandle Expected_; |
2428 | |
2429 | public: |
2430 | enum InputIndices { |
2431 | InputIdx = 0, |
2432 | ExpectedIdx = 1, |
2433 | }; |
2434 | |
2435 | enum ResultIndices { |
2436 | ResultIdx = 0, |
2437 | }; |
2438 | |
2439 | RegressionNode(llvm::StringRef name, NodeValue Input, NodeValue Expected) |
2440 | : Node(Kinded::Kind::RegressionNodeKind, name), Input_(this, Input), Expected_(this, Expected) { |
2441 | addResult(Input.getType()); |
2442 | } |
2443 | const NodeValue getInput() const { return Input_; } |
2444 | const NodeValue getExpected() const { return Expected_; } |
2445 | NodeValue getResult() { return getNthResult(0); } |
2446 | const NodeValue getResult() const { return getNthResult(0); } |
2447 | |
2448 | static bool classof(const Kinded *k) { |
2449 | return k->getKind() == Kinded::Kind::RegressionNodeKind; |
2450 | } |
2451 | |
2452 | |
2453 | bool isOverwrittenNthInput(unsigned idx) const { |
2454 | return false; |
2455 | } |
2456 | |
2457 | unsigned getNumInputs() const; |
2458 | std::string getInputName(unsigned idx) const; |
2459 | NodeValue getNthInput(unsigned idx); |
2460 | void setNthInput(unsigned idx, NodeValue val); |
2461 | llvm::StringRef getOutputName(unsigned idx) const; |
2462 | bool hasSideEffects() const { return 0; } |
2463 | bool isCanonical() const { return 1; } |
2464 | bool isDataParallel() const { return 0; } |
2465 | std::string getDebugDesc() const; |
2466 | bool isEqual(const RegressionNode &other) const; |
2467 | llvm::hash_code getHash() const; |
2468 | void visit(Node *parent, NodeWalker *visitor); |
2469 | Node* clone() const; |
2470 | bool verify() const; |
2471 | RegressionGradNode *getGrad(GraphGradMapper &builder); |
2472 | }; |
2473 | } // namespace glow |
2474 | |
2475 | |
2476 | namespace glow { |
2477 | /// Computes the sigmoid cross entropy between two inputs. |
2478 | class SigmoidCrossEntropyWithLogitsNode final : public Node { |
2479 | NodeHandle Logits_; |
2480 | NodeHandle Targets_; |
2481 | |
2482 | public: |
2483 | enum InputIndices { |
2484 | LogitsIdx = 0, |
2485 | TargetsIdx = 1, |
2486 | }; |
2487 | |
2488 | enum ResultIndices { |
2489 | ResultIdx = 0, |
2490 | }; |
2491 | |
2492 | SigmoidCrossEntropyWithLogitsNode(llvm::StringRef name, TypeRef Result , NodeValue Logits, NodeValue Targets) |
2493 | : Node(Kinded::Kind::SigmoidCrossEntropyWithLogitsNodeKind, name), Logits_(this, Logits), Targets_(this, Targets) { |
2494 | addResult(Result); |
2495 | } |
2496 | const NodeValue getLogits() const { return Logits_; } |
2497 | const NodeValue getTargets() const { return Targets_; } |
2498 | NodeValue getResult() { return getNthResult(0); } |
2499 | const NodeValue getResult() const { return getNthResult(0); } |
2500 | |
2501 | static bool classof(const Kinded *k) { |
2502 | return k->getKind() == Kinded::Kind::SigmoidCrossEntropyWithLogitsNodeKind; |
2503 | } |
2504 | |
2505 | |
2506 | bool isOverwrittenNthInput(unsigned idx) const { |
2507 | return false; |
2508 | } |
2509 | |
2510 | unsigned getNumInputs() const; |
2511 | std::string getInputName(unsigned idx) const; |
2512 | NodeValue getNthInput(unsigned idx); |
2513 | void setNthInput(unsigned idx, NodeValue val); |
2514 | llvm::StringRef getOutputName(unsigned idx) const; |
2515 | bool hasSideEffects() const { return 0; } |
2516 | bool isCanonical() const { return 1; } |
2517 | bool isDataParallel() const { return 0; } |
2518 | std::string getDebugDesc() const; |
2519 | bool isEqual(const SigmoidCrossEntropyWithLogitsNode &other) const; |
2520 | llvm::hash_code getHash() const; |
2521 | void visit(Node *parent, NodeWalker *visitor); |
2522 | Node* clone() const; |
2523 | bool verify() const; |
2524 | }; |
2525 | } // namespace glow |
2526 | |
2527 | |
2528 | namespace glow { |
2529 | class AddGradNode final : public Node { |
2530 | NodeHandle LHS_; |
2531 | NodeHandle RHS_; |
2532 | NodeHandle OriginalOutputForResult_; |
2533 | NodeHandle GradOfOriginalOutputNamedResult_; |
2534 | |
2535 | public: |
2536 | enum InputIndices { |
2537 | LHSIdx = 0, |
2538 | RHSIdx = 1, |
2539 | OriginalOutputForResultIdx = 2, |
2540 | GradOfOriginalOutputNamedResultIdx = 3, |
2541 | }; |
2542 | |
2543 | enum ResultIndices { |
2544 | GradOfInputNamedLHSIdx = 0, |
2545 | GradOfInputNamedRHSIdx = 1, |
2546 | }; |
2547 | |
2548 | AddGradNode(llvm::StringRef name, NodeValue LHS, NodeValue RHS, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
2549 | : Node(Kinded::Kind::AddGradNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
2550 | addResult(LHS.getType()); |
2551 | addResult(RHS.getType()); |
2552 | } |
2553 | const NodeValue getLHS() const { return LHS_; } |
2554 | const NodeValue getRHS() const { return RHS_; } |
2555 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
2556 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
2557 | NodeValue getGradOfInputNamedLHS() { return getNthResult(0); } |
2558 | const NodeValue getGradOfInputNamedLHS() const { return getNthResult(0); } |
2559 | NodeValue getGradOfInputNamedRHS() { return getNthResult(1); } |
2560 | const NodeValue getGradOfInputNamedRHS() const { return getNthResult(1); } |
2561 | |
2562 | static bool classof(const Kinded *k) { |
2563 | return k->getKind() == Kinded::Kind::AddGradNodeKind; |
2564 | } |
2565 | |
2566 | |
2567 | bool isOverwrittenNthInput(unsigned idx) const { |
2568 | return false; |
2569 | } |
2570 | |
2571 | unsigned getNumInputs() const; |
2572 | std::string getInputName(unsigned idx) const; |
2573 | NodeValue getNthInput(unsigned idx); |
2574 | void setNthInput(unsigned idx, NodeValue val); |
2575 | llvm::StringRef getOutputName(unsigned idx) const; |
2576 | bool hasSideEffects() const { return 0; } |
2577 | bool isCanonical() const { return 1; } |
2578 | bool isDataParallel() const { return 1; } |
2579 | std::string getDebugDesc() const; |
2580 | bool isEqual(const AddGradNode &other) const; |
2581 | llvm::hash_code getHash() const; |
2582 | void visit(Node *parent, NodeWalker *visitor); |
2583 | Node* clone() const; |
2584 | bool verify() const; |
2585 | }; |
2586 | } // namespace glow |
2587 | |
2588 | |
2589 | namespace glow { |
2590 | /// Performs Add on the LHS and RHS operands. |
2591 | class AddNode final : public Node { |
2592 | NodeHandle LHS_; |
2593 | NodeHandle RHS_; |
2594 | |
2595 | public: |
2596 | enum InputIndices { |
2597 | LHSIdx = 0, |
2598 | RHSIdx = 1, |
2599 | }; |
2600 | |
2601 | enum ResultIndices { |
2602 | ResultIdx = 0, |
2603 | }; |
2604 | |
2605 | AddNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
2606 | : Node(Kinded::Kind::AddNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
2607 | addResult(Result); |
2608 | } |
2609 | const NodeValue getLHS() const { return LHS_; } |
2610 | const NodeValue getRHS() const { return RHS_; } |
2611 | NodeValue getResult() { return getNthResult(0); } |
2612 | const NodeValue getResult() const { return getNthResult(0); } |
2613 | |
2614 | static bool classof(const Kinded *k) { |
2615 | return k->getKind() == Kinded::Kind::AddNodeKind; |
2616 | } |
2617 | |
2618 | |
2619 | bool isOverwrittenNthInput(unsigned idx) const { |
2620 | return false; |
2621 | } |
2622 | |
2623 | unsigned getNumInputs() const; |
2624 | std::string getInputName(unsigned idx) const; |
2625 | NodeValue getNthInput(unsigned idx); |
2626 | void setNthInput(unsigned idx, NodeValue val); |
2627 | llvm::StringRef getOutputName(unsigned idx) const; |
2628 | bool hasSideEffects() const { return 0; } |
2629 | bool isCanonical() const { return 1; } |
2630 | bool isDataParallel() const { return 1; } |
2631 | std::string getDebugDesc() const; |
2632 | bool isEqual(const AddNode &other) const; |
2633 | llvm::hash_code getHash() const; |
2634 | void visit(Node *parent, NodeWalker *visitor); |
2635 | Node* clone() const; |
2636 | bool verify() const; |
2637 | AddGradNode *getGrad(GraphGradMapper &builder); |
2638 | }; |
2639 | } // namespace glow |
2640 | |
2641 | |
2642 | namespace glow { |
2643 | class MulGradNode final : public Node { |
2644 | NodeHandle LHS_; |
2645 | NodeHandle RHS_; |
2646 | NodeHandle OriginalOutputForResult_; |
2647 | NodeHandle GradOfOriginalOutputNamedResult_; |
2648 | |
2649 | public: |
2650 | enum InputIndices { |
2651 | LHSIdx = 0, |
2652 | RHSIdx = 1, |
2653 | OriginalOutputForResultIdx = 2, |
2654 | GradOfOriginalOutputNamedResultIdx = 3, |
2655 | }; |
2656 | |
2657 | enum ResultIndices { |
2658 | GradOfInputNamedLHSIdx = 0, |
2659 | GradOfInputNamedRHSIdx = 1, |
2660 | }; |
2661 | |
2662 | MulGradNode(llvm::StringRef name, NodeValue LHS, NodeValue RHS, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
2663 | : Node(Kinded::Kind::MulGradNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
2664 | addResult(LHS.getType()); |
2665 | addResult(RHS.getType()); |
2666 | } |
2667 | const NodeValue getLHS() const { return LHS_; } |
2668 | const NodeValue getRHS() const { return RHS_; } |
2669 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
2670 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
2671 | NodeValue getGradOfInputNamedLHS() { return getNthResult(0); } |
2672 | const NodeValue getGradOfInputNamedLHS() const { return getNthResult(0); } |
2673 | NodeValue getGradOfInputNamedRHS() { return getNthResult(1); } |
2674 | const NodeValue getGradOfInputNamedRHS() const { return getNthResult(1); } |
2675 | |
2676 | static bool classof(const Kinded *k) { |
2677 | return k->getKind() == Kinded::Kind::MulGradNodeKind; |
2678 | } |
2679 | |
2680 | |
2681 | bool isOverwrittenNthInput(unsigned idx) const { |
2682 | return false; |
2683 | } |
2684 | |
2685 | unsigned getNumInputs() const; |
2686 | std::string getInputName(unsigned idx) const; |
2687 | NodeValue getNthInput(unsigned idx); |
2688 | void setNthInput(unsigned idx, NodeValue val); |
2689 | llvm::StringRef getOutputName(unsigned idx) const; |
2690 | bool hasSideEffects() const { return 0; } |
2691 | bool isCanonical() const { return 1; } |
2692 | bool isDataParallel() const { return 1; } |
2693 | std::string getDebugDesc() const; |
2694 | bool isEqual(const MulGradNode &other) const; |
2695 | llvm::hash_code getHash() const; |
2696 | void visit(Node *parent, NodeWalker *visitor); |
2697 | Node* clone() const; |
2698 | bool verify() const; |
2699 | }; |
2700 | } // namespace glow |
2701 | |
2702 | |
2703 | namespace glow { |
2704 | /// Performs Mul on the LHS and RHS operands. |
2705 | class MulNode final : public Node { |
2706 | NodeHandle LHS_; |
2707 | NodeHandle RHS_; |
2708 | |
2709 | public: |
2710 | enum InputIndices { |
2711 | LHSIdx = 0, |
2712 | RHSIdx = 1, |
2713 | }; |
2714 | |
2715 | enum ResultIndices { |
2716 | ResultIdx = 0, |
2717 | }; |
2718 | |
2719 | MulNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
2720 | : Node(Kinded::Kind::MulNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
2721 | addResult(Result); |
2722 | } |
2723 | const NodeValue getLHS() const { return LHS_; } |
2724 | const NodeValue getRHS() const { return RHS_; } |
2725 | NodeValue getResult() { return getNthResult(0); } |
2726 | const NodeValue getResult() const { return getNthResult(0); } |
2727 | |
2728 | static bool classof(const Kinded *k) { |
2729 | return k->getKind() == Kinded::Kind::MulNodeKind; |
2730 | } |
2731 | |
2732 | |
2733 | bool isOverwrittenNthInput(unsigned idx) const { |
2734 | return false; |
2735 | } |
2736 | |
2737 | unsigned getNumInputs() const; |
2738 | std::string getInputName(unsigned idx) const; |
2739 | NodeValue getNthInput(unsigned idx); |
2740 | void setNthInput(unsigned idx, NodeValue val); |
2741 | llvm::StringRef getOutputName(unsigned idx) const; |
2742 | bool hasSideEffects() const { return 0; } |
2743 | bool isCanonical() const { return 1; } |
2744 | bool isDataParallel() const { return 1; } |
2745 | std::string getDebugDesc() const; |
2746 | bool isEqual(const MulNode &other) const; |
2747 | llvm::hash_code getHash() const; |
2748 | void visit(Node *parent, NodeWalker *visitor); |
2749 | Node* clone() const; |
2750 | bool verify() const; |
2751 | MulGradNode *getGrad(GraphGradMapper &builder); |
2752 | }; |
2753 | } // namespace glow |
2754 | |
2755 | |
2756 | namespace glow { |
2757 | class SubGradNode final : public Node { |
2758 | NodeHandle LHS_; |
2759 | NodeHandle RHS_; |
2760 | NodeHandle OriginalOutputForResult_; |
2761 | NodeHandle GradOfOriginalOutputNamedResult_; |
2762 | |
2763 | public: |
2764 | enum InputIndices { |
2765 | LHSIdx = 0, |
2766 | RHSIdx = 1, |
2767 | OriginalOutputForResultIdx = 2, |
2768 | GradOfOriginalOutputNamedResultIdx = 3, |
2769 | }; |
2770 | |
2771 | enum ResultIndices { |
2772 | GradOfInputNamedLHSIdx = 0, |
2773 | GradOfInputNamedRHSIdx = 1, |
2774 | }; |
2775 | |
2776 | SubGradNode(llvm::StringRef name, NodeValue LHS, NodeValue RHS, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
2777 | : Node(Kinded::Kind::SubGradNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
2778 | addResult(LHS.getType()); |
2779 | addResult(RHS.getType()); |
2780 | } |
2781 | const NodeValue getLHS() const { return LHS_; } |
2782 | const NodeValue getRHS() const { return RHS_; } |
2783 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
2784 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
2785 | NodeValue getGradOfInputNamedLHS() { return getNthResult(0); } |
2786 | const NodeValue getGradOfInputNamedLHS() const { return getNthResult(0); } |
2787 | NodeValue getGradOfInputNamedRHS() { return getNthResult(1); } |
2788 | const NodeValue getGradOfInputNamedRHS() const { return getNthResult(1); } |
2789 | |
2790 | static bool classof(const Kinded *k) { |
2791 | return k->getKind() == Kinded::Kind::SubGradNodeKind; |
2792 | } |
2793 | |
2794 | |
2795 | bool isOverwrittenNthInput(unsigned idx) const { |
2796 | return false; |
2797 | } |
2798 | |
2799 | unsigned getNumInputs() const; |
2800 | std::string getInputName(unsigned idx) const; |
2801 | NodeValue getNthInput(unsigned idx); |
2802 | void setNthInput(unsigned idx, NodeValue val); |
2803 | llvm::StringRef getOutputName(unsigned idx) const; |
2804 | bool hasSideEffects() const { return 0; } |
2805 | bool isCanonical() const { return 1; } |
2806 | bool isDataParallel() const { return 1; } |
2807 | std::string getDebugDesc() const; |
2808 | bool isEqual(const SubGradNode &other) const; |
2809 | llvm::hash_code getHash() const; |
2810 | void visit(Node *parent, NodeWalker *visitor); |
2811 | Node* clone() const; |
2812 | bool verify() const; |
2813 | }; |
2814 | } // namespace glow |
2815 | |
2816 | |
2817 | namespace glow { |
2818 | /// Performs Sub on the LHS and RHS operands. |
2819 | class SubNode final : public Node { |
2820 | NodeHandle LHS_; |
2821 | NodeHandle RHS_; |
2822 | |
2823 | public: |
2824 | enum InputIndices { |
2825 | LHSIdx = 0, |
2826 | RHSIdx = 1, |
2827 | }; |
2828 | |
2829 | enum ResultIndices { |
2830 | ResultIdx = 0, |
2831 | }; |
2832 | |
2833 | SubNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
2834 | : Node(Kinded::Kind::SubNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
2835 | addResult(Result); |
2836 | } |
2837 | const NodeValue getLHS() const { return LHS_; } |
2838 | const NodeValue getRHS() const { return RHS_; } |
2839 | NodeValue getResult() { return getNthResult(0); } |
2840 | const NodeValue getResult() const { return getNthResult(0); } |
2841 | |
2842 | static bool classof(const Kinded *k) { |
2843 | return k->getKind() == Kinded::Kind::SubNodeKind; |
2844 | } |
2845 | |
2846 | |
2847 | bool isOverwrittenNthInput(unsigned idx) const { |
2848 | return false; |
2849 | } |
2850 | |
2851 | unsigned getNumInputs() const; |
2852 | std::string getInputName(unsigned idx) const; |
2853 | NodeValue getNthInput(unsigned idx); |
2854 | void setNthInput(unsigned idx, NodeValue val); |
2855 | llvm::StringRef getOutputName(unsigned idx) const; |
2856 | bool hasSideEffects() const { return 0; } |
2857 | bool isCanonical() const { return 1; } |
2858 | bool isDataParallel() const { return 1; } |
2859 | std::string getDebugDesc() const; |
2860 | bool isEqual(const SubNode &other) const; |
2861 | llvm::hash_code getHash() const; |
2862 | void visit(Node *parent, NodeWalker *visitor); |
2863 | Node* clone() const; |
2864 | bool verify() const; |
2865 | SubGradNode *getGrad(GraphGradMapper &builder); |
2866 | }; |
2867 | } // namespace glow |
2868 | |
2869 | |
2870 | namespace glow { |
2871 | class DivGradNode final : public Node { |
2872 | NodeHandle LHS_; |
2873 | NodeHandle RHS_; |
2874 | NodeHandle OriginalOutputForResult_; |
2875 | NodeHandle GradOfOriginalOutputNamedResult_; |
2876 | |
2877 | public: |
2878 | enum InputIndices { |
2879 | LHSIdx = 0, |
2880 | RHSIdx = 1, |
2881 | OriginalOutputForResultIdx = 2, |
2882 | GradOfOriginalOutputNamedResultIdx = 3, |
2883 | }; |
2884 | |
2885 | enum ResultIndices { |
2886 | GradOfInputNamedLHSIdx = 0, |
2887 | GradOfInputNamedRHSIdx = 1, |
2888 | }; |
2889 | |
2890 | DivGradNode(llvm::StringRef name, NodeValue LHS, NodeValue RHS, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
2891 | : Node(Kinded::Kind::DivGradNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
2892 | addResult(LHS.getType()); |
2893 | addResult(RHS.getType()); |
2894 | } |
2895 | const NodeValue getLHS() const { return LHS_; } |
2896 | const NodeValue getRHS() const { return RHS_; } |
2897 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
2898 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
2899 | NodeValue getGradOfInputNamedLHS() { return getNthResult(0); } |
2900 | const NodeValue getGradOfInputNamedLHS() const { return getNthResult(0); } |
2901 | NodeValue getGradOfInputNamedRHS() { return getNthResult(1); } |
2902 | const NodeValue getGradOfInputNamedRHS() const { return getNthResult(1); } |
2903 | |
2904 | static bool classof(const Kinded *k) { |
2905 | return k->getKind() == Kinded::Kind::DivGradNodeKind; |
2906 | } |
2907 | |
2908 | |
2909 | bool isOverwrittenNthInput(unsigned idx) const { |
2910 | return false; |
2911 | } |
2912 | |
2913 | unsigned getNumInputs() const; |
2914 | std::string getInputName(unsigned idx) const; |
2915 | NodeValue getNthInput(unsigned idx); |
2916 | void setNthInput(unsigned idx, NodeValue val); |
2917 | llvm::StringRef getOutputName(unsigned idx) const; |
2918 | bool hasSideEffects() const { return 0; } |
2919 | bool isCanonical() const { return 1; } |
2920 | bool isDataParallel() const { return 1; } |
2921 | std::string getDebugDesc() const; |
2922 | bool isEqual(const DivGradNode &other) const; |
2923 | llvm::hash_code getHash() const; |
2924 | void visit(Node *parent, NodeWalker *visitor); |
2925 | Node* clone() const; |
2926 | bool verify() const; |
2927 | }; |
2928 | } // namespace glow |
2929 | |
2930 | |
2931 | namespace glow { |
2932 | /// Performs Div on the LHS and RHS operands. |
2933 | class DivNode final : public Node { |
2934 | NodeHandle LHS_; |
2935 | NodeHandle RHS_; |
2936 | |
2937 | public: |
2938 | enum InputIndices { |
2939 | LHSIdx = 0, |
2940 | RHSIdx = 1, |
2941 | }; |
2942 | |
2943 | enum ResultIndices { |
2944 | ResultIdx = 0, |
2945 | }; |
2946 | |
2947 | DivNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
2948 | : Node(Kinded::Kind::DivNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
2949 | addResult(Result); |
2950 | } |
2951 | const NodeValue getLHS() const { return LHS_; } |
2952 | const NodeValue getRHS() const { return RHS_; } |
2953 | NodeValue getResult() { return getNthResult(0); } |
2954 | const NodeValue getResult() const { return getNthResult(0); } |
2955 | |
2956 | static bool classof(const Kinded *k) { |
2957 | return k->getKind() == Kinded::Kind::DivNodeKind; |
2958 | } |
2959 | |
2960 | |
2961 | bool isOverwrittenNthInput(unsigned idx) const { |
2962 | return false; |
2963 | } |
2964 | |
2965 | unsigned getNumInputs() const; |
2966 | std::string getInputName(unsigned idx) const; |
2967 | NodeValue getNthInput(unsigned idx); |
2968 | void setNthInput(unsigned idx, NodeValue val); |
2969 | llvm::StringRef getOutputName(unsigned idx) const; |
2970 | bool hasSideEffects() const { return 0; } |
2971 | bool isCanonical() const { return 1; } |
2972 | bool isDataParallel() const { return 1; } |
2973 | std::string getDebugDesc() const; |
2974 | bool isEqual(const DivNode &other) const; |
2975 | llvm::hash_code getHash() const; |
2976 | void visit(Node *parent, NodeWalker *visitor); |
2977 | Node* clone() const; |
2978 | bool verify() const; |
2979 | DivGradNode *getGrad(GraphGradMapper &builder); |
2980 | }; |
2981 | } // namespace glow |
2982 | |
2983 | |
2984 | namespace glow { |
2985 | /// Performs Div on the LHS and RHS operands, then Floor. If Truncate is set to true then truncate the quotient to zero instead. |
2986 | class FloorDivNode final : public Node { |
2987 | NodeHandle LHS_; |
2988 | NodeHandle RHS_; |
2989 | bool Truncate_; |
2990 | |
2991 | public: |
2992 | enum InputIndices { |
2993 | LHSIdx = 0, |
2994 | RHSIdx = 1, |
2995 | }; |
2996 | |
2997 | enum ResultIndices { |
2998 | ResultIdx = 0, |
2999 | }; |
3000 | |
3001 | FloorDivNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS, bool Truncate) |
3002 | : Node(Kinded::Kind::FloorDivNodeKind, name), LHS_(this, LHS), RHS_(this, RHS), Truncate_(Truncate) { |
3003 | addResult(Result); |
3004 | } |
3005 | const NodeValue getLHS() const { return LHS_; } |
3006 | const NodeValue getRHS() const { return RHS_; } |
3007 | NodeValue getResult() { return getNthResult(0); } |
3008 | const NodeValue getResult() const { return getNthResult(0); } |
3009 | bool getTruncate() const { return Truncate_; } |
3010 | |
3011 | static bool classof(const Kinded *k) { |
3012 | return k->getKind() == Kinded::Kind::FloorDivNodeKind; |
3013 | } |
3014 | |
3015 | |
3016 | bool isOverwrittenNthInput(unsigned idx) const { |
3017 | return false; |
3018 | } |
3019 | |
3020 | unsigned getNumInputs() const; |
3021 | std::string getInputName(unsigned idx) const; |
3022 | NodeValue getNthInput(unsigned idx); |
3023 | void setNthInput(unsigned idx, NodeValue val); |
3024 | llvm::StringRef getOutputName(unsigned idx) const; |
3025 | bool hasSideEffects() const { return 0; } |
3026 | bool isCanonical() const { return 1; } |
3027 | bool isDataParallel() const { return 1; } |
3028 | std::string getDebugDesc() const; |
3029 | bool isEqual(const FloorDivNode &other) const; |
3030 | llvm::hash_code getHash() const; |
3031 | void visit(Node *parent, NodeWalker *visitor); |
3032 | Node* clone() const; |
3033 | bool verify() const; |
3034 | }; |
3035 | } // namespace glow |
3036 | |
3037 | |
3038 | namespace glow { |
3039 | /// Computes the element-wise remainder of division. |
3040 | class FmodNode final : public Node { |
3041 | NodeHandle LHS_; |
3042 | NodeHandle RHS_; |
3043 | |
3044 | public: |
3045 | enum InputIndices { |
3046 | LHSIdx = 0, |
3047 | RHSIdx = 1, |
3048 | }; |
3049 | |
3050 | enum ResultIndices { |
3051 | ResultIdx = 0, |
3052 | }; |
3053 | |
3054 | FmodNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3055 | : Node(Kinded::Kind::FmodNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3056 | addResult(Result); |
3057 | } |
3058 | const NodeValue getLHS() const { return LHS_; } |
3059 | const NodeValue getRHS() const { return RHS_; } |
3060 | NodeValue getResult() { return getNthResult(0); } |
3061 | const NodeValue getResult() const { return getNthResult(0); } |
3062 | |
3063 | static bool classof(const Kinded *k) { |
3064 | return k->getKind() == Kinded::Kind::FmodNodeKind; |
3065 | } |
3066 | |
3067 | |
3068 | bool isOverwrittenNthInput(unsigned idx) const { |
3069 | return false; |
3070 | } |
3071 | |
3072 | unsigned getNumInputs() const; |
3073 | std::string getInputName(unsigned idx) const; |
3074 | NodeValue getNthInput(unsigned idx); |
3075 | void setNthInput(unsigned idx, NodeValue val); |
3076 | llvm::StringRef getOutputName(unsigned idx) const; |
3077 | bool hasSideEffects() const { return 0; } |
3078 | bool isCanonical() const { return 1; } |
3079 | bool isDataParallel() const { return 1; } |
3080 | std::string getDebugDesc() const; |
3081 | bool isEqual(const FmodNode &other) const; |
3082 | llvm::hash_code getHash() const; |
3083 | void visit(Node *parent, NodeWalker *visitor); |
3084 | Node* clone() const; |
3085 | bool verify() const; |
3086 | }; |
3087 | } // namespace glow |
3088 | |
3089 | |
3090 | namespace glow { |
3091 | /// Performs Max on the LHS and RHS operands. |
3092 | class MaxNode final : public Node { |
3093 | NodeHandle LHS_; |
3094 | NodeHandle RHS_; |
3095 | |
3096 | public: |
3097 | enum InputIndices { |
3098 | LHSIdx = 0, |
3099 | RHSIdx = 1, |
3100 | }; |
3101 | |
3102 | enum ResultIndices { |
3103 | ResultIdx = 0, |
3104 | }; |
3105 | |
3106 | MaxNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3107 | : Node(Kinded::Kind::MaxNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3108 | addResult(Result); |
3109 | } |
3110 | const NodeValue getLHS() const { return LHS_; } |
3111 | const NodeValue getRHS() const { return RHS_; } |
3112 | NodeValue getResult() { return getNthResult(0); } |
3113 | const NodeValue getResult() const { return getNthResult(0); } |
3114 | |
3115 | static bool classof(const Kinded *k) { |
3116 | return k->getKind() == Kinded::Kind::MaxNodeKind; |
3117 | } |
3118 | |
3119 | |
3120 | bool isOverwrittenNthInput(unsigned idx) const { |
3121 | return false; |
3122 | } |
3123 | |
3124 | unsigned getNumInputs() const; |
3125 | std::string getInputName(unsigned idx) const; |
3126 | NodeValue getNthInput(unsigned idx); |
3127 | void setNthInput(unsigned idx, NodeValue val); |
3128 | llvm::StringRef getOutputName(unsigned idx) const; |
3129 | bool hasSideEffects() const { return 0; } |
3130 | bool isCanonical() const { return 1; } |
3131 | bool isDataParallel() const { return 1; } |
3132 | std::string getDebugDesc() const; |
3133 | bool isEqual(const MaxNode &other) const; |
3134 | llvm::hash_code getHash() const; |
3135 | void visit(Node *parent, NodeWalker *visitor); |
3136 | Node* clone() const; |
3137 | bool verify() const; |
3138 | }; |
3139 | } // namespace glow |
3140 | |
3141 | |
3142 | namespace glow { |
3143 | /// Performs Min on the LHS and RHS operands. |
3144 | class MinNode final : public Node { |
3145 | NodeHandle LHS_; |
3146 | NodeHandle RHS_; |
3147 | |
3148 | public: |
3149 | enum InputIndices { |
3150 | LHSIdx = 0, |
3151 | RHSIdx = 1, |
3152 | }; |
3153 | |
3154 | enum ResultIndices { |
3155 | ResultIdx = 0, |
3156 | }; |
3157 | |
3158 | MinNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3159 | : Node(Kinded::Kind::MinNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3160 | addResult(Result); |
3161 | } |
3162 | const NodeValue getLHS() const { return LHS_; } |
3163 | const NodeValue getRHS() const { return RHS_; } |
3164 | NodeValue getResult() { return getNthResult(0); } |
3165 | const NodeValue getResult() const { return getNthResult(0); } |
3166 | |
3167 | static bool classof(const Kinded *k) { |
3168 | return k->getKind() == Kinded::Kind::MinNodeKind; |
3169 | } |
3170 | |
3171 | |
3172 | bool isOverwrittenNthInput(unsigned idx) const { |
3173 | return false; |
3174 | } |
3175 | |
3176 | unsigned getNumInputs() const; |
3177 | std::string getInputName(unsigned idx) const; |
3178 | NodeValue getNthInput(unsigned idx); |
3179 | void setNthInput(unsigned idx, NodeValue val); |
3180 | llvm::StringRef getOutputName(unsigned idx) const; |
3181 | bool hasSideEffects() const { return 0; } |
3182 | bool isCanonical() const { return 1; } |
3183 | bool isDataParallel() const { return 1; } |
3184 | std::string getDebugDesc() const; |
3185 | bool isEqual(const MinNode &other) const; |
3186 | llvm::hash_code getHash() const; |
3187 | void visit(Node *parent, NodeWalker *visitor); |
3188 | Node* clone() const; |
3189 | bool verify() const; |
3190 | }; |
3191 | } // namespace glow |
3192 | |
3193 | |
3194 | namespace glow { |
3195 | /// Performs an element-wise EQUAL comparison between the LHS and RHS operands. |
3196 | class CmpEQNode final : public Node { |
3197 | NodeHandle LHS_; |
3198 | NodeHandle RHS_; |
3199 | |
3200 | public: |
3201 | enum InputIndices { |
3202 | LHSIdx = 0, |
3203 | RHSIdx = 1, |
3204 | }; |
3205 | |
3206 | enum ResultIndices { |
3207 | ResultIdx = 0, |
3208 | }; |
3209 | |
3210 | CmpEQNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3211 | : Node(Kinded::Kind::CmpEQNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3212 | addResult(Result); |
3213 | } |
3214 | const NodeValue getLHS() const { return LHS_; } |
3215 | const NodeValue getRHS() const { return RHS_; } |
3216 | NodeValue getResult() { return getNthResult(0); } |
3217 | const NodeValue getResult() const { return getNthResult(0); } |
3218 | |
3219 | static bool classof(const Kinded *k) { |
3220 | return k->getKind() == Kinded::Kind::CmpEQNodeKind; |
3221 | } |
3222 | |
3223 | |
3224 | bool isOverwrittenNthInput(unsigned idx) const { |
3225 | return false; |
3226 | } |
3227 | |
3228 | unsigned getNumInputs() const; |
3229 | std::string getInputName(unsigned idx) const; |
3230 | NodeValue getNthInput(unsigned idx); |
3231 | void setNthInput(unsigned idx, NodeValue val); |
3232 | llvm::StringRef getOutputName(unsigned idx) const; |
3233 | bool hasSideEffects() const { return 0; } |
3234 | bool isCanonical() const { return 1; } |
3235 | bool isDataParallel() const { return 1; } |
3236 | std::string getDebugDesc() const; |
3237 | bool isEqual(const CmpEQNode &other) const; |
3238 | llvm::hash_code getHash() const; |
3239 | void visit(Node *parent, NodeWalker *visitor); |
3240 | Node* clone() const; |
3241 | bool verify() const; |
3242 | }; |
3243 | } // namespace glow |
3244 | |
3245 | |
3246 | namespace glow { |
3247 | /// Performs an element-wise NOT EQUAL comparison between the LHS and RHS operands. |
3248 | class CmpNEQNode final : public Node { |
3249 | NodeHandle LHS_; |
3250 | NodeHandle RHS_; |
3251 | |
3252 | public: |
3253 | enum InputIndices { |
3254 | LHSIdx = 0, |
3255 | RHSIdx = 1, |
3256 | }; |
3257 | |
3258 | enum ResultIndices { |
3259 | ResultIdx = 0, |
3260 | }; |
3261 | |
3262 | CmpNEQNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3263 | : Node(Kinded::Kind::CmpNEQNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3264 | addResult(Result); |
3265 | } |
3266 | const NodeValue getLHS() const { return LHS_; } |
3267 | const NodeValue getRHS() const { return RHS_; } |
3268 | NodeValue getResult() { return getNthResult(0); } |
3269 | const NodeValue getResult() const { return getNthResult(0); } |
3270 | |
3271 | static bool classof(const Kinded *k) { |
3272 | return k->getKind() == Kinded::Kind::CmpNEQNodeKind; |
3273 | } |
3274 | |
3275 | |
3276 | bool isOverwrittenNthInput(unsigned idx) const { |
3277 | return false; |
3278 | } |
3279 | |
3280 | unsigned getNumInputs() const; |
3281 | std::string getInputName(unsigned idx) const; |
3282 | NodeValue getNthInput(unsigned idx); |
3283 | void setNthInput(unsigned idx, NodeValue val); |
3284 | llvm::StringRef getOutputName(unsigned idx) const; |
3285 | bool hasSideEffects() const { return 0; } |
3286 | bool isCanonical() const { return 1; } |
3287 | bool isDataParallel() const { return 1; } |
3288 | std::string getDebugDesc() const; |
3289 | bool isEqual(const CmpNEQNode &other) const; |
3290 | llvm::hash_code getHash() const; |
3291 | void visit(Node *parent, NodeWalker *visitor); |
3292 | Node* clone() const; |
3293 | bool verify() const; |
3294 | }; |
3295 | } // namespace glow |
3296 | |
3297 | |
3298 | namespace glow { |
3299 | /// Performs an element-wise LESS THAN comparison between the LHS and RHS operands. |
3300 | class CmpLTNode final : public Node { |
3301 | NodeHandle LHS_; |
3302 | NodeHandle RHS_; |
3303 | |
3304 | public: |
3305 | enum InputIndices { |
3306 | LHSIdx = 0, |
3307 | RHSIdx = 1, |
3308 | }; |
3309 | |
3310 | enum ResultIndices { |
3311 | ResultIdx = 0, |
3312 | }; |
3313 | |
3314 | CmpLTNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3315 | : Node(Kinded::Kind::CmpLTNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3316 | addResult(Result); |
3317 | } |
3318 | const NodeValue getLHS() const { return LHS_; } |
3319 | const NodeValue getRHS() const { return RHS_; } |
3320 | NodeValue getResult() { return getNthResult(0); } |
3321 | const NodeValue getResult() const { return getNthResult(0); } |
3322 | |
3323 | static bool classof(const Kinded *k) { |
3324 | return k->getKind() == Kinded::Kind::CmpLTNodeKind; |
3325 | } |
3326 | |
3327 | |
3328 | bool isOverwrittenNthInput(unsigned idx) const { |
3329 | return false; |
3330 | } |
3331 | |
3332 | unsigned getNumInputs() const; |
3333 | std::string getInputName(unsigned idx) const; |
3334 | NodeValue getNthInput(unsigned idx); |
3335 | void setNthInput(unsigned idx, NodeValue val); |
3336 | llvm::StringRef getOutputName(unsigned idx) const; |
3337 | bool hasSideEffects() const { return 0; } |
3338 | bool isCanonical() const { return 1; } |
3339 | bool isDataParallel() const { return 1; } |
3340 | std::string getDebugDesc() const; |
3341 | bool isEqual(const CmpLTNode &other) const; |
3342 | llvm::hash_code getHash() const; |
3343 | void visit(Node *parent, NodeWalker *visitor); |
3344 | Node* clone() const; |
3345 | bool verify() const; |
3346 | }; |
3347 | } // namespace glow |
3348 | |
3349 | |
3350 | namespace glow { |
3351 | /// Performs an element-wise LESS THAN OR EQUAL comparison between the LHS and RHS operands. |
3352 | class CmpLTENode final : public Node { |
3353 | NodeHandle LHS_; |
3354 | NodeHandle RHS_; |
3355 | |
3356 | public: |
3357 | enum InputIndices { |
3358 | LHSIdx = 0, |
3359 | RHSIdx = 1, |
3360 | }; |
3361 | |
3362 | enum ResultIndices { |
3363 | ResultIdx = 0, |
3364 | }; |
3365 | |
3366 | CmpLTENode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3367 | : Node(Kinded::Kind::CmpLTENodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3368 | addResult(Result); |
3369 | } |
3370 | const NodeValue getLHS() const { return LHS_; } |
3371 | const NodeValue getRHS() const { return RHS_; } |
3372 | NodeValue getResult() { return getNthResult(0); } |
3373 | const NodeValue getResult() const { return getNthResult(0); } |
3374 | |
3375 | static bool classof(const Kinded *k) { |
3376 | return k->getKind() == Kinded::Kind::CmpLTENodeKind; |
3377 | } |
3378 | |
3379 | |
3380 | bool isOverwrittenNthInput(unsigned idx) const { |
3381 | return false; |
3382 | } |
3383 | |
3384 | unsigned getNumInputs() const; |
3385 | std::string getInputName(unsigned idx) const; |
3386 | NodeValue getNthInput(unsigned idx); |
3387 | void setNthInput(unsigned idx, NodeValue val); |
3388 | llvm::StringRef getOutputName(unsigned idx) const; |
3389 | bool hasSideEffects() const { return 0; } |
3390 | bool isCanonical() const { return 1; } |
3391 | bool isDataParallel() const { return 1; } |
3392 | std::string getDebugDesc() const; |
3393 | bool isEqual(const CmpLTENode &other) const; |
3394 | llvm::hash_code getHash() const; |
3395 | void visit(Node *parent, NodeWalker *visitor); |
3396 | Node* clone() const; |
3397 | bool verify() const; |
3398 | }; |
3399 | } // namespace glow |
3400 | |
3401 | |
3402 | namespace glow { |
3403 | /// Performs elementwise pow(LHS, RHS). |
3404 | class PowNode final : public Node { |
3405 | NodeHandle LHS_; |
3406 | NodeHandle RHS_; |
3407 | |
3408 | public: |
3409 | enum InputIndices { |
3410 | LHSIdx = 0, |
3411 | RHSIdx = 1, |
3412 | }; |
3413 | |
3414 | enum ResultIndices { |
3415 | ResultIdx = 0, |
3416 | }; |
3417 | |
3418 | PowNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3419 | : Node(Kinded::Kind::PowNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3420 | addResult(Result); |
3421 | } |
3422 | const NodeValue getLHS() const { return LHS_; } |
3423 | const NodeValue getRHS() const { return RHS_; } |
3424 | NodeValue getResult() { return getNthResult(0); } |
3425 | const NodeValue getResult() const { return getNthResult(0); } |
3426 | |
3427 | static bool classof(const Kinded *k) { |
3428 | return k->getKind() == Kinded::Kind::PowNodeKind; |
3429 | } |
3430 | |
3431 | |
3432 | bool isOverwrittenNthInput(unsigned idx) const { |
3433 | return false; |
3434 | } |
3435 | |
3436 | unsigned getNumInputs() const; |
3437 | std::string getInputName(unsigned idx) const; |
3438 | NodeValue getNthInput(unsigned idx); |
3439 | void setNthInput(unsigned idx, NodeValue val); |
3440 | llvm::StringRef getOutputName(unsigned idx) const; |
3441 | bool hasSideEffects() const { return 0; } |
3442 | bool isCanonical() const { return 1; } |
3443 | bool isDataParallel() const { return 1; } |
3444 | std::string getDebugDesc() const; |
3445 | bool isEqual(const PowNode &other) const; |
3446 | llvm::hash_code getHash() const; |
3447 | void visit(Node *parent, NodeWalker *visitor); |
3448 | Node* clone() const; |
3449 | bool verify() const; |
3450 | }; |
3451 | } // namespace glow |
3452 | |
3453 | |
3454 | namespace glow { |
3455 | /// Performs an element-wise logical AND between the LHS and RHS operands. |
3456 | class AndNode final : public Node { |
3457 | NodeHandle LHS_; |
3458 | NodeHandle RHS_; |
3459 | |
3460 | public: |
3461 | enum InputIndices { |
3462 | LHSIdx = 0, |
3463 | RHSIdx = 1, |
3464 | }; |
3465 | |
3466 | enum ResultIndices { |
3467 | ResultIdx = 0, |
3468 | }; |
3469 | |
3470 | AndNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3471 | : Node(Kinded::Kind::AndNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3472 | addResult(Result); |
3473 | } |
3474 | const NodeValue getLHS() const { return LHS_; } |
3475 | const NodeValue getRHS() const { return RHS_; } |
3476 | NodeValue getResult() { return getNthResult(0); } |
3477 | const NodeValue getResult() const { return getNthResult(0); } |
3478 | |
3479 | static bool classof(const Kinded *k) { |
3480 | return k->getKind() == Kinded::Kind::AndNodeKind; |
3481 | } |
3482 | |
3483 | |
3484 | bool isOverwrittenNthInput(unsigned idx) const { |
3485 | return false; |
3486 | } |
3487 | |
3488 | unsigned getNumInputs() const; |
3489 | std::string getInputName(unsigned idx) const; |
3490 | NodeValue getNthInput(unsigned idx); |
3491 | void setNthInput(unsigned idx, NodeValue val); |
3492 | llvm::StringRef getOutputName(unsigned idx) const; |
3493 | bool hasSideEffects() const { return 0; } |
3494 | bool isCanonical() const { return 1; } |
3495 | bool isDataParallel() const { return 1; } |
3496 | std::string getDebugDesc() const; |
3497 | bool isEqual(const AndNode &other) const; |
3498 | llvm::hash_code getHash() const; |
3499 | void visit(Node *parent, NodeWalker *visitor); |
3500 | Node* clone() const; |
3501 | bool verify() const; |
3502 | }; |
3503 | } // namespace glow |
3504 | |
3505 | |
3506 | namespace glow { |
3507 | /// Performs an element-wise bitwise AND between the LHS and RHS operands. |
3508 | class BitwiseAndNode final : public Node { |
3509 | NodeHandle LHS_; |
3510 | NodeHandle RHS_; |
3511 | |
3512 | public: |
3513 | enum InputIndices { |
3514 | LHSIdx = 0, |
3515 | RHSIdx = 1, |
3516 | }; |
3517 | |
3518 | enum ResultIndices { |
3519 | ResultIdx = 0, |
3520 | }; |
3521 | |
3522 | BitwiseAndNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3523 | : Node(Kinded::Kind::BitwiseAndNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3524 | addResult(Result); |
3525 | } |
3526 | const NodeValue getLHS() const { return LHS_; } |
3527 | const NodeValue getRHS() const { return RHS_; } |
3528 | NodeValue getResult() { return getNthResult(0); } |
3529 | const NodeValue getResult() const { return getNthResult(0); } |
3530 | |
3531 | static bool classof(const Kinded *k) { |
3532 | return k->getKind() == Kinded::Kind::BitwiseAndNodeKind; |
3533 | } |
3534 | |
3535 | |
3536 | bool isOverwrittenNthInput(unsigned idx) const { |
3537 | return false; |
3538 | } |
3539 | |
3540 | unsigned getNumInputs() const; |
3541 | std::string getInputName(unsigned idx) const; |
3542 | NodeValue getNthInput(unsigned idx); |
3543 | void setNthInput(unsigned idx, NodeValue val); |
3544 | llvm::StringRef getOutputName(unsigned idx) const; |
3545 | bool hasSideEffects() const { return 0; } |
3546 | bool isCanonical() const { return 1; } |
3547 | bool isDataParallel() const { return 1; } |
3548 | std::string getDebugDesc() const; |
3549 | bool isEqual(const BitwiseAndNode &other) const; |
3550 | llvm::hash_code getHash() const; |
3551 | void visit(Node *parent, NodeWalker *visitor); |
3552 | Node* clone() const; |
3553 | bool verify() const; |
3554 | }; |
3555 | } // namespace glow |
3556 | |
3557 | |
3558 | namespace glow { |
3559 | /// Performs an element-wise logical OR between the LHS and RHS operands. |
3560 | class OrNode final : public Node { |
3561 | NodeHandle LHS_; |
3562 | NodeHandle RHS_; |
3563 | |
3564 | public: |
3565 | enum InputIndices { |
3566 | LHSIdx = 0, |
3567 | RHSIdx = 1, |
3568 | }; |
3569 | |
3570 | enum ResultIndices { |
3571 | ResultIdx = 0, |
3572 | }; |
3573 | |
3574 | OrNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3575 | : Node(Kinded::Kind::OrNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3576 | addResult(Result); |
3577 | } |
3578 | const NodeValue getLHS() const { return LHS_; } |
3579 | const NodeValue getRHS() const { return RHS_; } |
3580 | NodeValue getResult() { return getNthResult(0); } |
3581 | const NodeValue getResult() const { return getNthResult(0); } |
3582 | |
3583 | static bool classof(const Kinded *k) { |
3584 | return k->getKind() == Kinded::Kind::OrNodeKind; |
3585 | } |
3586 | |
3587 | |
3588 | bool isOverwrittenNthInput(unsigned idx) const { |
3589 | return false; |
3590 | } |
3591 | |
3592 | unsigned getNumInputs() const; |
3593 | std::string getInputName(unsigned idx) const; |
3594 | NodeValue getNthInput(unsigned idx); |
3595 | void setNthInput(unsigned idx, NodeValue val); |
3596 | llvm::StringRef getOutputName(unsigned idx) const; |
3597 | bool hasSideEffects() const { return 0; } |
3598 | bool isCanonical() const { return 1; } |
3599 | bool isDataParallel() const { return 1; } |
3600 | std::string getDebugDesc() const; |
3601 | bool isEqual(const OrNode &other) const; |
3602 | llvm::hash_code getHash() const; |
3603 | void visit(Node *parent, NodeWalker *visitor); |
3604 | Node* clone() const; |
3605 | bool verify() const; |
3606 | }; |
3607 | } // namespace glow |
3608 | |
3609 | |
3610 | namespace glow { |
3611 | /// Performs an element-wise bitwise OR between the LHS and RHS operands. |
3612 | class BitwiseOrNode final : public Node { |
3613 | NodeHandle LHS_; |
3614 | NodeHandle RHS_; |
3615 | |
3616 | public: |
3617 | enum InputIndices { |
3618 | LHSIdx = 0, |
3619 | RHSIdx = 1, |
3620 | }; |
3621 | |
3622 | enum ResultIndices { |
3623 | ResultIdx = 0, |
3624 | }; |
3625 | |
3626 | BitwiseOrNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3627 | : Node(Kinded::Kind::BitwiseOrNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3628 | addResult(Result); |
3629 | } |
3630 | const NodeValue getLHS() const { return LHS_; } |
3631 | const NodeValue getRHS() const { return RHS_; } |
3632 | NodeValue getResult() { return getNthResult(0); } |
3633 | const NodeValue getResult() const { return getNthResult(0); } |
3634 | |
3635 | static bool classof(const Kinded *k) { |
3636 | return k->getKind() == Kinded::Kind::BitwiseOrNodeKind; |
3637 | } |
3638 | |
3639 | |
3640 | bool isOverwrittenNthInput(unsigned idx) const { |
3641 | return false; |
3642 | } |
3643 | |
3644 | unsigned getNumInputs() const; |
3645 | std::string getInputName(unsigned idx) const; |
3646 | NodeValue getNthInput(unsigned idx); |
3647 | void setNthInput(unsigned idx, NodeValue val); |
3648 | llvm::StringRef getOutputName(unsigned idx) const; |
3649 | bool hasSideEffects() const { return 0; } |
3650 | bool isCanonical() const { return 1; } |
3651 | bool isDataParallel() const { return 1; } |
3652 | std::string getDebugDesc() const; |
3653 | bool isEqual(const BitwiseOrNode &other) const; |
3654 | llvm::hash_code getHash() const; |
3655 | void visit(Node *parent, NodeWalker *visitor); |
3656 | Node* clone() const; |
3657 | bool verify() const; |
3658 | }; |
3659 | } // namespace glow |
3660 | |
3661 | |
3662 | namespace glow { |
3663 | /// Performs an element-wise logical XOR between the LHS and RHS operands. |
3664 | class XorNode final : public Node { |
3665 | NodeHandle LHS_; |
3666 | NodeHandle RHS_; |
3667 | |
3668 | public: |
3669 | enum InputIndices { |
3670 | LHSIdx = 0, |
3671 | RHSIdx = 1, |
3672 | }; |
3673 | |
3674 | enum ResultIndices { |
3675 | ResultIdx = 0, |
3676 | }; |
3677 | |
3678 | XorNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3679 | : Node(Kinded::Kind::XorNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3680 | addResult(Result); |
3681 | } |
3682 | const NodeValue getLHS() const { return LHS_; } |
3683 | const NodeValue getRHS() const { return RHS_; } |
3684 | NodeValue getResult() { return getNthResult(0); } |
3685 | const NodeValue getResult() const { return getNthResult(0); } |
3686 | |
3687 | static bool classof(const Kinded *k) { |
3688 | return k->getKind() == Kinded::Kind::XorNodeKind; |
3689 | } |
3690 | |
3691 | |
3692 | bool isOverwrittenNthInput(unsigned idx) const { |
3693 | return false; |
3694 | } |
3695 | |
3696 | unsigned getNumInputs() const; |
3697 | std::string getInputName(unsigned idx) const; |
3698 | NodeValue getNthInput(unsigned idx); |
3699 | void setNthInput(unsigned idx, NodeValue val); |
3700 | llvm::StringRef getOutputName(unsigned idx) const; |
3701 | bool hasSideEffects() const { return 0; } |
3702 | bool isCanonical() const { return 1; } |
3703 | bool isDataParallel() const { return 1; } |
3704 | std::string getDebugDesc() const; |
3705 | bool isEqual(const XorNode &other) const; |
3706 | llvm::hash_code getHash() const; |
3707 | void visit(Node *parent, NodeWalker *visitor); |
3708 | Node* clone() const; |
3709 | bool verify() const; |
3710 | }; |
3711 | } // namespace glow |
3712 | |
3713 | |
3714 | namespace glow { |
3715 | /// Performs an element-wise bitwise XOR between the LHS and RHS operands. |
3716 | class BitwiseXorNode final : public Node { |
3717 | NodeHandle LHS_; |
3718 | NodeHandle RHS_; |
3719 | |
3720 | public: |
3721 | enum InputIndices { |
3722 | LHSIdx = 0, |
3723 | RHSIdx = 1, |
3724 | }; |
3725 | |
3726 | enum ResultIndices { |
3727 | ResultIdx = 0, |
3728 | }; |
3729 | |
3730 | BitwiseXorNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
3731 | : Node(Kinded::Kind::BitwiseXorNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
3732 | addResult(Result); |
3733 | } |
3734 | const NodeValue getLHS() const { return LHS_; } |
3735 | const NodeValue getRHS() const { return RHS_; } |
3736 | NodeValue getResult() { return getNthResult(0); } |
3737 | const NodeValue getResult() const { return getNthResult(0); } |
3738 | |
3739 | static bool classof(const Kinded *k) { |
3740 | return k->getKind() == Kinded::Kind::BitwiseXorNodeKind; |
3741 | } |
3742 | |
3743 | |
3744 | bool isOverwrittenNthInput(unsigned idx) const { |
3745 | return false; |
3746 | } |
3747 | |
3748 | unsigned getNumInputs() const; |
3749 | std::string getInputName(unsigned idx) const; |
3750 | NodeValue getNthInput(unsigned idx); |
3751 | void setNthInput(unsigned idx, NodeValue val); |
3752 | llvm::StringRef getOutputName(unsigned idx) const; |
3753 | bool hasSideEffects() const { return 0; } |
3754 | bool isCanonical() const { return 1; } |
3755 | bool isDataParallel() const { return 1; } |
3756 | std::string getDebugDesc() const; |
3757 | bool isEqual(const BitwiseXorNode &other) const; |
3758 | llvm::hash_code getHash() const; |
3759 | void visit(Node *parent, NodeWalker *visitor); |
3760 | Node* clone() const; |
3761 | bool verify() const; |
3762 | }; |
3763 | } // namespace glow |
3764 | |
3765 | |
3766 | namespace glow { |
3767 | /// Performs an element-wise logical NOT of the Input operand. |
3768 | class NotNode final : public Node { |
3769 | NodeHandle Input_; |
3770 | |
3771 | public: |
3772 | enum InputIndices { |
3773 | InputIdx = 0, |
3774 | }; |
3775 | |
3776 | enum ResultIndices { |
3777 | ResultIdx = 0, |
3778 | }; |
3779 | |
3780 | NotNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
3781 | : Node(Kinded::Kind::NotNodeKind, name), Input_(this, Input) { |
3782 | addResult(Result); |
3783 | } |
3784 | const NodeValue getInput() const { return Input_; } |
3785 | NodeValue getResult() { return getNthResult(0); } |
3786 | const NodeValue getResult() const { return getNthResult(0); } |
3787 | |
3788 | static bool classof(const Kinded *k) { |
3789 | return k->getKind() == Kinded::Kind::NotNodeKind; |
3790 | } |
3791 | |
3792 | |
3793 | bool isOverwrittenNthInput(unsigned idx) const { |
3794 | return false; |
3795 | } |
3796 | |
3797 | unsigned getNumInputs() const; |
3798 | std::string getInputName(unsigned idx) const; |
3799 | NodeValue getNthInput(unsigned idx); |
3800 | void setNthInput(unsigned idx, NodeValue val); |
3801 | llvm::StringRef getOutputName(unsigned idx) const; |
3802 | bool hasSideEffects() const { return 0; } |
3803 | bool isCanonical() const { return 1; } |
3804 | bool isDataParallel() const { return 1; } |
3805 | std::string getDebugDesc() const; |
3806 | bool isEqual(const NotNode &other) const; |
3807 | llvm::hash_code getHash() const; |
3808 | void visit(Node *parent, NodeWalker *visitor); |
3809 | Node* clone() const; |
3810 | bool verify() const; |
3811 | }; |
3812 | } // namespace glow |
3813 | |
3814 | |
3815 | namespace glow { |
3816 | /// Performs an element-wise bitwise NOT of the Input operand. |
3817 | class BitwiseNotNode final : public Node { |
3818 | NodeHandle Input_; |
3819 | |
3820 | public: |
3821 | enum InputIndices { |
3822 | InputIdx = 0, |
3823 | }; |
3824 | |
3825 | enum ResultIndices { |
3826 | ResultIdx = 0, |
3827 | }; |
3828 | |
3829 | BitwiseNotNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
3830 | : Node(Kinded::Kind::BitwiseNotNodeKind, name), Input_(this, Input) { |
3831 | addResult(Result); |
3832 | } |
3833 | const NodeValue getInput() const { return Input_; } |
3834 | NodeValue getResult() { return getNthResult(0); } |
3835 | const NodeValue getResult() const { return getNthResult(0); } |
3836 | |
3837 | static bool classof(const Kinded *k) { |
3838 | return k->getKind() == Kinded::Kind::BitwiseNotNodeKind; |
3839 | } |
3840 | |
3841 | |
3842 | bool isOverwrittenNthInput(unsigned idx) const { |
3843 | return false; |
3844 | } |
3845 | |
3846 | unsigned getNumInputs() const; |
3847 | std::string getInputName(unsigned idx) const; |
3848 | NodeValue getNthInput(unsigned idx); |
3849 | void setNthInput(unsigned idx, NodeValue val); |
3850 | llvm::StringRef getOutputName(unsigned idx) const; |
3851 | bool hasSideEffects() const { return 0; } |
3852 | bool isCanonical() const { return 1; } |
3853 | bool isDataParallel() const { return 1; } |
3854 | std::string getDebugDesc() const; |
3855 | bool isEqual(const BitwiseNotNode &other) const; |
3856 | llvm::hash_code getHash() const; |
3857 | void visit(Node *parent, NodeWalker *visitor); |
3858 | Node* clone() const; |
3859 | bool verify() const; |
3860 | }; |
3861 | } // namespace glow |
3862 | |
3863 | |
3864 | namespace glow { |
3865 | /// Performs an element-wise negation (sign flip) of the Input operand. |
3866 | class NegNode final : public Node { |
3867 | NodeHandle Input_; |
3868 | |
3869 | public: |
3870 | enum InputIndices { |
3871 | InputIdx = 0, |
3872 | }; |
3873 | |
3874 | enum ResultIndices { |
3875 | ResultIdx = 0, |
3876 | }; |
3877 | |
3878 | NegNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
3879 | : Node(Kinded::Kind::NegNodeKind, name), Input_(this, Input) { |
3880 | addResult(Result); |
3881 | } |
3882 | const NodeValue getInput() const { return Input_; } |
3883 | NodeValue getResult() { return getNthResult(0); } |
3884 | const NodeValue getResult() const { return getNthResult(0); } |
3885 | |
3886 | static bool classof(const Kinded *k) { |
3887 | return k->getKind() == Kinded::Kind::NegNodeKind; |
3888 | } |
3889 | |
3890 | |
3891 | bool isOverwrittenNthInput(unsigned idx) const { |
3892 | return false; |
3893 | } |
3894 | |
3895 | unsigned getNumInputs() const; |
3896 | std::string getInputName(unsigned idx) const; |
3897 | NodeValue getNthInput(unsigned idx); |
3898 | void setNthInput(unsigned idx, NodeValue val); |
3899 | llvm::StringRef getOutputName(unsigned idx) const; |
3900 | bool hasSideEffects() const { return 0; } |
3901 | bool isCanonical() const { return 1; } |
3902 | bool isDataParallel() const { return 1; } |
3903 | std::string getDebugDesc() const; |
3904 | bool isEqual(const NegNode &other) const; |
3905 | llvm::hash_code getHash() const; |
3906 | void visit(Node *parent, NodeWalker *visitor); |
3907 | Node* clone() const; |
3908 | bool verify() const; |
3909 | }; |
3910 | } // namespace glow |
3911 | |
3912 | |
3913 | namespace glow { |
3914 | /// Performs an element-wise ABS(x) of the Input operand. |
3915 | class AbsNode final : public Node { |
3916 | NodeHandle Input_; |
3917 | |
3918 | public: |
3919 | enum InputIndices { |
3920 | InputIdx = 0, |
3921 | }; |
3922 | |
3923 | enum ResultIndices { |
3924 | ResultIdx = 0, |
3925 | }; |
3926 | |
3927 | AbsNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
3928 | : Node(Kinded::Kind::AbsNodeKind, name), Input_(this, Input) { |
3929 | addResult(Result); |
3930 | } |
3931 | const NodeValue getInput() const { return Input_; } |
3932 | NodeValue getResult() { return getNthResult(0); } |
3933 | const NodeValue getResult() const { return getNthResult(0); } |
3934 | |
3935 | static bool classof(const Kinded *k) { |
3936 | return k->getKind() == Kinded::Kind::AbsNodeKind; |
3937 | } |
3938 | |
3939 | |
3940 | bool isOverwrittenNthInput(unsigned idx) const { |
3941 | return false; |
3942 | } |
3943 | |
3944 | unsigned getNumInputs() const; |
3945 | std::string getInputName(unsigned idx) const; |
3946 | NodeValue getNthInput(unsigned idx); |
3947 | void setNthInput(unsigned idx, NodeValue val); |
3948 | llvm::StringRef getOutputName(unsigned idx) const; |
3949 | bool hasSideEffects() const { return 0; } |
3950 | bool isCanonical() const { return 1; } |
3951 | bool isDataParallel() const { return 1; } |
3952 | std::string getDebugDesc() const; |
3953 | bool isEqual(const AbsNode &other) const; |
3954 | llvm::hash_code getHash() const; |
3955 | void visit(Node *parent, NodeWalker *visitor); |
3956 | Node* clone() const; |
3957 | bool verify() const; |
3958 | }; |
3959 | } // namespace glow |
3960 | |
3961 | |
3962 | namespace glow { |
3963 | /// Performs an element-wise FLOOR(x) of the Input operand. |
3964 | class FloorNode final : public Node { |
3965 | NodeHandle Input_; |
3966 | |
3967 | public: |
3968 | enum InputIndices { |
3969 | InputIdx = 0, |
3970 | }; |
3971 | |
3972 | enum ResultIndices { |
3973 | ResultIdx = 0, |
3974 | }; |
3975 | |
3976 | FloorNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
3977 | : Node(Kinded::Kind::FloorNodeKind, name), Input_(this, Input) { |
3978 | addResult(Result); |
3979 | } |
3980 | const NodeValue getInput() const { return Input_; } |
3981 | NodeValue getResult() { return getNthResult(0); } |
3982 | const NodeValue getResult() const { return getNthResult(0); } |
3983 | |
3984 | static bool classof(const Kinded *k) { |
3985 | return k->getKind() == Kinded::Kind::FloorNodeKind; |
3986 | } |
3987 | |
3988 | |
3989 | bool isOverwrittenNthInput(unsigned idx) const { |
3990 | return false; |
3991 | } |
3992 | |
3993 | unsigned getNumInputs() const; |
3994 | std::string getInputName(unsigned idx) const; |
3995 | NodeValue getNthInput(unsigned idx); |
3996 | void setNthInput(unsigned idx, NodeValue val); |
3997 | llvm::StringRef getOutputName(unsigned idx) const; |
3998 | bool hasSideEffects() const { return 0; } |
3999 | bool isCanonical() const { return 1; } |
4000 | bool isDataParallel() const { return 1; } |
4001 | std::string getDebugDesc() const; |
4002 | bool isEqual(const FloorNode &other) const; |
4003 | llvm::hash_code getHash() const; |
4004 | void visit(Node *parent, NodeWalker *visitor); |
4005 | Node* clone() const; |
4006 | bool verify() const; |
4007 | }; |
4008 | } // namespace glow |
4009 | |
4010 | |
4011 | namespace glow { |
4012 | /// Performs an element-wise Sign(x) of the Input operand |
4013 | class SignNode final : public Node { |
4014 | NodeHandle Input_; |
4015 | |
4016 | public: |
4017 | enum InputIndices { |
4018 | InputIdx = 0, |
4019 | }; |
4020 | |
4021 | enum ResultIndices { |
4022 | ResultIdx = 0, |
4023 | }; |
4024 | |
4025 | SignNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4026 | : Node(Kinded::Kind::SignNodeKind, name), Input_(this, Input) { |
4027 | addResult(Result); |
4028 | } |
4029 | const NodeValue getInput() const { return Input_; } |
4030 | NodeValue getResult() { return getNthResult(0); } |
4031 | const NodeValue getResult() const { return getNthResult(0); } |
4032 | |
4033 | static bool classof(const Kinded *k) { |
4034 | return k->getKind() == Kinded::Kind::SignNodeKind; |
4035 | } |
4036 | |
4037 | |
4038 | bool isOverwrittenNthInput(unsigned idx) const { |
4039 | return false; |
4040 | } |
4041 | |
4042 | unsigned getNumInputs() const; |
4043 | std::string getInputName(unsigned idx) const; |
4044 | NodeValue getNthInput(unsigned idx); |
4045 | void setNthInput(unsigned idx, NodeValue val); |
4046 | llvm::StringRef getOutputName(unsigned idx) const; |
4047 | bool hasSideEffects() const { return 0; } |
4048 | bool isCanonical() const { return 1; } |
4049 | bool isDataParallel() const { return 1; } |
4050 | std::string getDebugDesc() const; |
4051 | bool isEqual(const SignNode &other) const; |
4052 | llvm::hash_code getHash() const; |
4053 | void visit(Node *parent, NodeWalker *visitor); |
4054 | Node* clone() const; |
4055 | bool verify() const; |
4056 | }; |
4057 | } // namespace glow |
4058 | |
4059 | |
4060 | namespace glow { |
4061 | /// Performs an element-wise CEIL(x) of the Input operand. |
4062 | class CeilNode final : public Node { |
4063 | NodeHandle Input_; |
4064 | |
4065 | public: |
4066 | enum InputIndices { |
4067 | InputIdx = 0, |
4068 | }; |
4069 | |
4070 | enum ResultIndices { |
4071 | ResultIdx = 0, |
4072 | }; |
4073 | |
4074 | CeilNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4075 | : Node(Kinded::Kind::CeilNodeKind, name), Input_(this, Input) { |
4076 | addResult(Result); |
4077 | } |
4078 | const NodeValue getInput() const { return Input_; } |
4079 | NodeValue getResult() { return getNthResult(0); } |
4080 | const NodeValue getResult() const { return getNthResult(0); } |
4081 | |
4082 | static bool classof(const Kinded *k) { |
4083 | return k->getKind() == Kinded::Kind::CeilNodeKind; |
4084 | } |
4085 | |
4086 | |
4087 | bool isOverwrittenNthInput(unsigned idx) const { |
4088 | return false; |
4089 | } |
4090 | |
4091 | unsigned getNumInputs() const; |
4092 | std::string getInputName(unsigned idx) const; |
4093 | NodeValue getNthInput(unsigned idx); |
4094 | void setNthInput(unsigned idx, NodeValue val); |
4095 | llvm::StringRef getOutputName(unsigned idx) const; |
4096 | bool hasSideEffects() const { return 0; } |
4097 | bool isCanonical() const { return 1; } |
4098 | bool isDataParallel() const { return 1; } |
4099 | std::string getDebugDesc() const; |
4100 | bool isEqual(const CeilNode &other) const; |
4101 | llvm::hash_code getHash() const; |
4102 | void visit(Node *parent, NodeWalker *visitor); |
4103 | Node* clone() const; |
4104 | bool verify() const; |
4105 | }; |
4106 | } // namespace glow |
4107 | |
4108 | |
4109 | namespace glow { |
4110 | /// Performs an element-wise ROUND(x) of the Input operand. |
4111 | class RoundNode final : public Node { |
4112 | NodeHandle Input_; |
4113 | |
4114 | public: |
4115 | enum InputIndices { |
4116 | InputIdx = 0, |
4117 | }; |
4118 | |
4119 | enum ResultIndices { |
4120 | ResultIdx = 0, |
4121 | }; |
4122 | |
4123 | RoundNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4124 | : Node(Kinded::Kind::RoundNodeKind, name), Input_(this, Input) { |
4125 | addResult(Result); |
4126 | } |
4127 | const NodeValue getInput() const { return Input_; } |
4128 | NodeValue getResult() { return getNthResult(0); } |
4129 | const NodeValue getResult() const { return getNthResult(0); } |
4130 | |
4131 | static bool classof(const Kinded *k) { |
4132 | return k->getKind() == Kinded::Kind::RoundNodeKind; |
4133 | } |
4134 | |
4135 | |
4136 | bool isOverwrittenNthInput(unsigned idx) const { |
4137 | return false; |
4138 | } |
4139 | |
4140 | unsigned getNumInputs() const; |
4141 | std::string getInputName(unsigned idx) const; |
4142 | NodeValue getNthInput(unsigned idx); |
4143 | void setNthInput(unsigned idx, NodeValue val); |
4144 | llvm::StringRef getOutputName(unsigned idx) const; |
4145 | bool hasSideEffects() const { return 0; } |
4146 | bool isCanonical() const { return 1; } |
4147 | bool isDataParallel() const { return 1; } |
4148 | std::string getDebugDesc() const; |
4149 | bool isEqual(const RoundNode &other) const; |
4150 | llvm::hash_code getHash() const; |
4151 | void visit(Node *parent, NodeWalker *visitor); |
4152 | Node* clone() const; |
4153 | bool verify() const; |
4154 | }; |
4155 | } // namespace glow |
4156 | |
4157 | |
4158 | namespace glow { |
4159 | /// Performs an element-wise TRUNCATE(x) of the Input operand. |
4160 | class TruncateNode final : public Node { |
4161 | NodeHandle Input_; |
4162 | |
4163 | public: |
4164 | enum InputIndices { |
4165 | InputIdx = 0, |
4166 | }; |
4167 | |
4168 | enum ResultIndices { |
4169 | ResultIdx = 0, |
4170 | }; |
4171 | |
4172 | TruncateNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4173 | : Node(Kinded::Kind::TruncateNodeKind, name), Input_(this, Input) { |
4174 | addResult(Result); |
4175 | } |
4176 | const NodeValue getInput() const { return Input_; } |
4177 | NodeValue getResult() { return getNthResult(0); } |
4178 | const NodeValue getResult() const { return getNthResult(0); } |
4179 | |
4180 | static bool classof(const Kinded *k) { |
4181 | return k->getKind() == Kinded::Kind::TruncateNodeKind; |
4182 | } |
4183 | |
4184 | |
4185 | bool isOverwrittenNthInput(unsigned idx) const { |
4186 | return false; |
4187 | } |
4188 | |
4189 | unsigned getNumInputs() const; |
4190 | std::string getInputName(unsigned idx) const; |
4191 | NodeValue getNthInput(unsigned idx); |
4192 | void setNthInput(unsigned idx, NodeValue val); |
4193 | llvm::StringRef getOutputName(unsigned idx) const; |
4194 | bool hasSideEffects() const { return 0; } |
4195 | bool isCanonical() const { return 1; } |
4196 | bool isDataParallel() const { return 1; } |
4197 | std::string getDebugDesc() const; |
4198 | bool isEqual(const TruncateNode &other) const; |
4199 | llvm::hash_code getHash() const; |
4200 | void visit(Node *parent, NodeWalker *visitor); |
4201 | Node* clone() const; |
4202 | bool verify() const; |
4203 | }; |
4204 | } // namespace glow |
4205 | |
4206 | |
4207 | namespace glow { |
4208 | /// Performs an element-wise SQRT(x) of the Input operand. |
4209 | class SqrtNode final : public Node { |
4210 | NodeHandle Input_; |
4211 | |
4212 | public: |
4213 | enum InputIndices { |
4214 | InputIdx = 0, |
4215 | }; |
4216 | |
4217 | enum ResultIndices { |
4218 | ResultIdx = 0, |
4219 | }; |
4220 | |
4221 | SqrtNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4222 | : Node(Kinded::Kind::SqrtNodeKind, name), Input_(this, Input) { |
4223 | addResult(Result); |
4224 | } |
4225 | const NodeValue getInput() const { return Input_; } |
4226 | NodeValue getResult() { return getNthResult(0); } |
4227 | const NodeValue getResult() const { return getNthResult(0); } |
4228 | |
4229 | static bool classof(const Kinded *k) { |
4230 | return k->getKind() == Kinded::Kind::SqrtNodeKind; |
4231 | } |
4232 | |
4233 | |
4234 | bool isOverwrittenNthInput(unsigned idx) const { |
4235 | return false; |
4236 | } |
4237 | |
4238 | unsigned getNumInputs() const; |
4239 | std::string getInputName(unsigned idx) const; |
4240 | NodeValue getNthInput(unsigned idx); |
4241 | void setNthInput(unsigned idx, NodeValue val); |
4242 | llvm::StringRef getOutputName(unsigned idx) const; |
4243 | bool hasSideEffects() const { return 0; } |
4244 | bool isCanonical() const { return 1; } |
4245 | bool isDataParallel() const { return 1; } |
4246 | std::string getDebugDesc() const; |
4247 | bool isEqual(const SqrtNode &other) const; |
4248 | llvm::hash_code getHash() const; |
4249 | void visit(Node *parent, NodeWalker *visitor); |
4250 | Node* clone() const; |
4251 | bool verify() const; |
4252 | }; |
4253 | } // namespace glow |
4254 | |
4255 | |
4256 | namespace glow { |
4257 | /// Performs an element-wise RSQRT(x) = 1 / SQRT(x) of the Input operand. |
4258 | class RsqrtNode final : public Node { |
4259 | NodeHandle Input_; |
4260 | |
4261 | public: |
4262 | enum InputIndices { |
4263 | InputIdx = 0, |
4264 | }; |
4265 | |
4266 | enum ResultIndices { |
4267 | ResultIdx = 0, |
4268 | }; |
4269 | |
4270 | RsqrtNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4271 | : Node(Kinded::Kind::RsqrtNodeKind, name), Input_(this, Input) { |
4272 | addResult(Result); |
4273 | } |
4274 | const NodeValue getInput() const { return Input_; } |
4275 | NodeValue getResult() { return getNthResult(0); } |
4276 | const NodeValue getResult() const { return getNthResult(0); } |
4277 | |
4278 | static bool classof(const Kinded *k) { |
4279 | return k->getKind() == Kinded::Kind::RsqrtNodeKind; |
4280 | } |
4281 | |
4282 | |
4283 | bool isOverwrittenNthInput(unsigned idx) const { |
4284 | return false; |
4285 | } |
4286 | |
4287 | unsigned getNumInputs() const; |
4288 | std::string getInputName(unsigned idx) const; |
4289 | NodeValue getNthInput(unsigned idx); |
4290 | void setNthInput(unsigned idx, NodeValue val); |
4291 | llvm::StringRef getOutputName(unsigned idx) const; |
4292 | bool hasSideEffects() const { return 0; } |
4293 | bool isCanonical() const { return 1; } |
4294 | bool isDataParallel() const { return 1; } |
4295 | std::string getDebugDesc() const; |
4296 | bool isEqual(const RsqrtNode &other) const; |
4297 | llvm::hash_code getHash() const; |
4298 | void visit(Node *parent, NodeWalker *visitor); |
4299 | Node* clone() const; |
4300 | bool verify() const; |
4301 | }; |
4302 | } // namespace glow |
4303 | |
4304 | |
4305 | namespace glow { |
4306 | /// Performs an element-wise RECIPROCAL(x) = 1 / x of the Input operand. |
4307 | class ReciprocalNode final : public Node { |
4308 | NodeHandle Input_; |
4309 | |
4310 | public: |
4311 | enum InputIndices { |
4312 | InputIdx = 0, |
4313 | }; |
4314 | |
4315 | enum ResultIndices { |
4316 | ResultIdx = 0, |
4317 | }; |
4318 | |
4319 | ReciprocalNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4320 | : Node(Kinded::Kind::ReciprocalNodeKind, name), Input_(this, Input) { |
4321 | addResult(Result); |
4322 | } |
4323 | const NodeValue getInput() const { return Input_; } |
4324 | NodeValue getResult() { return getNthResult(0); } |
4325 | const NodeValue getResult() const { return getNthResult(0); } |
4326 | |
4327 | static bool classof(const Kinded *k) { |
4328 | return k->getKind() == Kinded::Kind::ReciprocalNodeKind; |
4329 | } |
4330 | |
4331 | |
4332 | bool isOverwrittenNthInput(unsigned idx) const { |
4333 | return false; |
4334 | } |
4335 | |
4336 | unsigned getNumInputs() const; |
4337 | std::string getInputName(unsigned idx) const; |
4338 | NodeValue getNthInput(unsigned idx); |
4339 | void setNthInput(unsigned idx, NodeValue val); |
4340 | llvm::StringRef getOutputName(unsigned idx) const; |
4341 | bool hasSideEffects() const { return 0; } |
4342 | bool isCanonical() const { return 1; } |
4343 | bool isDataParallel() const { return 1; } |
4344 | std::string getDebugDesc() const; |
4345 | bool isEqual(const ReciprocalNode &other) const; |
4346 | llvm::hash_code getHash() const; |
4347 | void visit(Node *parent, NodeWalker *visitor); |
4348 | Node* clone() const; |
4349 | bool verify() const; |
4350 | }; |
4351 | } // namespace glow |
4352 | |
4353 | |
4354 | namespace glow { |
4355 | /// Performs an element-wise SIN(x) of the Input operand. |
4356 | class SinNode final : public Node { |
4357 | NodeHandle Input_; |
4358 | |
4359 | public: |
4360 | enum InputIndices { |
4361 | InputIdx = 0, |
4362 | }; |
4363 | |
4364 | enum ResultIndices { |
4365 | ResultIdx = 0, |
4366 | }; |
4367 | |
4368 | SinNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4369 | : Node(Kinded::Kind::SinNodeKind, name), Input_(this, Input) { |
4370 | addResult(Result); |
4371 | } |
4372 | const NodeValue getInput() const { return Input_; } |
4373 | NodeValue getResult() { return getNthResult(0); } |
4374 | const NodeValue getResult() const { return getNthResult(0); } |
4375 | |
4376 | static bool classof(const Kinded *k) { |
4377 | return k->getKind() == Kinded::Kind::SinNodeKind; |
4378 | } |
4379 | |
4380 | |
4381 | bool isOverwrittenNthInput(unsigned idx) const { |
4382 | return false; |
4383 | } |
4384 | |
4385 | unsigned getNumInputs() const; |
4386 | std::string getInputName(unsigned idx) const; |
4387 | NodeValue getNthInput(unsigned idx); |
4388 | void setNthInput(unsigned idx, NodeValue val); |
4389 | llvm::StringRef getOutputName(unsigned idx) const; |
4390 | bool hasSideEffects() const { return 0; } |
4391 | bool isCanonical() const { return 1; } |
4392 | bool isDataParallel() const { return 1; } |
4393 | std::string getDebugDesc() const; |
4394 | bool isEqual(const SinNode &other) const; |
4395 | llvm::hash_code getHash() const; |
4396 | void visit(Node *parent, NodeWalker *visitor); |
4397 | Node* clone() const; |
4398 | bool verify() const; |
4399 | }; |
4400 | } // namespace glow |
4401 | |
4402 | |
4403 | namespace glow { |
4404 | /// Performs an element-wise COS(x) of the Input operand. |
4405 | class CosNode final : public Node { |
4406 | NodeHandle Input_; |
4407 | |
4408 | public: |
4409 | enum InputIndices { |
4410 | InputIdx = 0, |
4411 | }; |
4412 | |
4413 | enum ResultIndices { |
4414 | ResultIdx = 0, |
4415 | }; |
4416 | |
4417 | CosNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4418 | : Node(Kinded::Kind::CosNodeKind, name), Input_(this, Input) { |
4419 | addResult(Result); |
4420 | } |
4421 | const NodeValue getInput() const { return Input_; } |
4422 | NodeValue getResult() { return getNthResult(0); } |
4423 | const NodeValue getResult() const { return getNthResult(0); } |
4424 | |
4425 | static bool classof(const Kinded *k) { |
4426 | return k->getKind() == Kinded::Kind::CosNodeKind; |
4427 | } |
4428 | |
4429 | |
4430 | bool isOverwrittenNthInput(unsigned idx) const { |
4431 | return false; |
4432 | } |
4433 | |
4434 | unsigned getNumInputs() const; |
4435 | std::string getInputName(unsigned idx) const; |
4436 | NodeValue getNthInput(unsigned idx); |
4437 | void setNthInput(unsigned idx, NodeValue val); |
4438 | llvm::StringRef getOutputName(unsigned idx) const; |
4439 | bool hasSideEffects() const { return 0; } |
4440 | bool isCanonical() const { return 1; } |
4441 | bool isDataParallel() const { return 1; } |
4442 | std::string getDebugDesc() const; |
4443 | bool isEqual(const CosNode &other) const; |
4444 | llvm::hash_code getHash() const; |
4445 | void visit(Node *parent, NodeWalker *visitor); |
4446 | Node* clone() const; |
4447 | bool verify() const; |
4448 | }; |
4449 | } // namespace glow |
4450 | |
4451 | |
4452 | namespace glow { |
4453 | /// Performs element-wise natural log to the Input. |
4454 | class LogNode final : public Node { |
4455 | NodeHandle Input_; |
4456 | |
4457 | public: |
4458 | enum InputIndices { |
4459 | InputIdx = 0, |
4460 | }; |
4461 | |
4462 | enum ResultIndices { |
4463 | ResultIdx = 0, |
4464 | }; |
4465 | |
4466 | LogNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4467 | : Node(Kinded::Kind::LogNodeKind, name), Input_(this, Input) { |
4468 | addResult(Result); |
4469 | } |
4470 | const NodeValue getInput() const { return Input_; } |
4471 | NodeValue getResult() { return getNthResult(0); } |
4472 | const NodeValue getResult() const { return getNthResult(0); } |
4473 | |
4474 | static bool classof(const Kinded *k) { |
4475 | return k->getKind() == Kinded::Kind::LogNodeKind; |
4476 | } |
4477 | |
4478 | |
4479 | bool isOverwrittenNthInput(unsigned idx) const { |
4480 | return false; |
4481 | } |
4482 | |
4483 | unsigned getNumInputs() const; |
4484 | std::string getInputName(unsigned idx) const; |
4485 | NodeValue getNthInput(unsigned idx); |
4486 | void setNthInput(unsigned idx, NodeValue val); |
4487 | llvm::StringRef getOutputName(unsigned idx) const; |
4488 | bool hasSideEffects() const { return 0; } |
4489 | bool isCanonical() const { return 1; } |
4490 | bool isDataParallel() const { return 1; } |
4491 | std::string getDebugDesc() const; |
4492 | bool isEqual(const LogNode &other) const; |
4493 | llvm::hash_code getHash() const; |
4494 | void visit(Node *parent, NodeWalker *visitor); |
4495 | Node* clone() const; |
4496 | bool verify() const; |
4497 | }; |
4498 | } // namespace glow |
4499 | |
4500 | |
4501 | namespace glow { |
4502 | /// Performs an element-wise Arccosine(x) of the Input operand. |
4503 | class AcosNode final : public Node { |
4504 | NodeHandle Input_; |
4505 | |
4506 | public: |
4507 | enum InputIndices { |
4508 | InputIdx = 0, |
4509 | }; |
4510 | |
4511 | enum ResultIndices { |
4512 | ResultIdx = 0, |
4513 | }; |
4514 | |
4515 | AcosNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4516 | : Node(Kinded::Kind::AcosNodeKind, name), Input_(this, Input) { |
4517 | addResult(Result); |
4518 | } |
4519 | const NodeValue getInput() const { return Input_; } |
4520 | NodeValue getResult() { return getNthResult(0); } |
4521 | const NodeValue getResult() const { return getNthResult(0); } |
4522 | |
4523 | static bool classof(const Kinded *k) { |
4524 | return k->getKind() == Kinded::Kind::AcosNodeKind; |
4525 | } |
4526 | |
4527 | |
4528 | bool isOverwrittenNthInput(unsigned idx) const { |
4529 | return false; |
4530 | } |
4531 | |
4532 | unsigned getNumInputs() const; |
4533 | std::string getInputName(unsigned idx) const; |
4534 | NodeValue getNthInput(unsigned idx); |
4535 | void setNthInput(unsigned idx, NodeValue val); |
4536 | llvm::StringRef getOutputName(unsigned idx) const; |
4537 | bool hasSideEffects() const { return 0; } |
4538 | bool isCanonical() const { return 1; } |
4539 | bool isDataParallel() const { return 1; } |
4540 | std::string getDebugDesc() const; |
4541 | bool isEqual(const AcosNode &other) const; |
4542 | llvm::hash_code getHash() const; |
4543 | void visit(Node *parent, NodeWalker *visitor); |
4544 | Node* clone() const; |
4545 | bool verify() const; |
4546 | }; |
4547 | } // namespace glow |
4548 | |
4549 | |
4550 | namespace glow { |
4551 | /// Performs an element-wise Arcsine(x) of the Input operand. |
4552 | class AsinNode final : public Node { |
4553 | NodeHandle Input_; |
4554 | |
4555 | public: |
4556 | enum InputIndices { |
4557 | InputIdx = 0, |
4558 | }; |
4559 | |
4560 | enum ResultIndices { |
4561 | ResultIdx = 0, |
4562 | }; |
4563 | |
4564 | AsinNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4565 | : Node(Kinded::Kind::AsinNodeKind, name), Input_(this, Input) { |
4566 | addResult(Result); |
4567 | } |
4568 | const NodeValue getInput() const { return Input_; } |
4569 | NodeValue getResult() { return getNthResult(0); } |
4570 | const NodeValue getResult() const { return getNthResult(0); } |
4571 | |
4572 | static bool classof(const Kinded *k) { |
4573 | return k->getKind() == Kinded::Kind::AsinNodeKind; |
4574 | } |
4575 | |
4576 | |
4577 | bool isOverwrittenNthInput(unsigned idx) const { |
4578 | return false; |
4579 | } |
4580 | |
4581 | unsigned getNumInputs() const; |
4582 | std::string getInputName(unsigned idx) const; |
4583 | NodeValue getNthInput(unsigned idx); |
4584 | void setNthInput(unsigned idx, NodeValue val); |
4585 | llvm::StringRef getOutputName(unsigned idx) const; |
4586 | bool hasSideEffects() const { return 0; } |
4587 | bool isCanonical() const { return 1; } |
4588 | bool isDataParallel() const { return 1; } |
4589 | std::string getDebugDesc() const; |
4590 | bool isEqual(const AsinNode &other) const; |
4591 | llvm::hash_code getHash() const; |
4592 | void visit(Node *parent, NodeWalker *visitor); |
4593 | Node* clone() const; |
4594 | bool verify() const; |
4595 | }; |
4596 | } // namespace glow |
4597 | |
4598 | |
4599 | namespace glow { |
4600 | /// Performs an element-wise Arctan(x) of the Input operand. |
4601 | class AtanNode final : public Node { |
4602 | NodeHandle Input_; |
4603 | |
4604 | public: |
4605 | enum InputIndices { |
4606 | InputIdx = 0, |
4607 | }; |
4608 | |
4609 | enum ResultIndices { |
4610 | ResultIdx = 0, |
4611 | }; |
4612 | |
4613 | AtanNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4614 | : Node(Kinded::Kind::AtanNodeKind, name), Input_(this, Input) { |
4615 | addResult(Result); |
4616 | } |
4617 | const NodeValue getInput() const { return Input_; } |
4618 | NodeValue getResult() { return getNthResult(0); } |
4619 | const NodeValue getResult() const { return getNthResult(0); } |
4620 | |
4621 | static bool classof(const Kinded *k) { |
4622 | return k->getKind() == Kinded::Kind::AtanNodeKind; |
4623 | } |
4624 | |
4625 | |
4626 | bool isOverwrittenNthInput(unsigned idx) const { |
4627 | return false; |
4628 | } |
4629 | |
4630 | unsigned getNumInputs() const; |
4631 | std::string getInputName(unsigned idx) const; |
4632 | NodeValue getNthInput(unsigned idx); |
4633 | void setNthInput(unsigned idx, NodeValue val); |
4634 | llvm::StringRef getOutputName(unsigned idx) const; |
4635 | bool hasSideEffects() const { return 0; } |
4636 | bool isCanonical() const { return 1; } |
4637 | bool isDataParallel() const { return 1; } |
4638 | std::string getDebugDesc() const; |
4639 | bool isEqual(const AtanNode &other) const; |
4640 | llvm::hash_code getHash() const; |
4641 | void visit(Node *parent, NodeWalker *visitor); |
4642 | Node* clone() const; |
4643 | bool verify() const; |
4644 | }; |
4645 | } // namespace glow |
4646 | |
4647 | |
4648 | namespace glow { |
4649 | /// Performs an element-wise Erf(x) of the Input operand. |
4650 | class ErfNode final : public Node { |
4651 | NodeHandle Input_; |
4652 | |
4653 | public: |
4654 | enum InputIndices { |
4655 | InputIdx = 0, |
4656 | }; |
4657 | |
4658 | enum ResultIndices { |
4659 | ResultIdx = 0, |
4660 | }; |
4661 | |
4662 | ErfNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4663 | : Node(Kinded::Kind::ErfNodeKind, name), Input_(this, Input) { |
4664 | addResult(Result); |
4665 | } |
4666 | const NodeValue getInput() const { return Input_; } |
4667 | NodeValue getResult() { return getNthResult(0); } |
4668 | const NodeValue getResult() const { return getNthResult(0); } |
4669 | |
4670 | static bool classof(const Kinded *k) { |
4671 | return k->getKind() == Kinded::Kind::ErfNodeKind; |
4672 | } |
4673 | |
4674 | |
4675 | bool isOverwrittenNthInput(unsigned idx) const { |
4676 | return false; |
4677 | } |
4678 | |
4679 | unsigned getNumInputs() const; |
4680 | std::string getInputName(unsigned idx) const; |
4681 | NodeValue getNthInput(unsigned idx); |
4682 | void setNthInput(unsigned idx, NodeValue val); |
4683 | llvm::StringRef getOutputName(unsigned idx) const; |
4684 | bool hasSideEffects() const { return 0; } |
4685 | bool isCanonical() const { return 1; } |
4686 | bool isDataParallel() const { return 1; } |
4687 | std::string getDebugDesc() const; |
4688 | bool isEqual(const ErfNode &other) const; |
4689 | llvm::hash_code getHash() const; |
4690 | void visit(Node *parent, NodeWalker *visitor); |
4691 | Node* clone() const; |
4692 | bool verify() const; |
4693 | }; |
4694 | } // namespace glow |
4695 | |
4696 | |
4697 | namespace glow { |
4698 | /// Performs element-wise exponential to the Input. |
4699 | class ExpNode final : public Node { |
4700 | NodeHandle Input_; |
4701 | |
4702 | public: |
4703 | enum InputIndices { |
4704 | InputIdx = 0, |
4705 | }; |
4706 | |
4707 | enum ResultIndices { |
4708 | ResultIdx = 0, |
4709 | }; |
4710 | |
4711 | ExpNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
4712 | : Node(Kinded::Kind::ExpNodeKind, name), Input_(this, Input) { |
4713 | addResult(Result); |
4714 | } |
4715 | const NodeValue getInput() const { return Input_; } |
4716 | NodeValue getResult() { return getNthResult(0); } |
4717 | const NodeValue getResult() const { return getNthResult(0); } |
4718 | |
4719 | static bool classof(const Kinded *k) { |
4720 | return k->getKind() == Kinded::Kind::ExpNodeKind; |
4721 | } |
4722 | |
4723 | |
4724 | bool isOverwrittenNthInput(unsigned idx) const { |
4725 | return false; |
4726 | } |
4727 | |
4728 | unsigned getNumInputs() const; |
4729 | std::string getInputName(unsigned idx) const; |
4730 | NodeValue getNthInput(unsigned idx); |
4731 | void setNthInput(unsigned idx, NodeValue val); |
4732 | llvm::StringRef getOutputName(unsigned idx) const; |
4733 | bool hasSideEffects() const { return 0; } |
4734 | bool isCanonical() const { return 1; } |
4735 | bool isDataParallel() const { return 1; } |
4736 | std::string getDebugDesc() const; |
4737 | bool isEqual(const ExpNode &other) const; |
4738 | llvm::hash_code getHash() const; |
4739 | void visit(Node *parent, NodeWalker *visitor); |
4740 | Node* clone() const; |
4741 | bool verify() const; |
4742 | }; |
4743 | } // namespace glow |
4744 | |
4745 | |
4746 | namespace glow { |
4747 | /// Computes elementwise: result = log(input / (1 - input)). |
4748 | class LogitNode final : public Node { |
4749 | NodeHandle Input_; |
4750 | float Epsilon_; |
4751 | |
4752 | public: |
4753 | enum InputIndices { |
4754 | InputIdx = 0, |
4755 | }; |
4756 | |
4757 | enum ResultIndices { |
4758 | ResultIdx = 0, |
4759 | }; |
4760 | |
4761 | LogitNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Epsilon) |
4762 | : Node(Kinded::Kind::LogitNodeKind, name), Input_(this, Input), Epsilon_(Epsilon) { |
4763 | addResult(Result); |
4764 | } |
4765 | const NodeValue getInput() const { return Input_; } |
4766 | NodeValue getResult() { return getNthResult(0); } |
4767 | const NodeValue getResult() const { return getNthResult(0); } |
4768 | float getEpsilon() const { return Epsilon_; } |
4769 | |
4770 | static bool classof(const Kinded *k) { |
4771 | return k->getKind() == Kinded::Kind::LogitNodeKind; |
4772 | } |
4773 | |
4774 | |
4775 | bool isOverwrittenNthInput(unsigned idx) const { |
4776 | return false; |
4777 | } |
4778 | |
4779 | unsigned getNumInputs() const; |
4780 | std::string getInputName(unsigned idx) const; |
4781 | NodeValue getNthInput(unsigned idx); |
4782 | void setNthInput(unsigned idx, NodeValue val); |
4783 | llvm::StringRef getOutputName(unsigned idx) const; |
4784 | bool hasSideEffects() const { return 0; } |
4785 | bool isCanonical() const { return 1; } |
4786 | bool isDataParallel() const { return 1; } |
4787 | std::string getDebugDesc() const; |
4788 | bool isEqual(const LogitNode &other) const; |
4789 | llvm::hash_code getHash() const; |
4790 | void visit(Node *parent, NodeWalker *visitor); |
4791 | Node* clone() const; |
4792 | bool verify() const; |
4793 | }; |
4794 | } // namespace glow |
4795 | |
4796 | |
4797 | namespace glow { |
4798 | /// Selects indices of the true elements in Cond |
4799 | class NonZeroNode final : public Node { |
4800 | NodeHandle Cond_; |
4801 | |
4802 | public: |
4803 | enum InputIndices { |
4804 | CondIdx = 0, |
4805 | }; |
4806 | |
4807 | enum ResultIndices { |
4808 | ResultIdx = 0, |
4809 | }; |
4810 | |
4811 | NonZeroNode(llvm::StringRef name, TypeRef Result , NodeValue Cond) |
4812 | : Node(Kinded::Kind::NonZeroNodeKind, name), Cond_(this, Cond) { |
4813 | addResult(Result); |
4814 | } |
4815 | const NodeValue getCond() const { return Cond_; } |
4816 | NodeValue getResult() { return getNthResult(0); } |
4817 | const NodeValue getResult() const { return getNthResult(0); } |
4818 | |
4819 | static bool classof(const Kinded *k) { |
4820 | return k->getKind() == Kinded::Kind::NonZeroNodeKind; |
4821 | } |
4822 | |
4823 | |
4824 | bool isOverwrittenNthInput(unsigned idx) const { |
4825 | return false; |
4826 | } |
4827 | |
4828 | unsigned getNumInputs() const; |
4829 | std::string getInputName(unsigned idx) const; |
4830 | NodeValue getNthInput(unsigned idx); |
4831 | void setNthInput(unsigned idx, NodeValue val); |
4832 | llvm::StringRef getOutputName(unsigned idx) const; |
4833 | bool hasSideEffects() const { return 0; } |
4834 | bool isCanonical() const { return 1; } |
4835 | bool isDataParallel() const { return 1; } |
4836 | std::string getDebugDesc() const; |
4837 | bool isEqual(const NonZeroNode &other) const; |
4838 | llvm::hash_code getHash() const; |
4839 | void visit(Node *parent, NodeWalker *visitor); |
4840 | Node* clone() const; |
4841 | bool verify() const; |
4842 | }; |
4843 | } // namespace glow |
4844 | |
4845 | |
4846 | namespace glow { |
4847 | /// Selects between values on the LHS or RHS, depending on the value of Cond. Cond is generated by the compare instruction, and is target- and type-specific. |
4848 | class SelectNode final : public Node { |
4849 | NodeHandle Cond_; |
4850 | NodeHandle LHS_; |
4851 | NodeHandle RHS_; |
4852 | |
4853 | public: |
4854 | enum InputIndices { |
4855 | CondIdx = 0, |
4856 | LHSIdx = 1, |
4857 | RHSIdx = 2, |
4858 | }; |
4859 | |
4860 | enum ResultIndices { |
4861 | ResultIdx = 0, |
4862 | }; |
4863 | |
4864 | SelectNode(llvm::StringRef name, TypeRef Result , NodeValue Cond, NodeValue LHS, NodeValue RHS) |
4865 | : Node(Kinded::Kind::SelectNodeKind, name), Cond_(this, Cond), LHS_(this, LHS), RHS_(this, RHS) { |
4866 | addResult(Result); |
4867 | } |
4868 | const NodeValue getCond() const { return Cond_; } |
4869 | const NodeValue getLHS() const { return LHS_; } |
4870 | const NodeValue getRHS() const { return RHS_; } |
4871 | NodeValue getResult() { return getNthResult(0); } |
4872 | const NodeValue getResult() const { return getNthResult(0); } |
4873 | |
4874 | static bool classof(const Kinded *k) { |
4875 | return k->getKind() == Kinded::Kind::SelectNodeKind; |
4876 | } |
4877 | |
4878 | |
4879 | bool isOverwrittenNthInput(unsigned idx) const { |
4880 | return false; |
4881 | } |
4882 | |
4883 | unsigned getNumInputs() const; |
4884 | std::string getInputName(unsigned idx) const; |
4885 | NodeValue getNthInput(unsigned idx); |
4886 | void setNthInput(unsigned idx, NodeValue val); |
4887 | llvm::StringRef getOutputName(unsigned idx) const; |
4888 | bool hasSideEffects() const { return 0; } |
4889 | bool isCanonical() const { return 1; } |
4890 | bool isDataParallel() const { return 1; } |
4891 | std::string getDebugDesc() const; |
4892 | bool isEqual(const SelectNode &other) const; |
4893 | llvm::hash_code getHash() const; |
4894 | void visit(Node *parent, NodeWalker *visitor); |
4895 | Node* clone() const; |
4896 | bool verify() const; |
4897 | }; |
4898 | } // namespace glow |
4899 | |
4900 | |
4901 | namespace glow { |
4902 | /// Adds the 'Slice' operand to each one of the slices in the batch. |
4903 | class BatchedAddNode final : public Node { |
4904 | NodeHandle Batch_; |
4905 | NodeHandle Slice_; |
4906 | |
4907 | public: |
4908 | enum InputIndices { |
4909 | BatchIdx = 0, |
4910 | SliceIdx = 1, |
4911 | }; |
4912 | |
4913 | enum ResultIndices { |
4914 | ResultIdx = 0, |
4915 | }; |
4916 | |
4917 | BatchedAddNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, NodeValue Slice) |
4918 | : Node(Kinded::Kind::BatchedAddNodeKind, name), Batch_(this, Batch), Slice_(this, Slice) { |
4919 | addResult(Result); |
4920 | } |
4921 | const NodeValue getBatch() const { return Batch_; } |
4922 | const NodeValue getSlice() const { return Slice_; } |
4923 | NodeValue getResult() { return getNthResult(0); } |
4924 | const NodeValue getResult() const { return getNthResult(0); } |
4925 | |
4926 | static bool classof(const Kinded *k) { |
4927 | return k->getKind() == Kinded::Kind::BatchedAddNodeKind; |
4928 | } |
4929 | |
4930 | |
4931 | bool isOverwrittenNthInput(unsigned idx) const { |
4932 | return false; |
4933 | } |
4934 | |
4935 | unsigned getNumInputs() const; |
4936 | std::string getInputName(unsigned idx) const; |
4937 | NodeValue getNthInput(unsigned idx); |
4938 | void setNthInput(unsigned idx, NodeValue val); |
4939 | llvm::StringRef getOutputName(unsigned idx) const; |
4940 | bool hasSideEffects() const { return 0; } |
4941 | bool isCanonical() const { return 1; } |
4942 | bool isDataParallel() const { return 0; } |
4943 | std::string getDebugDesc() const; |
4944 | bool isEqual(const BatchedAddNode &other) const; |
4945 | llvm::hash_code getHash() const; |
4946 | void visit(Node *parent, NodeWalker *visitor); |
4947 | Node* clone() const; |
4948 | bool verify() const; |
4949 | }; |
4950 | } // namespace glow |
4951 | |
4952 | |
4953 | namespace glow { |
4954 | /// Multiplies the 'Slice' operand to each one of the slices in the batch. |
4955 | class BatchedMulNode final : public Node { |
4956 | NodeHandle Batch_; |
4957 | NodeHandle Slice_; |
4958 | |
4959 | public: |
4960 | enum InputIndices { |
4961 | BatchIdx = 0, |
4962 | SliceIdx = 1, |
4963 | }; |
4964 | |
4965 | enum ResultIndices { |
4966 | ResultIdx = 0, |
4967 | }; |
4968 | |
4969 | BatchedMulNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, NodeValue Slice) |
4970 | : Node(Kinded::Kind::BatchedMulNodeKind, name), Batch_(this, Batch), Slice_(this, Slice) { |
4971 | addResult(Result); |
4972 | } |
4973 | const NodeValue getBatch() const { return Batch_; } |
4974 | const NodeValue getSlice() const { return Slice_; } |
4975 | NodeValue getResult() { return getNthResult(0); } |
4976 | const NodeValue getResult() const { return getNthResult(0); } |
4977 | |
4978 | static bool classof(const Kinded *k) { |
4979 | return k->getKind() == Kinded::Kind::BatchedMulNodeKind; |
4980 | } |
4981 | |
4982 | |
4983 | bool isOverwrittenNthInput(unsigned idx) const { |
4984 | return false; |
4985 | } |
4986 | |
4987 | unsigned getNumInputs() const; |
4988 | std::string getInputName(unsigned idx) const; |
4989 | NodeValue getNthInput(unsigned idx); |
4990 | void setNthInput(unsigned idx, NodeValue val); |
4991 | llvm::StringRef getOutputName(unsigned idx) const; |
4992 | bool hasSideEffects() const { return 0; } |
4993 | bool isCanonical() const { return 1; } |
4994 | bool isDataParallel() const { return 0; } |
4995 | std::string getDebugDesc() const; |
4996 | bool isEqual(const BatchedMulNode &other) const; |
4997 | llvm::hash_code getHash() const; |
4998 | void visit(Node *parent, NodeWalker *visitor); |
4999 | Node* clone() const; |
5000 | bool verify() const; |
5001 | }; |
5002 | } // namespace glow |
5003 | |
5004 | |
5005 | namespace glow { |
5006 | /// Performs matrix multiplication between the LHS and RHS.Example: (A, Z) x (Z, B) => (A, B) |
5007 | class MatMulNode final : public Node { |
5008 | NodeHandle LHS_; |
5009 | NodeHandle RHS_; |
5010 | |
5011 | public: |
5012 | enum InputIndices { |
5013 | LHSIdx = 0, |
5014 | RHSIdx = 1, |
5015 | }; |
5016 | |
5017 | enum ResultIndices { |
5018 | ResultIdx = 0, |
5019 | }; |
5020 | |
5021 | MatMulNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
5022 | : Node(Kinded::Kind::MatMulNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
5023 | addResult(Result); |
5024 | } |
5025 | const NodeValue getLHS() const { return LHS_; } |
5026 | const NodeValue getRHS() const { return RHS_; } |
5027 | NodeValue getResult() { return getNthResult(0); } |
5028 | const NodeValue getResult() const { return getNthResult(0); } |
5029 | |
5030 | static bool classof(const Kinded *k) { |
5031 | return k->getKind() == Kinded::Kind::MatMulNodeKind; |
5032 | } |
5033 | |
5034 | |
5035 | bool isOverwrittenNthInput(unsigned idx) const { |
5036 | return false; |
5037 | } |
5038 | |
5039 | unsigned getNumInputs() const; |
5040 | std::string getInputName(unsigned idx) const; |
5041 | NodeValue getNthInput(unsigned idx); |
5042 | void setNthInput(unsigned idx, NodeValue val); |
5043 | llvm::StringRef getOutputName(unsigned idx) const; |
5044 | bool hasSideEffects() const { return 0; } |
5045 | bool isCanonical() const { return 1; } |
5046 | bool isDataParallel() const { return 0; } |
5047 | std::string getDebugDesc() const; |
5048 | bool isEqual(const MatMulNode &other) const; |
5049 | llvm::hash_code getHash() const; |
5050 | void visit(Node *parent, NodeWalker *visitor); |
5051 | Node* clone() const; |
5052 | bool verify() const; |
5053 | }; |
5054 | } // namespace glow |
5055 | |
5056 | |
5057 | namespace glow { |
5058 | /// Performs batch matrix multiplication between the LHS and RHS. The operands are a stack of two dimensional matrices. Example: (N, A, Z) x (N, Z, B) => (N, A, B) |
5059 | class BatchMatMulNode final : public Node { |
5060 | NodeHandle LHS_; |
5061 | NodeHandle RHS_; |
5062 | |
5063 | public: |
5064 | enum InputIndices { |
5065 | LHSIdx = 0, |
5066 | RHSIdx = 1, |
5067 | }; |
5068 | |
5069 | enum ResultIndices { |
5070 | ResultIdx = 0, |
5071 | }; |
5072 | |
5073 | BatchMatMulNode(llvm::StringRef name, TypeRef Result , NodeValue LHS, NodeValue RHS) |
5074 | : Node(Kinded::Kind::BatchMatMulNodeKind, name), LHS_(this, LHS), RHS_(this, RHS) { |
5075 | addResult(Result); |
5076 | } |
5077 | const NodeValue getLHS() const { return LHS_; } |
5078 | const NodeValue getRHS() const { return RHS_; } |
5079 | NodeValue getResult() { return getNthResult(0); } |
5080 | const NodeValue getResult() const { return getNthResult(0); } |
5081 | |
5082 | static bool classof(const Kinded *k) { |
5083 | return k->getKind() == Kinded::Kind::BatchMatMulNodeKind; |
5084 | } |
5085 | |
5086 | |
5087 | bool isOverwrittenNthInput(unsigned idx) const { |
5088 | return false; |
5089 | } |
5090 | |
5091 | unsigned getNumInputs() const; |
5092 | std::string getInputName(unsigned idx) const; |
5093 | NodeValue getNthInput(unsigned idx); |
5094 | void setNthInput(unsigned idx, NodeValue val); |
5095 | llvm::StringRef getOutputName(unsigned idx) const; |
5096 | bool hasSideEffects() const { return 0; } |
5097 | bool isCanonical() const { return 1; } |
5098 | bool isDataParallel() const { return 0; } |
5099 | std::string getDebugDesc() const; |
5100 | bool isEqual(const BatchMatMulNode &other) const; |
5101 | llvm::hash_code getHash() const; |
5102 | void visit(Node *parent, NodeWalker *visitor); |
5103 | Node* clone() const; |
5104 | bool verify() const; |
5105 | }; |
5106 | } // namespace glow |
5107 | |
5108 | |
5109 | namespace glow { |
5110 | /// Accumulates all of the layers in the batch and produce a tensor that has the same dimensions as the input tensor without the first dimension. |
5111 | class BatchedReduceAddNode final : public Node { |
5112 | NodeHandle Batch_; |
5113 | unsigned_t Axis_; |
5114 | |
5115 | public: |
5116 | enum InputIndices { |
5117 | BatchIdx = 0, |
5118 | }; |
5119 | |
5120 | enum ResultIndices { |
5121 | ResultIdx = 0, |
5122 | }; |
5123 | |
5124 | BatchedReduceAddNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, unsigned_t Axis) |
5125 | : Node(Kinded::Kind::BatchedReduceAddNodeKind, name), Batch_(this, Batch), Axis_(Axis) { |
5126 | addResult(Result); |
5127 | } |
5128 | const NodeValue getBatch() const { return Batch_; } |
5129 | NodeValue getResult() { return getNthResult(0); } |
5130 | const NodeValue getResult() const { return getNthResult(0); } |
5131 | unsigned_t getAxis() const { return Axis_; } |
5132 | |
5133 | static bool classof(const Kinded *k) { |
5134 | return k->getKind() == Kinded::Kind::BatchedReduceAddNodeKind; |
5135 | } |
5136 | |
5137 | |
5138 | bool isOverwrittenNthInput(unsigned idx) const { |
5139 | return false; |
5140 | } |
5141 | |
5142 | unsigned getNumInputs() const; |
5143 | std::string getInputName(unsigned idx) const; |
5144 | NodeValue getNthInput(unsigned idx); |
5145 | void setNthInput(unsigned idx, NodeValue val); |
5146 | llvm::StringRef getOutputName(unsigned idx) const; |
5147 | bool hasSideEffects() const { return 0; } |
5148 | bool isCanonical() const { return 1; } |
5149 | bool isDataParallel() const { return 0; } |
5150 | std::string getDebugDesc() const; |
5151 | bool isEqual(const BatchedReduceAddNode &other) const; |
5152 | llvm::hash_code getHash() const; |
5153 | void visit(Node *parent, NodeWalker *visitor); |
5154 | Node* clone() const; |
5155 | bool verify() const; |
5156 | }; |
5157 | } // namespace glow |
5158 | |
5159 | |
5160 | namespace glow { |
5161 | /// Accumulates squares of all of the layers in the batch and produce a tensor that has the same dimensions as the input tensor without the first dimension. |
5162 | class BatchedReduceSumSquareNode final : public Node { |
5163 | NodeHandle Batch_; |
5164 | unsigned_t Axis_; |
5165 | |
5166 | public: |
5167 | enum InputIndices { |
5168 | BatchIdx = 0, |
5169 | }; |
5170 | |
5171 | enum ResultIndices { |
5172 | ResultIdx = 0, |
5173 | }; |
5174 | |
5175 | BatchedReduceSumSquareNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, unsigned_t Axis) |
5176 | : Node(Kinded::Kind::BatchedReduceSumSquareNodeKind, name), Batch_(this, Batch), Axis_(Axis) { |
5177 | addResult(Result); |
5178 | } |
5179 | const NodeValue getBatch() const { return Batch_; } |
5180 | NodeValue getResult() { return getNthResult(0); } |
5181 | const NodeValue getResult() const { return getNthResult(0); } |
5182 | unsigned_t getAxis() const { return Axis_; } |
5183 | |
5184 | static bool classof(const Kinded *k) { |
5185 | return k->getKind() == Kinded::Kind::BatchedReduceSumSquareNodeKind; |
5186 | } |
5187 | |
5188 | |
5189 | bool isOverwrittenNthInput(unsigned idx) const { |
5190 | return false; |
5191 | } |
5192 | |
5193 | unsigned getNumInputs() const; |
5194 | std::string getInputName(unsigned idx) const; |
5195 | NodeValue getNthInput(unsigned idx); |
5196 | void setNthInput(unsigned idx, NodeValue val); |
5197 | llvm::StringRef getOutputName(unsigned idx) const; |
5198 | bool hasSideEffects() const { return 0; } |
5199 | bool isCanonical() const { return 1; } |
5200 | bool isDataParallel() const { return 0; } |
5201 | std::string getDebugDesc() const; |
5202 | bool isEqual(const BatchedReduceSumSquareNode &other) const; |
5203 | llvm::hash_code getHash() const; |
5204 | void visit(Node *parent, NodeWalker *visitor); |
5205 | Node* clone() const; |
5206 | bool verify() const; |
5207 | }; |
5208 | } // namespace glow |
5209 | |
5210 | |
5211 | namespace glow { |
5212 | /// Performs Average Mean operation on the Input given Axes. |
5213 | class BatchedReduceMeanNode final : public Node { |
5214 | NodeHandle Batch_; |
5215 | std::vector<unsigned_t> Axes_; |
5216 | |
5217 | public: |
5218 | enum InputIndices { |
5219 | BatchIdx = 0, |
5220 | }; |
5221 | |
5222 | enum ResultIndices { |
5223 | ResultIdx = 0, |
5224 | }; |
5225 | |
5226 | BatchedReduceMeanNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, std::vector<unsigned_t> Axes) |
5227 | : Node(Kinded::Kind::BatchedReduceMeanNodeKind, name), Batch_(this, Batch), Axes_(Axes) { |
5228 | addResult(Result); |
5229 | } |
5230 | const NodeValue getBatch() const { return Batch_; } |
5231 | NodeValue getResult() { return getNthResult(0); } |
5232 | const NodeValue getResult() const { return getNthResult(0); } |
5233 | llvm::ArrayRef<unsigned_t> getAxes() const { return Axes_; } |
5234 | |
5235 | static bool classof(const Kinded *k) { |
5236 | return k->getKind() == Kinded::Kind::BatchedReduceMeanNodeKind; |
5237 | } |
5238 | |
5239 | |
5240 | bool isOverwrittenNthInput(unsigned idx) const { |
5241 | return false; |
5242 | } |
5243 | |
5244 | unsigned getNumInputs() const; |
5245 | std::string getInputName(unsigned idx) const; |
5246 | NodeValue getNthInput(unsigned idx); |
5247 | void setNthInput(unsigned idx, NodeValue val); |
5248 | llvm::StringRef getOutputName(unsigned idx) const; |
5249 | bool hasSideEffects() const { return 0; } |
5250 | bool isCanonical() const { return 1; } |
5251 | bool isDataParallel() const { return 0; } |
5252 | std::string getDebugDesc() const; |
5253 | bool isEqual(const BatchedReduceMeanNode &other) const; |
5254 | llvm::hash_code getHash() const; |
5255 | void visit(Node *parent, NodeWalker *visitor); |
5256 | Node* clone() const; |
5257 | bool verify() const; |
5258 | }; |
5259 | } // namespace glow |
5260 | |
5261 | |
5262 | namespace glow { |
5263 | /// Performs Reduce Min operation on the Input given Axes. |
5264 | class BatchedReduceMinNode final : public Node { |
5265 | NodeHandle Batch_; |
5266 | std::vector<unsigned_t> Axes_; |
5267 | |
5268 | public: |
5269 | enum InputIndices { |
5270 | BatchIdx = 0, |
5271 | }; |
5272 | |
5273 | enum ResultIndices { |
5274 | ResultIdx = 0, |
5275 | }; |
5276 | |
5277 | BatchedReduceMinNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, std::vector<unsigned_t> Axes) |
5278 | : Node(Kinded::Kind::BatchedReduceMinNodeKind, name), Batch_(this, Batch), Axes_(Axes) { |
5279 | addResult(Result); |
5280 | } |
5281 | const NodeValue getBatch() const { return Batch_; } |
5282 | NodeValue getResult() { return getNthResult(0); } |
5283 | const NodeValue getResult() const { return getNthResult(0); } |
5284 | llvm::ArrayRef<unsigned_t> getAxes() const { return Axes_; } |
5285 | |
5286 | static bool classof(const Kinded *k) { |
5287 | return k->getKind() == Kinded::Kind::BatchedReduceMinNodeKind; |
5288 | } |
5289 | |
5290 | |
5291 | bool isOverwrittenNthInput(unsigned idx) const { |
5292 | return false; |
5293 | } |
5294 | |
5295 | unsigned getNumInputs() const; |
5296 | std::string getInputName(unsigned idx) const; |
5297 | NodeValue getNthInput(unsigned idx); |
5298 | void setNthInput(unsigned idx, NodeValue val); |
5299 | llvm::StringRef getOutputName(unsigned idx) const; |
5300 | bool hasSideEffects() const { return 0; } |
5301 | bool isCanonical() const { return 1; } |
5302 | bool isDataParallel() const { return 0; } |
5303 | std::string getDebugDesc() const; |
5304 | bool isEqual(const BatchedReduceMinNode &other) const; |
5305 | llvm::hash_code getHash() const; |
5306 | void visit(Node *parent, NodeWalker *visitor); |
5307 | Node* clone() const; |
5308 | bool verify() const; |
5309 | }; |
5310 | } // namespace glow |
5311 | |
5312 | |
5313 | namespace glow { |
5314 | /// Performs Reduce Max operation on the Input given Axes. |
5315 | class BatchedReduceMaxNode final : public Node { |
5316 | NodeHandle Batch_; |
5317 | std::vector<unsigned_t> Axes_; |
5318 | |
5319 | public: |
5320 | enum InputIndices { |
5321 | BatchIdx = 0, |
5322 | }; |
5323 | |
5324 | enum ResultIndices { |
5325 | ResultIdx = 0, |
5326 | }; |
5327 | |
5328 | BatchedReduceMaxNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, std::vector<unsigned_t> Axes) |
5329 | : Node(Kinded::Kind::BatchedReduceMaxNodeKind, name), Batch_(this, Batch), Axes_(Axes) { |
5330 | addResult(Result); |
5331 | } |
5332 | const NodeValue getBatch() const { return Batch_; } |
5333 | NodeValue getResult() { return getNthResult(0); } |
5334 | const NodeValue getResult() const { return getNthResult(0); } |
5335 | llvm::ArrayRef<unsigned_t> getAxes() const { return Axes_; } |
5336 | |
5337 | static bool classof(const Kinded *k) { |
5338 | return k->getKind() == Kinded::Kind::BatchedReduceMaxNodeKind; |
5339 | } |
5340 | |
5341 | |
5342 | bool isOverwrittenNthInput(unsigned idx) const { |
5343 | return false; |
5344 | } |
5345 | |
5346 | unsigned getNumInputs() const; |
5347 | std::string getInputName(unsigned idx) const; |
5348 | NodeValue getNthInput(unsigned idx); |
5349 | void setNthInput(unsigned idx, NodeValue val); |
5350 | llvm::StringRef getOutputName(unsigned idx) const; |
5351 | bool hasSideEffects() const { return 0; } |
5352 | bool isCanonical() const { return 1; } |
5353 | bool isDataParallel() const { return 0; } |
5354 | std::string getDebugDesc() const; |
5355 | bool isEqual(const BatchedReduceMaxNode &other) const; |
5356 | llvm::hash_code getHash() const; |
5357 | void visit(Node *parent, NodeWalker *visitor); |
5358 | Node* clone() const; |
5359 | bool verify() const; |
5360 | }; |
5361 | } // namespace glow |
5362 | |
5363 | |
5364 | namespace glow { |
5365 | /// Accumulates the product all of the layers in the batch and produce a tensor that has the same dimensions as the input tensor without the first dimension. |
5366 | class BatchedReduceProdNode final : public Node { |
5367 | NodeHandle Batch_; |
5368 | unsigned_t Axis_; |
5369 | |
5370 | public: |
5371 | enum InputIndices { |
5372 | BatchIdx = 0, |
5373 | }; |
5374 | |
5375 | enum ResultIndices { |
5376 | ResultIdx = 0, |
5377 | }; |
5378 | |
5379 | BatchedReduceProdNode(llvm::StringRef name, TypeRef Result , NodeValue Batch, unsigned_t Axis) |
5380 | : Node(Kinded::Kind::BatchedReduceProdNodeKind, name), Batch_(this, Batch), Axis_(Axis) { |
5381 | addResult(Result); |
5382 | } |
5383 | const NodeValue getBatch() const { return Batch_; } |
5384 | NodeValue getResult() { return getNthResult(0); } |
5385 | const NodeValue getResult() const { return getNthResult(0); } |
5386 | unsigned_t getAxis() const { return Axis_; } |
5387 | |
5388 | static bool classof(const Kinded *k) { |
5389 | return k->getKind() == Kinded::Kind::BatchedReduceProdNodeKind; |
5390 | } |
5391 | |
5392 | |
5393 | bool isOverwrittenNthInput(unsigned idx) const { |
5394 | return false; |
5395 | } |
5396 | |
5397 | unsigned getNumInputs() const; |
5398 | std::string getInputName(unsigned idx) const; |
5399 | NodeValue getNthInput(unsigned idx); |
5400 | void setNthInput(unsigned idx, NodeValue val); |
5401 | llvm::StringRef getOutputName(unsigned idx) const; |
5402 | bool hasSideEffects() const { return 0; } |
5403 | bool isCanonical() const { return 1; } |
5404 | bool isDataParallel() const { return 0; } |
5405 | std::string getDebugDesc() const; |
5406 | bool isEqual(const BatchedReduceProdNode &other) const; |
5407 | llvm::hash_code getHash() const; |
5408 | void visit(Node *parent, NodeWalker *visitor); |
5409 | Node* clone() const; |
5410 | bool verify() const; |
5411 | }; |
5412 | } // namespace glow |
5413 | |
5414 | |
5415 | namespace glow { |
5416 | /// Performs Channel shuffle. |
5417 | class ChannelShuffleNode final : public Node { |
5418 | NodeHandle Input_; |
5419 | unsigned_t Group_; |
5420 | unsigned_t Kernel_; |
5421 | |
5422 | public: |
5423 | enum InputIndices { |
5424 | InputIdx = 0, |
5425 | }; |
5426 | |
5427 | enum ResultIndices { |
5428 | ResultIdx = 0, |
5429 | }; |
5430 | |
5431 | ChannelShuffleNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Group, unsigned_t Kernel) |
5432 | : Node(Kinded::Kind::ChannelShuffleNodeKind, name), Input_(this, Input), Group_(Group), Kernel_(Kernel) { |
5433 | addResult(Result); |
5434 | } |
5435 | const NodeValue getInput() const { return Input_; } |
5436 | NodeValue getResult() { return getNthResult(0); } |
5437 | const NodeValue getResult() const { return getNthResult(0); } |
5438 | unsigned_t getGroup() const { return Group_; } |
5439 | unsigned_t getKernel() const { return Kernel_; } |
5440 | |
5441 | static bool classof(const Kinded *k) { |
5442 | return k->getKind() == Kinded::Kind::ChannelShuffleNodeKind; |
5443 | } |
5444 | |
5445 | |
5446 | bool isOverwrittenNthInput(unsigned idx) const { |
5447 | return false; |
5448 | } |
5449 | |
5450 | unsigned getNumInputs() const; |
5451 | std::string getInputName(unsigned idx) const; |
5452 | NodeValue getNthInput(unsigned idx); |
5453 | void setNthInput(unsigned idx, NodeValue val); |
5454 | llvm::StringRef getOutputName(unsigned idx) const; |
5455 | bool hasSideEffects() const { return 0; } |
5456 | bool isCanonical() const { return 1; } |
5457 | bool isDataParallel() const { return 0; } |
5458 | std::string getDebugDesc() const; |
5459 | bool isEqual(const ChannelShuffleNode &other) const; |
5460 | llvm::hash_code getHash() const; |
5461 | void visit(Node *parent, NodeWalker *visitor); |
5462 | Node* clone() const; |
5463 | bool verify() const; |
5464 | }; |
5465 | } // namespace glow |
5466 | |
5467 | |
5468 | namespace glow { |
5469 | /// Performs a Cumulative Sum operation over a 1D vector with flags for working in exclusive mode and in reverse. In each case the output size is the same as in input size.e.g (default) [1, 2, 3, 4] -> [1, 3, 6, 10]. (exclusive) [1, 2, 3, 4] -> [0, 1, 3, 6]. (reverse) [1, 2, 3, 4] -> [10, 9, 7, 4]. |
5470 | class CumSumNode final : public Node { |
5471 | NodeHandle Input_; |
5472 | int64_t Dim_; |
5473 | unsigned_t Exclusive_; |
5474 | unsigned_t Reverse_; |
5475 | |
5476 | public: |
5477 | enum InputIndices { |
5478 | InputIdx = 0, |
5479 | }; |
5480 | |
5481 | enum ResultIndices { |
5482 | ResultIdx = 0, |
5483 | }; |
5484 | |
5485 | CumSumNode(llvm::StringRef name, TypeRef Result , NodeValue Input, int64_t Dim, unsigned_t Exclusive, unsigned_t Reverse) |
5486 | : Node(Kinded::Kind::CumSumNodeKind, name), Input_(this, Input), Dim_(Dim), Exclusive_(Exclusive), Reverse_(Reverse) { |
5487 | addResult(Result); |
5488 | } |
5489 | const NodeValue getInput() const { return Input_; } |
5490 | NodeValue getResult() { return getNthResult(0); } |
5491 | const NodeValue getResult() const { return getNthResult(0); } |
5492 | int64_t getDim() const { return Dim_; } |
5493 | unsigned_t getExclusive() const { return Exclusive_; } |
5494 | unsigned_t getReverse() const { return Reverse_; } |
5495 | |
5496 | static bool classof(const Kinded *k) { |
5497 | return k->getKind() == Kinded::Kind::CumSumNodeKind; |
5498 | } |
5499 | |
5500 | |
5501 | bool isOverwrittenNthInput(unsigned idx) const { |
5502 | return false; |
5503 | } |
5504 | |
5505 | unsigned getNumInputs() const; |
5506 | std::string getInputName(unsigned idx) const; |
5507 | NodeValue getNthInput(unsigned idx); |
5508 | void setNthInput(unsigned idx, NodeValue val); |
5509 | llvm::StringRef getOutputName(unsigned idx) const; |
5510 | bool hasSideEffects() const { return 0; } |
5511 | bool isCanonical() const { return 1; } |
5512 | bool isDataParallel() const { return 1; } |
5513 | std::string getDebugDesc() const; |
5514 | bool isEqual(const CumSumNode &other) const; |
5515 | llvm::hash_code getHash() const; |
5516 | void visit(Node *parent, NodeWalker *visitor); |
5517 | Node* clone() const; |
5518 | bool verify() const; |
5519 | }; |
5520 | } // namespace glow |
5521 | |
5522 | |
5523 | namespace glow { |
5524 | /// Sums slices of the outermost dimension of Data in groups defined by Lengths. The first Lengths[0] slices are added together and stored in Result[0], the subsequent Lengths[1] slices are added together and stored in Result[1], etc. |
5525 | class LengthsSumNode final : public Node { |
5526 | NodeHandle Data_; |
5527 | NodeHandle Lengths_; |
5528 | |
5529 | public: |
5530 | enum InputIndices { |
5531 | DataIdx = 0, |
5532 | LengthsIdx = 1, |
5533 | }; |
5534 | |
5535 | enum ResultIndices { |
5536 | ResultIdx = 0, |
5537 | }; |
5538 | |
5539 | LengthsSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Lengths) |
5540 | : Node(Kinded::Kind::LengthsSumNodeKind, name), Data_(this, Data), Lengths_(this, Lengths) { |
5541 | addResult(Result); |
5542 | } |
5543 | const NodeValue getData() const { return Data_; } |
5544 | const NodeValue getLengths() const { return Lengths_; } |
5545 | NodeValue getResult() { return getNthResult(0); } |
5546 | const NodeValue getResult() const { return getNthResult(0); } |
5547 | |
5548 | static bool classof(const Kinded *k) { |
5549 | return k->getKind() == Kinded::Kind::LengthsSumNodeKind; |
5550 | } |
5551 | |
5552 | |
5553 | bool isOverwrittenNthInput(unsigned idx) const { |
5554 | return false; |
5555 | } |
5556 | |
5557 | unsigned getNumInputs() const; |
5558 | std::string getInputName(unsigned idx) const; |
5559 | NodeValue getNthInput(unsigned idx); |
5560 | void setNthInput(unsigned idx, NodeValue val); |
5561 | llvm::StringRef getOutputName(unsigned idx) const; |
5562 | bool hasSideEffects() const { return 0; } |
5563 | bool isCanonical() const { return 1; } |
5564 | bool isDataParallel() const { return 0; } |
5565 | std::string getDebugDesc() const; |
5566 | bool isEqual(const LengthsSumNode &other) const; |
5567 | llvm::hash_code getHash() const; |
5568 | void visit(Node *parent, NodeWalker *visitor); |
5569 | Node* clone() const; |
5570 | bool verify() const; |
5571 | }; |
5572 | } // namespace glow |
5573 | |
5574 | |
5575 | namespace glow { |
5576 | class SparseLengthsSumGradNode final : public Node { |
5577 | NodeHandle Data_; |
5578 | NodeHandle Indices_; |
5579 | NodeHandle Lengths_; |
5580 | NodeHandle OriginalOutputForResult_; |
5581 | NodeHandle GradOfOriginalOutputNamedResult_; |
5582 | glow::LengthsMode LengthsMode_; |
5583 | float AvgLength_; |
5584 | |
5585 | public: |
5586 | enum InputIndices { |
5587 | DataIdx = 0, |
5588 | IndicesIdx = 1, |
5589 | LengthsIdx = 2, |
5590 | OriginalOutputForResultIdx = 3, |
5591 | GradOfOriginalOutputNamedResultIdx = 4, |
5592 | }; |
5593 | |
5594 | enum ResultIndices { |
5595 | GradOfInputNamedDataIdx = 0, |
5596 | GradOfInputNamedIndicesIdx = 1, |
5597 | GradOfInputNamedLengthsIdx = 2, |
5598 | }; |
5599 | |
5600 | SparseLengthsSumGradNode(llvm::StringRef name, NodeValue Data, NodeValue Indices, NodeValue Lengths, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, glow::LengthsMode LengthsMode, float AvgLength) |
5601 | : Node(Kinded::Kind::SparseLengthsSumGradNodeKind, name), Data_(this, Data), Indices_(this, Indices), Lengths_(this, Lengths), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), LengthsMode_(LengthsMode), AvgLength_(AvgLength) { |
5602 | addResult(Data.getType()); |
5603 | addResult(Indices.getType()); |
5604 | addResult(Lengths.getType()); |
5605 | } |
5606 | const NodeValue getData() const { return Data_; } |
5607 | const NodeValue getIndices() const { return Indices_; } |
5608 | const NodeValue getLengths() const { return Lengths_; } |
5609 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
5610 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
5611 | NodeValue getGradOfInputNamedData() { return getNthResult(0); } |
5612 | const NodeValue getGradOfInputNamedData() const { return getNthResult(0); } |
5613 | NodeValue getGradOfInputNamedIndices() { return getNthResult(1); } |
5614 | const NodeValue getGradOfInputNamedIndices() const { return getNthResult(1); } |
5615 | NodeValue getGradOfInputNamedLengths() { return getNthResult(2); } |
5616 | const NodeValue getGradOfInputNamedLengths() const { return getNthResult(2); } |
5617 | glow::LengthsMode getLengthsMode() const { return LengthsMode_; } |
5618 | float getAvgLength() const { return AvgLength_; } |
5619 | |
5620 | static bool classof(const Kinded *k) { |
5621 | return k->getKind() == Kinded::Kind::SparseLengthsSumGradNodeKind; |
5622 | } |
5623 | |
5624 | |
5625 | bool isOverwrittenNthInput(unsigned idx) const { |
5626 | return false; |
5627 | } |
5628 | |
5629 | unsigned getNumInputs() const; |
5630 | std::string getInputName(unsigned idx) const; |
5631 | NodeValue getNthInput(unsigned idx); |
5632 | void setNthInput(unsigned idx, NodeValue val); |
5633 | llvm::StringRef getOutputName(unsigned idx) const; |
5634 | bool hasSideEffects() const { return 0; } |
5635 | bool isCanonical() const { return 1; } |
5636 | bool isDataParallel() const { return 0; } |
5637 | std::string getDebugDesc() const; |
5638 | bool isEqual(const SparseLengthsSumGradNode &other) const; |
5639 | llvm::hash_code getHash() const; |
5640 | void visit(Node *parent, NodeWalker *visitor); |
5641 | Node* clone() const; |
5642 | bool verify() const; |
5643 | }; |
5644 | } // namespace glow |
5645 | |
5646 | |
5647 | namespace glow { |
5648 | /// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). |
5649 | class SparseLengthsSumNode final : public Node { |
5650 | NodeHandle Data_; |
5651 | NodeHandle Indices_; |
5652 | NodeHandle Lengths_; |
5653 | glow::LengthsMode LengthsMode_; |
5654 | float AvgLength_; |
5655 | |
5656 | public: |
5657 | enum InputIndices { |
5658 | DataIdx = 0, |
5659 | IndicesIdx = 1, |
5660 | LengthsIdx = 2, |
5661 | }; |
5662 | |
5663 | enum ResultIndices { |
5664 | ResultIdx = 0, |
5665 | }; |
5666 | |
5667 | SparseLengthsSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, NodeValue Lengths, glow::LengthsMode LengthsMode, float AvgLength) |
5668 | : Node(Kinded::Kind::SparseLengthsSumNodeKind, name), Data_(this, Data), Indices_(this, Indices), Lengths_(this, Lengths), LengthsMode_(LengthsMode), AvgLength_(AvgLength) { |
5669 | addResult(Result); |
5670 | } |
5671 | const NodeValue getData() const { return Data_; } |
5672 | const NodeValue getIndices() const { return Indices_; } |
5673 | const NodeValue getLengths() const { return Lengths_; } |
5674 | NodeValue getResult() { return getNthResult(0); } |
5675 | const NodeValue getResult() const { return getNthResult(0); } |
5676 | glow::LengthsMode getLengthsMode() const { return LengthsMode_; } |
5677 | float getAvgLength() const { return AvgLength_; } |
5678 | |
5679 | static bool classof(const Kinded *k) { |
5680 | return k->getKind() == Kinded::Kind::SparseLengthsSumNodeKind; |
5681 | } |
5682 | |
5683 | |
5684 | bool isOverwrittenNthInput(unsigned idx) const { |
5685 | return false; |
5686 | } |
5687 | |
5688 | unsigned getNumInputs() const; |
5689 | std::string getInputName(unsigned idx) const; |
5690 | NodeValue getNthInput(unsigned idx); |
5691 | void setNthInput(unsigned idx, NodeValue val); |
5692 | llvm::StringRef getOutputName(unsigned idx) const; |
5693 | bool hasSideEffects() const { return 0; } |
5694 | bool isCanonical() const { return 1; } |
5695 | bool isDataParallel() const { return 0; } |
5696 | std::string getDebugDesc() const; |
5697 | bool isEqual(const SparseLengthsSumNode &other) const; |
5698 | llvm::hash_code getHash() const; |
5699 | void visit(Node *parent, NodeWalker *visitor); |
5700 | Node* clone() const; |
5701 | bool verify() const; |
5702 | SparseLengthsSumGradNode *getGrad(GraphGradMapper &builder); |
5703 | }; |
5704 | } // namespace glow |
5705 | |
5706 | |
5707 | namespace glow { |
5708 | class SparseLengthsWeightedSumGradNode final : public Node { |
5709 | NodeHandle Data_; |
5710 | NodeHandle Weights_; |
5711 | NodeHandle Indices_; |
5712 | NodeHandle Lengths_; |
5713 | NodeHandle OriginalOutputForResult_; |
5714 | NodeHandle GradOfOriginalOutputNamedResult_; |
5715 | glow::LengthsMode LengthsMode_; |
5716 | float AvgLength_; |
5717 | |
5718 | public: |
5719 | enum InputIndices { |
5720 | DataIdx = 0, |
5721 | WeightsIdx = 1, |
5722 | IndicesIdx = 2, |
5723 | LengthsIdx = 3, |
5724 | OriginalOutputForResultIdx = 4, |
5725 | GradOfOriginalOutputNamedResultIdx = 5, |
5726 | }; |
5727 | |
5728 | enum ResultIndices { |
5729 | GradOfInputNamedDataIdx = 0, |
5730 | GradOfInputNamedWeightsIdx = 1, |
5731 | GradOfInputNamedIndicesIdx = 2, |
5732 | GradOfInputNamedLengthsIdx = 3, |
5733 | }; |
5734 | |
5735 | SparseLengthsWeightedSumGradNode(llvm::StringRef name, NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Lengths, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult, glow::LengthsMode LengthsMode, float AvgLength) |
5736 | : Node(Kinded::Kind::SparseLengthsWeightedSumGradNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Lengths_(this, Lengths), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult), LengthsMode_(LengthsMode), AvgLength_(AvgLength) { |
5737 | addResult(Data.getType()); |
5738 | addResult(Weights.getType()); |
5739 | addResult(Indices.getType()); |
5740 | addResult(Lengths.getType()); |
5741 | } |
5742 | const NodeValue getData() const { return Data_; } |
5743 | const NodeValue getWeights() const { return Weights_; } |
5744 | const NodeValue getIndices() const { return Indices_; } |
5745 | const NodeValue getLengths() const { return Lengths_; } |
5746 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
5747 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
5748 | NodeValue getGradOfInputNamedData() { return getNthResult(0); } |
5749 | const NodeValue getGradOfInputNamedData() const { return getNthResult(0); } |
5750 | NodeValue getGradOfInputNamedWeights() { return getNthResult(1); } |
5751 | const NodeValue getGradOfInputNamedWeights() const { return getNthResult(1); } |
5752 | NodeValue getGradOfInputNamedIndices() { return getNthResult(2); } |
5753 | const NodeValue getGradOfInputNamedIndices() const { return getNthResult(2); } |
5754 | NodeValue getGradOfInputNamedLengths() { return getNthResult(3); } |
5755 | const NodeValue getGradOfInputNamedLengths() const { return getNthResult(3); } |
5756 | glow::LengthsMode getLengthsMode() const { return LengthsMode_; } |
5757 | float getAvgLength() const { return AvgLength_; } |
5758 | |
5759 | static bool classof(const Kinded *k) { |
5760 | return k->getKind() == Kinded::Kind::SparseLengthsWeightedSumGradNodeKind; |
5761 | } |
5762 | |
5763 | |
5764 | bool isOverwrittenNthInput(unsigned idx) const { |
5765 | return false; |
5766 | } |
5767 | |
5768 | unsigned getNumInputs() const; |
5769 | std::string getInputName(unsigned idx) const; |
5770 | NodeValue getNthInput(unsigned idx); |
5771 | void setNthInput(unsigned idx, NodeValue val); |
5772 | llvm::StringRef getOutputName(unsigned idx) const; |
5773 | bool hasSideEffects() const { return 0; } |
5774 | bool isCanonical() const { return 1; } |
5775 | bool isDataParallel() const { return 0; } |
5776 | std::string getDebugDesc() const; |
5777 | bool isEqual(const SparseLengthsWeightedSumGradNode &other) const; |
5778 | llvm::hash_code getHash() const; |
5779 | void visit(Node *parent, NodeWalker *visitor); |
5780 | Node* clone() const; |
5781 | bool verify() const; |
5782 | }; |
5783 | } // namespace glow |
5784 | |
5785 | |
5786 | namespace glow { |
5787 | /// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). Before doing aggregation, each individual slice is scaled by its weight: Result[0] = Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... It implies that len(Weights) == len(Indices). |
5788 | class SparseLengthsWeightedSumNode final : public Node { |
5789 | NodeHandle Data_; |
5790 | NodeHandle Weights_; |
5791 | NodeHandle Indices_; |
5792 | NodeHandle Lengths_; |
5793 | glow::LengthsMode LengthsMode_; |
5794 | float AvgLength_; |
5795 | |
5796 | public: |
5797 | enum InputIndices { |
5798 | DataIdx = 0, |
5799 | WeightsIdx = 1, |
5800 | IndicesIdx = 2, |
5801 | LengthsIdx = 3, |
5802 | }; |
5803 | |
5804 | enum ResultIndices { |
5805 | ResultIdx = 0, |
5806 | }; |
5807 | |
5808 | SparseLengthsWeightedSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Lengths, glow::LengthsMode LengthsMode, float AvgLength) |
5809 | : Node(Kinded::Kind::SparseLengthsWeightedSumNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Lengths_(this, Lengths), LengthsMode_(LengthsMode), AvgLength_(AvgLength) { |
5810 | addResult(Result); |
5811 | } |
5812 | const NodeValue getData() const { return Data_; } |
5813 | const NodeValue getWeights() const { return Weights_; } |
5814 | const NodeValue getIndices() const { return Indices_; } |
5815 | const NodeValue getLengths() const { return Lengths_; } |
5816 | NodeValue getResult() { return getNthResult(0); } |
5817 | const NodeValue getResult() const { return getNthResult(0); } |
5818 | glow::LengthsMode getLengthsMode() const { return LengthsMode_; } |
5819 | float getAvgLength() const { return AvgLength_; } |
5820 | |
5821 | static bool classof(const Kinded *k) { |
5822 | return k->getKind() == Kinded::Kind::SparseLengthsWeightedSumNodeKind; |
5823 | } |
5824 | |
5825 | |
5826 | bool isOverwrittenNthInput(unsigned idx) const { |
5827 | return false; |
5828 | } |
5829 | |
5830 | unsigned getNumInputs() const; |
5831 | std::string getInputName(unsigned idx) const; |
5832 | NodeValue getNthInput(unsigned idx); |
5833 | void setNthInput(unsigned idx, NodeValue val); |
5834 | llvm::StringRef getOutputName(unsigned idx) const; |
5835 | bool hasSideEffects() const { return 0; } |
5836 | bool isCanonical() const { return 1; } |
5837 | bool isDataParallel() const { return 0; } |
5838 | std::string getDebugDesc() const; |
5839 | bool isEqual(const SparseLengthsWeightedSumNode &other) const; |
5840 | llvm::hash_code getHash() const; |
5841 | void visit(Node *parent, NodeWalker *visitor); |
5842 | Node* clone() const; |
5843 | bool verify() const; |
5844 | SparseLengthsWeightedSumGradNode *getGrad(GraphGradMapper &builder); |
5845 | }; |
5846 | } // namespace glow |
5847 | |
5848 | |
5849 | namespace glow { |
5850 | /// Gathers slices of the outer-most dimension of Weights indexed by Indices tensor. |
5851 | class EmbeddingNode final : public Node { |
5852 | NodeHandle Weights_; |
5853 | NodeHandle Indices_; |
5854 | int64_t PadIdx_; |
5855 | bool Scale_; |
5856 | bool Sparse_; |
5857 | |
5858 | public: |
5859 | enum InputIndices { |
5860 | WeightsIdx = 0, |
5861 | IndicesIdx = 1, |
5862 | }; |
5863 | |
5864 | enum ResultIndices { |
5865 | ResultIdx = 0, |
5866 | }; |
5867 | |
5868 | EmbeddingNode(llvm::StringRef name, TypeRef Result , NodeValue Weights, NodeValue Indices, int64_t PadIdx, bool Scale, bool Sparse) |
5869 | : Node(Kinded::Kind::EmbeddingNodeKind, name), Weights_(this, Weights), Indices_(this, Indices), PadIdx_(PadIdx), Scale_(Scale), Sparse_(Sparse) { |
5870 | addResult(Result); |
5871 | } |
5872 | const NodeValue getWeights() const { return Weights_; } |
5873 | const NodeValue getIndices() const { return Indices_; } |
5874 | NodeValue getResult() { return getNthResult(0); } |
5875 | const NodeValue getResult() const { return getNthResult(0); } |
5876 | int64_t getPadIdx() const { return PadIdx_; } |
5877 | bool getScale() const { return Scale_; } |
5878 | bool getSparse() const { return Sparse_; } |
5879 | |
5880 | static bool classof(const Kinded *k) { |
5881 | return k->getKind() == Kinded::Kind::EmbeddingNodeKind; |
5882 | } |
5883 | |
5884 | |
5885 | bool isOverwrittenNthInput(unsigned idx) const { |
5886 | return false; |
5887 | } |
5888 | |
5889 | unsigned getNumInputs() const; |
5890 | std::string getInputName(unsigned idx) const; |
5891 | NodeValue getNthInput(unsigned idx); |
5892 | void setNthInput(unsigned idx, NodeValue val); |
5893 | llvm::StringRef getOutputName(unsigned idx) const; |
5894 | bool hasSideEffects() const { return 0; } |
5895 | bool isCanonical() const { return 1; } |
5896 | bool isDataParallel() const { return 0; } |
5897 | std::string getDebugDesc() const; |
5898 | bool isEqual(const EmbeddingNode &other) const; |
5899 | llvm::hash_code getHash() const; |
5900 | void visit(Node *parent, NodeWalker *visitor); |
5901 | Node* clone() const; |
5902 | bool verify() const; |
5903 | }; |
5904 | } // namespace glow |
5905 | |
5906 | |
5907 | namespace glow { |
5908 | /// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Offsets) entries: first slice between Offsets[0] and Offsets[1] (or total length if there's only one elem in Offsets) are aggregated to Result[0], etc. I.e. largest offset must be less than or equal to len(Indices). Before doing aggregation, each individual slice is scaled by its weight: Result[0] = Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... It implies that len(Weights) == len(Indices). |
5909 | class EmbeddingBagNode final : public Node { |
5910 | NodeHandle Data_; |
5911 | NodeHandle Weights_; |
5912 | NodeHandle Indices_; |
5913 | NodeHandle Offsets_; |
5914 | bool HasEndOffset_; |
5915 | glow::LengthsMode LengthsMode_; |
5916 | float AvgLength_; |
5917 | |
5918 | public: |
5919 | enum InputIndices { |
5920 | DataIdx = 0, |
5921 | WeightsIdx = 1, |
5922 | IndicesIdx = 2, |
5923 | OffsetsIdx = 3, |
5924 | }; |
5925 | |
5926 | enum ResultIndices { |
5927 | ResultIdx = 0, |
5928 | }; |
5929 | |
5930 | EmbeddingBagNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Offsets, bool HasEndOffset, glow::LengthsMode LengthsMode, float AvgLength) |
5931 | : Node(Kinded::Kind::EmbeddingBagNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Offsets_(this, Offsets), HasEndOffset_(HasEndOffset), LengthsMode_(LengthsMode), AvgLength_(AvgLength) { |
5932 | addResult(Result); |
5933 | } |
5934 | const NodeValue getData() const { return Data_; } |
5935 | const NodeValue getWeights() const { return Weights_; } |
5936 | const NodeValue getIndices() const { return Indices_; } |
5937 | const NodeValue getOffsets() const { return Offsets_; } |
5938 | NodeValue getResult() { return getNthResult(0); } |
5939 | const NodeValue getResult() const { return getNthResult(0); } |
5940 | bool getHasEndOffset() const { return HasEndOffset_; } |
5941 | glow::LengthsMode getLengthsMode() const { return LengthsMode_; } |
5942 | float getAvgLength() const { return AvgLength_; } |
5943 | |
5944 | static bool classof(const Kinded *k) { |
5945 | return k->getKind() == Kinded::Kind::EmbeddingBagNodeKind; |
5946 | } |
5947 | |
5948 | |
5949 | bool isOverwrittenNthInput(unsigned idx) const { |
5950 | return false; |
5951 | } |
5952 | |
5953 | unsigned getNumInputs() const; |
5954 | std::string getInputName(unsigned idx) const; |
5955 | NodeValue getNthInput(unsigned idx); |
5956 | void setNthInput(unsigned idx, NodeValue val); |
5957 | llvm::StringRef getOutputName(unsigned idx) const; |
5958 | bool hasSideEffects() const { return 0; } |
5959 | bool isCanonical() const { return 1; } |
5960 | bool isDataParallel() const { return 0; } |
5961 | std::string getDebugDesc() const; |
5962 | bool isEqual(const EmbeddingBagNode &other) const; |
5963 | llvm::hash_code getHash() const; |
5964 | void visit(Node *parent, NodeWalker *visitor); |
5965 | Node* clone() const; |
5966 | bool verify() const; |
5967 | }; |
5968 | } // namespace glow |
5969 | |
5970 | |
5971 | namespace glow { |
5972 | /// Same as FusedRowwiseQuantizedSparseLengthsWeightedSum but using offsets instead of lengths. |
5973 | class EmbeddingBagByteRowwiseOffsetsNode final : public Node { |
5974 | NodeHandle Data_; |
5975 | NodeHandle Weights_; |
5976 | NodeHandle Indices_; |
5977 | NodeHandle Offsets_; |
5978 | bool UseFP16Accumulation_; |
5979 | bool HasEndOffset_; |
5980 | glow::LengthsMode LengthsMode_; |
5981 | float AvgLength_; |
5982 | |
5983 | public: |
5984 | enum InputIndices { |
5985 | DataIdx = 0, |
5986 | WeightsIdx = 1, |
5987 | IndicesIdx = 2, |
5988 | OffsetsIdx = 3, |
5989 | }; |
5990 | |
5991 | enum ResultIndices { |
5992 | ResultIdx = 0, |
5993 | }; |
5994 | |
5995 | EmbeddingBagByteRowwiseOffsetsNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Offsets, bool UseFP16Accumulation, bool HasEndOffset, glow::LengthsMode LengthsMode, float AvgLength) |
5996 | : Node(Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Offsets_(this, Offsets), UseFP16Accumulation_(UseFP16Accumulation), HasEndOffset_(HasEndOffset), LengthsMode_(LengthsMode), AvgLength_(AvgLength) { |
5997 | addResult(Result); |
5998 | } |
5999 | const NodeValue getData() const { return Data_; } |
6000 | const NodeValue getWeights() const { return Weights_; } |
6001 | const NodeValue getIndices() const { return Indices_; } |
6002 | const NodeValue getOffsets() const { return Offsets_; } |
6003 | NodeValue getResult() { return getNthResult(0); } |
6004 | const NodeValue getResult() const { return getNthResult(0); } |
6005 | bool getUseFP16Accumulation() const { return UseFP16Accumulation_; } |
6006 | void setUseFP16Accumulation(bool a) {UseFP16Accumulation_ = a; } |
6007 | bool getHasEndOffset() const { return HasEndOffset_; } |
6008 | glow::LengthsMode getLengthsMode() const { return LengthsMode_; } |
6009 | float getAvgLength() const { return AvgLength_; } |
6010 | |
6011 | static bool classof(const Kinded *k) { |
6012 | return k->getKind() == Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind; |
6013 | } |
6014 | |
6015 | |
6016 | bool isOverwrittenNthInput(unsigned idx) const { |
6017 | return false; |
6018 | } |
6019 | |
6020 | unsigned getNumInputs() const; |
6021 | std::string getInputName(unsigned idx) const; |
6022 | NodeValue getNthInput(unsigned idx); |
6023 | void setNthInput(unsigned idx, NodeValue val); |
6024 | llvm::StringRef getOutputName(unsigned idx) const; |
6025 | bool hasSideEffects() const { return 0; } |
6026 | bool isCanonical() const { return 1; } |
6027 | bool isDataParallel() const { return 0; } |
6028 | std::string getDebugDesc() const; |
6029 | bool isEqual(const EmbeddingBagByteRowwiseOffsetsNode &other) const; |
6030 | llvm::hash_code getHash() const; |
6031 | void visit(Node *parent, NodeWalker *visitor); |
6032 | Node* clone() const; |
6033 | bool verify() const; |
6034 | }; |
6035 | } // namespace glow |
6036 | |
6037 | |
6038 | namespace glow { |
6039 | /// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). Before doing aggregation, each individual slice is scaled by its weight: Result[0] = Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... It implies that len(Weights) == len(Indices). The input data is rowwise-quantized, where the Scales and Offsets are 1D tensors of length equal to the first dim of Data. |
6040 | class RowwiseQuantizedSparseLengthsWeightedSumNode final : public Node { |
6041 | NodeHandle Data_; |
6042 | NodeHandle Scales_; |
6043 | NodeHandle Offsets_; |
6044 | NodeHandle Weights_; |
6045 | NodeHandle Indices_; |
6046 | NodeHandle Lengths_; |
6047 | bool UseFP16Accumulation_; |
6048 | glow::LengthsMode LengthsMode_; |
6049 | float AvgLength_; |
6050 | |
6051 | public: |
6052 | enum InputIndices { |
6053 | DataIdx = 0, |
6054 | ScalesIdx = 1, |
6055 | OffsetsIdx = 2, |
6056 | WeightsIdx = 3, |
6057 | IndicesIdx = 4, |
6058 | LengthsIdx = 5, |
6059 | }; |
6060 | |
6061 | enum ResultIndices { |
6062 | ResultIdx = 0, |
6063 | }; |
6064 | |
6065 | RowwiseQuantizedSparseLengthsWeightedSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Scales, NodeValue Offsets, NodeValue Weights, NodeValue Indices, NodeValue Lengths, bool UseFP16Accumulation, glow::LengthsMode LengthsMode, float AvgLength) |
6066 | : Node(Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind, name), Data_(this, Data), Scales_(this, Scales), Offsets_(this, Offsets), Weights_(this, Weights), Indices_(this, Indices), Lengths_(this, Lengths), UseFP16Accumulation_(UseFP16Accumulation), LengthsMode_(LengthsMode), AvgLength_(AvgLength) { |
6067 | addResult(Result); |
6068 | } |
6069 | const NodeValue getData() const { return Data_; } |
6070 | const NodeValue getScales() const { return Scales_; } |
6071 | const NodeValue getOffsets() const { return Offsets_; } |
6072 | const NodeValue getWeights() const { return Weights_; } |
6073 | const NodeValue getIndices() const { return Indices_; } |
6074 | const NodeValue getLengths() const { return Lengths_; } |
6075 | NodeValue getResult() { return getNthResult(0); } |
6076 | const NodeValue getResult() const { return getNthResult(0); } |
6077 | bool getUseFP16Accumulation() const { return UseFP16Accumulation_; } |
6078 | void setUseFP16Accumulation(bool a) {UseFP16Accumulation_ = a; } |
6079 | glow::LengthsMode getLengthsMode() const { return LengthsMode_; } |
6080 | float getAvgLength() const { return AvgLength_; } |
6081 | |
6082 | static bool classof(const Kinded *k) { |
6083 | return k->getKind() == Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind; |
6084 | } |
6085 | |
6086 | |
6087 | bool isOverwrittenNthInput(unsigned idx) const { |
6088 | return false; |
6089 | } |
6090 | |
6091 | unsigned getNumInputs() const; |
6092 | std::string getInputName(unsigned idx) const; |
6093 | NodeValue getNthInput(unsigned idx); |
6094 | void setNthInput(unsigned idx, NodeValue val); |
6095 | llvm::StringRef getOutputName(unsigned idx) const; |
6096 | bool hasSideEffects() const { return 0; } |
6097 | bool isCanonical() const { return 1; } |
6098 | bool isDataParallel() const { return 0; } |
6099 | std::string getDebugDesc() const; |
6100 | bool isEqual(const RowwiseQuantizedSparseLengthsWeightedSumNode &other) const; |
6101 | llvm::hash_code getHash() const; |
6102 | void visit(Node *parent, NodeWalker *visitor); |
6103 | Node* clone() const; |
6104 | bool verify() const; |
6105 | }; |
6106 | } // namespace glow |
6107 | |
6108 | |
6109 | namespace glow { |
6110 | /// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). Before doing aggregation, each individual slice is scaled by its weight: Result[0] = Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... It implies that len(Weights) == len(Indices). The input data is fused rowwise-quantized, where the Scales and Offsets are appended to the end of each row. Thus, Data must be a two-dimensional tensor. |
6111 | class FusedRowwiseQuantizedSparseLengthsWeightedSumNode final : public Node { |
6112 | NodeHandle Data_; |
6113 | NodeHandle Weights_; |
6114 | NodeHandle Indices_; |
6115 | NodeHandle Lengths_; |
6116 | bool UseFP16Accumulation_; |
6117 | glow::LengthsMode LengthsMode_; |
6118 | float AvgLength_; |
6119 | |
6120 | public: |
6121 | enum InputIndices { |
6122 | DataIdx = 0, |
6123 | WeightsIdx = 1, |
6124 | IndicesIdx = 2, |
6125 | LengthsIdx = 3, |
6126 | }; |
6127 | |
6128 | enum ResultIndices { |
6129 | ResultIdx = 0, |
6130 | }; |
6131 | |
6132 | FusedRowwiseQuantizedSparseLengthsWeightedSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Weights, NodeValue Indices, NodeValue Lengths, bool UseFP16Accumulation, glow::LengthsMode LengthsMode, float AvgLength) |
6133 | : Node(Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind, name), Data_(this, Data), Weights_(this, Weights), Indices_(this, Indices), Lengths_(this, Lengths), UseFP16Accumulation_(UseFP16Accumulation), LengthsMode_(LengthsMode), AvgLength_(AvgLength) { |
6134 | addResult(Result); |
6135 | } |
6136 | const NodeValue getData() const { return Data_; } |
6137 | const NodeValue getWeights() const { return Weights_; } |
6138 | const NodeValue getIndices() const { return Indices_; } |
6139 | const NodeValue getLengths() const { return Lengths_; } |
6140 | NodeValue getResult() { return getNthResult(0); } |
6141 | const NodeValue getResult() const { return getNthResult(0); } |
6142 | bool getUseFP16Accumulation() const { return UseFP16Accumulation_; } |
6143 | void setUseFP16Accumulation(bool a) {UseFP16Accumulation_ = a; } |
6144 | glow::LengthsMode getLengthsMode() const { return LengthsMode_; } |
6145 | float getAvgLength() const { return AvgLength_; } |
6146 | |
6147 | static bool classof(const Kinded *k) { |
6148 | return k->getKind() == Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind; |
6149 | } |
6150 | |
6151 | |
6152 | bool isOverwrittenNthInput(unsigned idx) const { |
6153 | return false; |
6154 | } |
6155 | |
6156 | unsigned getNumInputs() const; |
6157 | std::string getInputName(unsigned idx) const; |
6158 | NodeValue getNthInput(unsigned idx); |
6159 | void setNthInput(unsigned idx, NodeValue val); |
6160 | llvm::StringRef getOutputName(unsigned idx) const; |
6161 | bool hasSideEffects() const { return 0; } |
6162 | bool isCanonical() const { return 1; } |
6163 | bool isDataParallel() const { return 0; } |
6164 | std::string getDebugDesc() const; |
6165 | bool isEqual(const FusedRowwiseQuantizedSparseLengthsWeightedSumNode &other) const; |
6166 | llvm::hash_code getHash() const; |
6167 | void visit(Node *parent, NodeWalker *visitor); |
6168 | Node* clone() const; |
6169 | bool verify() const; |
6170 | }; |
6171 | } // namespace glow |
6172 | |
6173 | |
6174 | namespace glow { |
6175 | /// Gathers slices of the outer-most dimension of Data indexed by Indices vector, and then accumulates them into len(Lengths) entries: first Lengths[0] slices are aggregated to Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices). The input data is fused rowwise-quantized, where the Scales and Offsets are appended to the end of each row. Thus, Data must be a two-dimensional tensor. |
6176 | class FusedRowwiseQuantizedSparseLengthsSumNode final : public Node { |
6177 | NodeHandle Data_; |
6178 | NodeHandle Indices_; |
6179 | NodeHandle Lengths_; |
6180 | bool UseFP16Accumulation_; |
6181 | glow::LengthsMode LengthsMode_; |
6182 | float AvgLength_; |
6183 | |
6184 | public: |
6185 | enum InputIndices { |
6186 | DataIdx = 0, |
6187 | IndicesIdx = 1, |
6188 | LengthsIdx = 2, |
6189 | }; |
6190 | |
6191 | enum ResultIndices { |
6192 | ResultIdx = 0, |
6193 | }; |
6194 | |
6195 | FusedRowwiseQuantizedSparseLengthsSumNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, NodeValue Lengths, bool UseFP16Accumulation, glow::LengthsMode LengthsMode, float AvgLength) |
6196 | : Node(Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind, name), Data_(this, Data), Indices_(this, Indices), Lengths_(this, Lengths), UseFP16Accumulation_(UseFP16Accumulation), LengthsMode_(LengthsMode), AvgLength_(AvgLength) { |
6197 | addResult(Result); |
6198 | } |
6199 | const NodeValue getData() const { return Data_; } |
6200 | const NodeValue getIndices() const { return Indices_; } |
6201 | const NodeValue getLengths() const { return Lengths_; } |
6202 | NodeValue getResult() { return getNthResult(0); } |
6203 | const NodeValue getResult() const { return getNthResult(0); } |
6204 | bool getUseFP16Accumulation() const { return UseFP16Accumulation_; } |
6205 | void setUseFP16Accumulation(bool a) {UseFP16Accumulation_ = a; } |
6206 | glow::LengthsMode getLengthsMode() const { return LengthsMode_; } |
6207 | float getAvgLength() const { return AvgLength_; } |
6208 | |
6209 | static bool classof(const Kinded *k) { |
6210 | return k->getKind() == Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind; |
6211 | } |
6212 | |
6213 | |
6214 | bool isOverwrittenNthInput(unsigned idx) const { |
6215 | return false; |
6216 | } |
6217 | |
6218 | unsigned getNumInputs() const; |
6219 | std::string getInputName(unsigned idx) const; |
6220 | NodeValue getNthInput(unsigned idx); |
6221 | void setNthInput(unsigned idx, NodeValue val); |
6222 | llvm::StringRef getOutputName(unsigned idx) const; |
6223 | bool hasSideEffects() const { return 0; } |
6224 | bool isCanonical() const { return 1; } |
6225 | bool isDataParallel() const { return 0; } |
6226 | std::string getDebugDesc() const; |
6227 | bool isEqual(const FusedRowwiseQuantizedSparseLengthsSumNode &other) const; |
6228 | llvm::hash_code getHash() const; |
6229 | void visit(Node *parent, NodeWalker *visitor); |
6230 | Node* clone() const; |
6231 | bool verify() const; |
6232 | }; |
6233 | } // namespace glow |
6234 | |
6235 | |
6236 | namespace glow { |
6237 | /// Given a vector of segment lengths, calculates offsets of each segment and packs them next to the lengths. For the input vector of length N the output is a Nx2 matrix with (offset, lengths) packaged for each segment. |
6238 | class LengthsToRangesNode final : public Node { |
6239 | NodeHandle Lengths_; |
6240 | |
6241 | public: |
6242 | enum InputIndices { |
6243 | LengthsIdx = 0, |
6244 | }; |
6245 | |
6246 | enum ResultIndices { |
6247 | ResultIdx = 0, |
6248 | }; |
6249 | |
6250 | LengthsToRangesNode(llvm::StringRef name, TypeRef Result , NodeValue Lengths) |
6251 | : Node(Kinded::Kind::LengthsToRangesNodeKind, name), Lengths_(this, Lengths) { |
6252 | addResult(Result); |
6253 | } |
6254 | const NodeValue getLengths() const { return Lengths_; } |
6255 | NodeValue getResult() { return getNthResult(0); } |
6256 | const NodeValue getResult() const { return getNthResult(0); } |
6257 | |
6258 | static bool classof(const Kinded *k) { |
6259 | return k->getKind() == Kinded::Kind::LengthsToRangesNodeKind; |
6260 | } |
6261 | |
6262 | |
6263 | bool isOverwrittenNthInput(unsigned idx) const { |
6264 | return false; |
6265 | } |
6266 | |
6267 | unsigned getNumInputs() const; |
6268 | std::string getInputName(unsigned idx) const; |
6269 | NodeValue getNthInput(unsigned idx); |
6270 | void setNthInput(unsigned idx, NodeValue val); |
6271 | llvm::StringRef getOutputName(unsigned idx) const; |
6272 | bool hasSideEffects() const { return 0; } |
6273 | bool isCanonical() const { return 1; } |
6274 | bool isDataParallel() const { return 0; } |
6275 | std::string getDebugDesc() const; |
6276 | bool isEqual(const LengthsToRangesNode &other) const; |
6277 | llvm::hash_code getHash() const; |
6278 | void visit(Node *parent, NodeWalker *visitor); |
6279 | Node* clone() const; |
6280 | bool verify() const; |
6281 | }; |
6282 | } // namespace glow |
6283 | |
6284 | |
6285 | namespace glow { |
6286 | /// Converts an input Lengths 1D vector into a range sequence. |
6287 | class LengthsRangeFillNode final : public Node { |
6288 | NodeHandle Lengths_; |
6289 | |
6290 | public: |
6291 | enum InputIndices { |
6292 | LengthsIdx = 0, |
6293 | }; |
6294 | |
6295 | enum ResultIndices { |
6296 | ResultIdx = 0, |
6297 | }; |
6298 | |
6299 | LengthsRangeFillNode(llvm::StringRef name, TypeRef Result , NodeValue Lengths) |
6300 | : Node(Kinded::Kind::LengthsRangeFillNodeKind, name), Lengths_(this, Lengths) { |
6301 | addResult(Result); |
6302 | } |
6303 | const NodeValue getLengths() const { return Lengths_; } |
6304 | NodeValue getResult() { return getNthResult(0); } |
6305 | const NodeValue getResult() const { return getNthResult(0); } |
6306 | |
6307 | static bool classof(const Kinded *k) { |
6308 | return k->getKind() == Kinded::Kind::LengthsRangeFillNodeKind; |
6309 | } |
6310 | |
6311 | |
6312 | bool isOverwrittenNthInput(unsigned idx) const { |
6313 | return false; |
6314 | } |
6315 | |
6316 | unsigned getNumInputs() const; |
6317 | std::string getInputName(unsigned idx) const; |
6318 | NodeValue getNthInput(unsigned idx); |
6319 | void setNthInput(unsigned idx, NodeValue val); |
6320 | llvm::StringRef getOutputName(unsigned idx) const; |
6321 | bool hasSideEffects() const { return 0; } |
6322 | bool isCanonical() const { return 1; } |
6323 | bool isDataParallel() const { return 0; } |
6324 | std::string getDebugDesc() const; |
6325 | bool isEqual(const LengthsRangeFillNode &other) const; |
6326 | llvm::hash_code getHash() const; |
6327 | void visit(Node *parent, NodeWalker *visitor); |
6328 | Node* clone() const; |
6329 | bool verify() const; |
6330 | }; |
6331 | } // namespace glow |
6332 | |
6333 | |
6334 | namespace glow { |
6335 | /// Converts the sparse representation specified by (Lengths, Indices, Values) into a dense one. In the dense representation, elements of the lengths vector represent the number of indices in the corresponding batch, where each batch contains each value from Values at the corresponding index specified in Indices, and is filled with DefaultValue otherwise. Within each batch, Indices shouldn't contain duplicate indices. |
6336 | class BatchSparseToDenseNode final : public Node { |
6337 | NodeHandle Lengths_; |
6338 | NodeHandle Indices_; |
6339 | NodeHandle Values_; |
6340 | float DefaultValue_; |
6341 | unsigned_t DenseLastDim_; |
6342 | |
6343 | public: |
6344 | enum InputIndices { |
6345 | LengthsIdx = 0, |
6346 | IndicesIdx = 1, |
6347 | ValuesIdx = 2, |
6348 | }; |
6349 | |
6350 | enum ResultIndices { |
6351 | ResultIdx = 0, |
6352 | }; |
6353 | |
6354 | BatchSparseToDenseNode(llvm::StringRef name, TypeRef Result , NodeValue Lengths, NodeValue Indices, NodeValue Values, float DefaultValue, unsigned_t DenseLastDim) |
6355 | : Node(Kinded::Kind::BatchSparseToDenseNodeKind, name), Lengths_(this, Lengths), Indices_(this, Indices), Values_(this, Values), DefaultValue_(DefaultValue), DenseLastDim_(DenseLastDim) { |
6356 | addResult(Result); |
6357 | } |
6358 | const NodeValue getLengths() const { return Lengths_; } |
6359 | const NodeValue getIndices() const { return Indices_; } |
6360 | const NodeValue getValues() const { return Values_; } |
6361 | NodeValue getResult() { return getNthResult(0); } |
6362 | const NodeValue getResult() const { return getNthResult(0); } |
6363 | float getDefaultValue() const { return DefaultValue_; } |
6364 | unsigned_t getDenseLastDim() const { return DenseLastDim_; } |
6365 | |
6366 | static bool classof(const Kinded *k) { |
6367 | return k->getKind() == Kinded::Kind::BatchSparseToDenseNodeKind; |
6368 | } |
6369 | |
6370 | |
6371 | bool isOverwrittenNthInput(unsigned idx) const { |
6372 | return false; |
6373 | } |
6374 | |
6375 | unsigned getNumInputs() const; |
6376 | std::string getInputName(unsigned idx) const; |
6377 | NodeValue getNthInput(unsigned idx); |
6378 | void setNthInput(unsigned idx, NodeValue val); |
6379 | llvm::StringRef getOutputName(unsigned idx) const; |
6380 | bool hasSideEffects() const { return 0; } |
6381 | bool isCanonical() const { return 1; } |
6382 | bool isDataParallel() const { return 0; } |
6383 | std::string getDebugDesc() const; |
6384 | bool isEqual(const BatchSparseToDenseNode &other) const; |
6385 | llvm::hash_code getHash() const; |
6386 | void visit(Node *parent, NodeWalker *visitor); |
6387 | Node* clone() const; |
6388 | bool verify() const; |
6389 | }; |
6390 | } // namespace glow |
6391 | |
6392 | |
6393 | namespace glow { |
6394 | /// Inserts zeros into data along axis=0 for indices where indicator is zero. |
6395 | class FillExamplesWithIndicatorNode final : public Node { |
6396 | NodeHandle Data_; |
6397 | NodeHandle Indicator_; |
6398 | |
6399 | public: |
6400 | enum InputIndices { |
6401 | DataIdx = 0, |
6402 | IndicatorIdx = 1, |
6403 | }; |
6404 | |
6405 | enum ResultIndices { |
6406 | ResultIdx = 0, |
6407 | }; |
6408 | |
6409 | FillExamplesWithIndicatorNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indicator) |
6410 | : Node(Kinded::Kind::FillExamplesWithIndicatorNodeKind, name), Data_(this, Data), Indicator_(this, Indicator) { |
6411 | addResult(Result); |
6412 | } |
6413 | const NodeValue getData() const { return Data_; } |
6414 | const NodeValue getIndicator() const { return Indicator_; } |
6415 | NodeValue getResult() { return getNthResult(0); } |
6416 | const NodeValue getResult() const { return getNthResult(0); } |
6417 | |
6418 | static bool classof(const Kinded *k) { |
6419 | return k->getKind() == Kinded::Kind::FillExamplesWithIndicatorNodeKind; |
6420 | } |
6421 | |
6422 | |
6423 | bool isOverwrittenNthInput(unsigned idx) const { |
6424 | return false; |
6425 | } |
6426 | |
6427 | unsigned getNumInputs() const; |
6428 | std::string getInputName(unsigned idx) const; |
6429 | NodeValue getNthInput(unsigned idx); |
6430 | void setNthInput(unsigned idx, NodeValue val); |
6431 | llvm::StringRef getOutputName(unsigned idx) const; |
6432 | bool hasSideEffects() const { return 0; } |
6433 | bool isCanonical() const { return 1; } |
6434 | bool isDataParallel() const { return 0; } |
6435 | std::string getDebugDesc() const; |
6436 | bool isEqual(const FillExamplesWithIndicatorNode &other) const; |
6437 | llvm::hash_code getHash() const; |
6438 | void visit(Node *parent, NodeWalker *visitor); |
6439 | Node* clone() const; |
6440 | bool verify() const; |
6441 | }; |
6442 | } // namespace glow |
6443 | |
6444 | |
6445 | namespace glow { |
6446 | /// Converts the sparse representation specified by the pair (Indices, Values) into a dense one, where compacted tensor only contains IDs from given Mask. Indices cannot contain duplicate values. Lengths is used to distinguish elements from different examples of one batch. That is, first Lengths[0] index-value pairs belong to batch's example 0, next Lengths[1] pairs belong to example 1, and so on. |
6447 | class SparseToDenseMaskNode final : public Node { |
6448 | NodeHandle Indices_; |
6449 | NodeHandle Values_; |
6450 | NodeHandle DefaultValue_; |
6451 | NodeHandle Lengths_; |
6452 | std::vector<dim_t> Mask_; |
6453 | |
6454 | public: |
6455 | enum InputIndices { |
6456 | IndicesIdx = 0, |
6457 | ValuesIdx = 1, |
6458 | DefaultValueIdx = 2, |
6459 | LengthsIdx = 3, |
6460 | }; |
6461 | |
6462 | enum ResultIndices { |
6463 | ResultIdx = 0, |
6464 | }; |
6465 | |
6466 | SparseToDenseMaskNode(llvm::StringRef name, TypeRef Result , NodeValue Indices, NodeValue Values, NodeValue DefaultValue, NodeValue Lengths, std::vector<dim_t> Mask) |
6467 | : Node(Kinded::Kind::SparseToDenseMaskNodeKind, name), Indices_(this, Indices), Values_(this, Values), DefaultValue_(this, DefaultValue), Lengths_(this, Lengths), Mask_(Mask) { |
6468 | addResult(Result); |
6469 | } |
6470 | const NodeValue getIndices() const { return Indices_; } |
6471 | const NodeValue getValues() const { return Values_; } |
6472 | const NodeValue getDefaultValue() const { return DefaultValue_; } |
6473 | const NodeValue getLengths() const { return Lengths_; } |
6474 | NodeValue getResult() { return getNthResult(0); } |
6475 | const NodeValue getResult() const { return getNthResult(0); } |
6476 | llvm::ArrayRef<dim_t> getMask() const { return Mask_; } |
6477 | |
6478 | static bool classof(const Kinded *k) { |
6479 | return k->getKind() == Kinded::Kind::SparseToDenseMaskNodeKind; |
6480 | } |
6481 | |
6482 | |
6483 | bool isOverwrittenNthInput(unsigned idx) const { |
6484 | return false; |
6485 | } |
6486 | |
6487 | unsigned getNumInputs() const; |
6488 | std::string getInputName(unsigned idx) const; |
6489 | NodeValue getNthInput(unsigned idx); |
6490 | void setNthInput(unsigned idx, NodeValue val); |
6491 | llvm::StringRef getOutputName(unsigned idx) const; |
6492 | bool hasSideEffects() const { return 0; } |
6493 | bool isCanonical() const { return 1; } |
6494 | bool isDataParallel() const { return 0; } |
6495 | std::string getDebugDesc() const; |
6496 | bool isEqual(const SparseToDenseMaskNode &other) const; |
6497 | llvm::hash_code getHash() const; |
6498 | void visit(Node *parent, NodeWalker *visitor); |
6499 | Node* clone() const; |
6500 | bool verify() const; |
6501 | }; |
6502 | } // namespace glow |
6503 | |
6504 | |
6505 | namespace glow { |
6506 | /// Determines whether each element of the Input is NaN and generates a mask that can be consumed by a Select node. |
6507 | class IsNaNNode final : public Node { |
6508 | NodeHandle Input_; |
6509 | |
6510 | public: |
6511 | enum InputIndices { |
6512 | InputIdx = 0, |
6513 | }; |
6514 | |
6515 | enum ResultIndices { |
6516 | ResultIdx = 0, |
6517 | }; |
6518 | |
6519 | IsNaNNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
6520 | : Node(Kinded::Kind::IsNaNNodeKind, name), Input_(this, Input) { |
6521 | addResult(Result); |
6522 | } |
6523 | const NodeValue getInput() const { return Input_; } |
6524 | NodeValue getResult() { return getNthResult(0); } |
6525 | const NodeValue getResult() const { return getNthResult(0); } |
6526 | |
6527 | static bool classof(const Kinded *k) { |
6528 | return k->getKind() == Kinded::Kind::IsNaNNodeKind; |
6529 | } |
6530 | |
6531 | |
6532 | bool isOverwrittenNthInput(unsigned idx) const { |
6533 | return false; |
6534 | } |
6535 | |
6536 | unsigned getNumInputs() const; |
6537 | std::string getInputName(unsigned idx) const; |
6538 | NodeValue getNthInput(unsigned idx); |
6539 | void setNthInput(unsigned idx, NodeValue val); |
6540 | llvm::StringRef getOutputName(unsigned idx) const; |
6541 | bool hasSideEffects() const { return 0; } |
6542 | bool isCanonical() const { return 1; } |
6543 | bool isDataParallel() const { return 1; } |
6544 | std::string getDebugDesc() const; |
6545 | bool isEqual(const IsNaNNode &other) const; |
6546 | llvm::hash_code getHash() const; |
6547 | void visit(Node *parent, NodeWalker *visitor); |
6548 | Node* clone() const; |
6549 | bool verify() const; |
6550 | }; |
6551 | } // namespace glow |
6552 | |
6553 | |
6554 | namespace glow { |
6555 | /// Replaces NaNs found in Input with Value. |
6556 | class ReplaceNaNNode final : public Node { |
6557 | NodeHandle Input_; |
6558 | float Value_; |
6559 | |
6560 | public: |
6561 | enum InputIndices { |
6562 | InputIdx = 0, |
6563 | }; |
6564 | |
6565 | enum ResultIndices { |
6566 | ResultIdx = 0, |
6567 | }; |
6568 | |
6569 | ReplaceNaNNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Value) |
6570 | : Node(Kinded::Kind::ReplaceNaNNodeKind, name), Input_(this, Input), Value_(Value) { |
6571 | addResult(Result); |
6572 | } |
6573 | const NodeValue getInput() const { return Input_; } |
6574 | NodeValue getResult() { return getNthResult(0); } |
6575 | const NodeValue getResult() const { return getNthResult(0); } |
6576 | float getValue() const { return Value_; } |
6577 | |
6578 | static bool classof(const Kinded *k) { |
6579 | return k->getKind() == Kinded::Kind::ReplaceNaNNodeKind; |
6580 | } |
6581 | |
6582 | |
6583 | bool isOverwrittenNthInput(unsigned idx) const { |
6584 | return false; |
6585 | } |
6586 | |
6587 | unsigned getNumInputs() const; |
6588 | std::string getInputName(unsigned idx) const; |
6589 | NodeValue getNthInput(unsigned idx); |
6590 | void setNthInput(unsigned idx, NodeValue val); |
6591 | llvm::StringRef getOutputName(unsigned idx) const; |
6592 | bool hasSideEffects() const { return 0; } |
6593 | bool isCanonical() const { return 1; } |
6594 | bool isDataParallel() const { return 0; } |
6595 | std::string getDebugDesc() const; |
6596 | bool isEqual(const ReplaceNaNNode &other) const; |
6597 | llvm::hash_code getHash() const; |
6598 | void visit(Node *parent, NodeWalker *visitor); |
6599 | Node* clone() const; |
6600 | bool verify() const; |
6601 | }; |
6602 | } // namespace glow |
6603 | |
6604 | |
6605 | namespace glow { |
6606 | /// Performs elementwise modulo operation on the input where each element in the output is the corresponding element in the input data modulo Divisor. |
6607 | class ModuloNode final : public Node { |
6608 | NodeHandle Input_; |
6609 | int64_t Divisor_; |
6610 | bool SignFollowDivisor_; |
6611 | |
6612 | public: |
6613 | enum InputIndices { |
6614 | InputIdx = 0, |
6615 | }; |
6616 | |
6617 | enum ResultIndices { |
6618 | ResultIdx = 0, |
6619 | }; |
6620 | |
6621 | ModuloNode(llvm::StringRef name, TypeRef Result , NodeValue Input, int64_t Divisor, bool SignFollowDivisor) |
6622 | : Node(Kinded::Kind::ModuloNodeKind, name), Input_(this, Input), Divisor_(Divisor), SignFollowDivisor_(SignFollowDivisor) { |
6623 | addResult(Result); |
6624 | } |
6625 | const NodeValue getInput() const { return Input_; } |
6626 | NodeValue getResult() { return getNthResult(0); } |
6627 | const NodeValue getResult() const { return getNthResult(0); } |
6628 | int64_t getDivisor() const { return Divisor_; } |
6629 | bool getSignFollowDivisor() const { return SignFollowDivisor_; } |
6630 | |
6631 | static bool classof(const Kinded *k) { |
6632 | return k->getKind() == Kinded::Kind::ModuloNodeKind; |
6633 | } |
6634 | |
6635 | |
6636 | bool isOverwrittenNthInput(unsigned idx) const { |
6637 | return false; |
6638 | } |
6639 | |
6640 | unsigned getNumInputs() const; |
6641 | std::string getInputName(unsigned idx) const; |
6642 | NodeValue getNthInput(unsigned idx); |
6643 | void setNthInput(unsigned idx, NodeValue val); |
6644 | llvm::StringRef getOutputName(unsigned idx) const; |
6645 | bool hasSideEffects() const { return 0; } |
6646 | bool isCanonical() const { return 1; } |
6647 | bool isDataParallel() const { return 1; } |
6648 | std::string getDebugDesc() const; |
6649 | bool isEqual(const ModuloNode &other) const; |
6650 | llvm::hash_code getHash() const; |
6651 | void visit(Node *parent, NodeWalker *visitor); |
6652 | Node* clone() const; |
6653 | bool verify() const; |
6654 | }; |
6655 | } // namespace glow |
6656 | |
6657 | |
6658 | namespace glow { |
6659 | /// Performs batched pairwise dot products of the input vectors |
6660 | class BatchedPairwiseDotProductNode final : public Node { |
6661 | std::vector<NodeHandle> Inputs_; |
6662 | |
6663 | public: |
6664 | enum InputIndices { |
6665 | }; |
6666 | |
6667 | enum ResultIndices { |
6668 | ResultIdx = 0, |
6669 | }; |
6670 | |
6671 | BatchedPairwiseDotProductNode(llvm::StringRef name, TypeRef Result , std::vector<NodeValue> Inputs) |
6672 | : Node(Kinded::Kind::BatchedPairwiseDotProductNodeKind, name) { |
6673 | addResult(Result); |
6674 | Inputs_.resize(Inputs.size()); |
6675 | for (size_t idx = 0, e = Inputs.size(); idx < e; ++idx) { |
6676 | Inputs_[idx] = Inputs[idx]; |
6677 | Inputs_[idx].setParent(this); |
6678 | } |
6679 | } |
6680 | NodeValue getResult() { return getNthResult(0); } |
6681 | const NodeValue getResult() const { return getNthResult(0); } |
6682 | NodeValueArrayRef getInputs() const { return Inputs_; } |
6683 | |
6684 | static bool classof(const Kinded *k) { |
6685 | return k->getKind() == Kinded::Kind::BatchedPairwiseDotProductNodeKind; |
6686 | } |
6687 | |
6688 | |
6689 | bool isOverwrittenNthInput(unsigned idx) const { |
6690 | return false; |
6691 | } |
6692 | |
6693 | unsigned getNumInputs() const; |
6694 | std::string getInputName(unsigned idx) const; |
6695 | NodeValue getNthInput(unsigned idx); |
6696 | void setNthInput(unsigned idx, NodeValue val); |
6697 | llvm::StringRef getOutputName(unsigned idx) const; |
6698 | bool hasSideEffects() const { return 0; } |
6699 | bool isCanonical() const { return 1; } |
6700 | bool isDataParallel() const { return 0; } |
6701 | std::string getDebugDesc() const; |
6702 | bool isEqual(const BatchedPairwiseDotProductNode &other) const; |
6703 | llvm::hash_code getHash() const; |
6704 | void visit(Node *parent, NodeWalker *visitor); |
6705 | Node* clone() const; |
6706 | bool verify() const; |
6707 | }; |
6708 | } // namespace glow |
6709 | |
6710 | |
6711 | namespace glow { |
6712 | /// Performs the gradient operation for BatchedPairwiseDotProduct |
6713 | class BatchedPairwiseDotProductGradNode final : public Node { |
6714 | NodeHandle OutputGrad_; |
6715 | std::vector<NodeHandle> OriginalInputs_; |
6716 | |
6717 | public: |
6718 | enum InputIndices { |
6719 | OutputGradIdx = 0, |
6720 | }; |
6721 | |
6722 | enum ResultIndices { |
6723 | }; |
6724 | |
6725 | BatchedPairwiseDotProductGradNode(llvm::StringRef name, NodeValue OutputGrad, std::vector<NodeValue> OriginalInputs) |
6726 | : Node(Kinded::Kind::BatchedPairwiseDotProductGradNodeKind, name), OutputGrad_(this, OutputGrad) { |
6727 | OriginalInputs_.resize(OriginalInputs.size()); |
6728 | for (size_t idx = 0, e = OriginalInputs.size(); idx < e; ++idx) { |
6729 | OriginalInputs_[idx] = OriginalInputs[idx]; |
6730 | OriginalInputs_[idx].setParent(this); |
6731 | } |
6732 | } |
6733 | const NodeValue getOutputGrad() const { return OutputGrad_; } |
6734 | NodeValueArrayRef getOriginalInputs() const { return OriginalInputs_; } |
6735 | |
6736 | static bool classof(const Kinded *k) { |
6737 | return k->getKind() == Kinded::Kind::BatchedPairwiseDotProductGradNodeKind; |
6738 | } |
6739 | |
6740 | |
6741 | bool isOverwrittenNthInput(unsigned idx) const { |
6742 | return false; |
6743 | } |
6744 | |
6745 | unsigned getNumInputs() const; |
6746 | std::string getInputName(unsigned idx) const; |
6747 | NodeValue getNthInput(unsigned idx); |
6748 | void setNthInput(unsigned idx, NodeValue val); |
6749 | llvm::StringRef getOutputName(unsigned idx) const; |
6750 | bool hasSideEffects() const { return 0; } |
6751 | bool isCanonical() const { return 1; } |
6752 | bool isDataParallel() const { return 0; } |
6753 | std::string getDebugDesc() const; |
6754 | bool isEqual(const BatchedPairwiseDotProductGradNode &other) const; |
6755 | llvm::hash_code getHash() const; |
6756 | void visit(Node *parent, NodeWalker *visitor); |
6757 | Node* clone() const; |
6758 | bool verify() const; |
6759 | void (TypeRef T) { addResult(T); } |
6760 | }; |
6761 | } // namespace glow |
6762 | |
6763 | |
6764 | namespace glow { |
6765 | /// Sum weight embeddings according to offsets and indices |
6766 | class BatchedUnaryEmbeddingsBagsNode final : public Node { |
6767 | NodeHandle Weights_; |
6768 | NodeHandle TableOffsets_; |
6769 | NodeHandle Offsets_; |
6770 | NodeHandle Indices_; |
6771 | |
6772 | public: |
6773 | enum InputIndices { |
6774 | WeightsIdx = 0, |
6775 | TableOffsetsIdx = 1, |
6776 | OffsetsIdx = 2, |
6777 | IndicesIdx = 3, |
6778 | }; |
6779 | |
6780 | enum ResultIndices { |
6781 | ResultIdx = 0, |
6782 | }; |
6783 | |
6784 | BatchedUnaryEmbeddingsBagsNode(llvm::StringRef name, TypeRef Result , NodeValue Weights, NodeValue TableOffsets, NodeValue Offsets, NodeValue Indices) |
6785 | : Node(Kinded::Kind::BatchedUnaryEmbeddingsBagsNodeKind, name), Weights_(this, Weights), TableOffsets_(this, TableOffsets), Offsets_(this, Offsets), Indices_(this, Indices) { |
6786 | addResult(Result); |
6787 | } |
6788 | const NodeValue getWeights() const { return Weights_; } |
6789 | const NodeValue getTableOffsets() const { return TableOffsets_; } |
6790 | const NodeValue getOffsets() const { return Offsets_; } |
6791 | const NodeValue getIndices() const { return Indices_; } |
6792 | NodeValue getResult() { return getNthResult(0); } |
6793 | const NodeValue getResult() const { return getNthResult(0); } |
6794 | |
6795 | static bool classof(const Kinded *k) { |
6796 | return k->getKind() == Kinded::Kind::BatchedUnaryEmbeddingsBagsNodeKind; |
6797 | } |
6798 | |
6799 | |
6800 | bool isOverwrittenNthInput(unsigned idx) const { |
6801 | return false; |
6802 | } |
6803 | |
6804 | unsigned getNumInputs() const; |
6805 | std::string getInputName(unsigned idx) const; |
6806 | NodeValue getNthInput(unsigned idx); |
6807 | void setNthInput(unsigned idx, NodeValue val); |
6808 | llvm::StringRef getOutputName(unsigned idx) const; |
6809 | bool hasSideEffects() const { return 0; } |
6810 | bool isCanonical() const { return 1; } |
6811 | bool isDataParallel() const { return 0; } |
6812 | std::string getDebugDesc() const; |
6813 | bool isEqual(const BatchedUnaryEmbeddingsBagsNode &other) const; |
6814 | llvm::hash_code getHash() const; |
6815 | void visit(Node *parent, NodeWalker *visitor); |
6816 | Node* clone() const; |
6817 | bool verify() const; |
6818 | }; |
6819 | } // namespace glow |
6820 | |
6821 | |
6822 | namespace glow { |
6823 | /// Table based batched embeddingbags with quantization support. Experimental only and subject to change. |
6824 | class IntNBitSplitEmbeddingBagsNode final : public Node { |
6825 | NodeHandle DevWeights_; |
6826 | NodeHandle UvmWeights_; |
6827 | NodeHandle WeightsPlacements_; |
6828 | NodeHandle WeightsOffsets_; |
6829 | NodeHandle WeightsTys_; |
6830 | NodeHandle DimOffsets_; |
6831 | NodeHandle Indices_; |
6832 | NodeHandle Offsets_; |
6833 | int64_t TotalDims_; |
6834 | glow::SplitEmbeddingPoolingMode PoolingMode_; |
6835 | glow::SplitEmbeddingSparseType OutputDType_; |
6836 | |
6837 | public: |
6838 | enum InputIndices { |
6839 | DevWeightsIdx = 0, |
6840 | UvmWeightsIdx = 1, |
6841 | WeightsPlacementsIdx = 2, |
6842 | WeightsOffsetsIdx = 3, |
6843 | WeightsTysIdx = 4, |
6844 | DimOffsetsIdx = 5, |
6845 | IndicesIdx = 6, |
6846 | OffsetsIdx = 7, |
6847 | }; |
6848 | |
6849 | enum ResultIndices { |
6850 | ResultIdx = 0, |
6851 | }; |
6852 | |
6853 | IntNBitSplitEmbeddingBagsNode(llvm::StringRef name, TypeRef Result , NodeValue DevWeights, NodeValue UvmWeights, NodeValue WeightsPlacements, NodeValue WeightsOffsets, NodeValue WeightsTys, NodeValue DimOffsets, NodeValue Indices, NodeValue Offsets, int64_t TotalDims, glow::SplitEmbeddingPoolingMode PoolingMode, glow::SplitEmbeddingSparseType OutputDType) |
6854 | : Node(Kinded::Kind::IntNBitSplitEmbeddingBagsNodeKind, name), DevWeights_(this, DevWeights), UvmWeights_(this, UvmWeights), WeightsPlacements_(this, WeightsPlacements), WeightsOffsets_(this, WeightsOffsets), WeightsTys_(this, WeightsTys), DimOffsets_(this, DimOffsets), Indices_(this, Indices), Offsets_(this, Offsets), TotalDims_(TotalDims), PoolingMode_(PoolingMode), OutputDType_(OutputDType) { |
6855 | addResult(Result); |
6856 | } |
6857 | const NodeValue getDevWeights() const { return DevWeights_; } |
6858 | const NodeValue getUvmWeights() const { return UvmWeights_; } |
6859 | const NodeValue getWeightsPlacements() const { return WeightsPlacements_; } |
6860 | const NodeValue getWeightsOffsets() const { return WeightsOffsets_; } |
6861 | const NodeValue getWeightsTys() const { return WeightsTys_; } |
6862 | const NodeValue getDimOffsets() const { return DimOffsets_; } |
6863 | const NodeValue getIndices() const { return Indices_; } |
6864 | const NodeValue getOffsets() const { return Offsets_; } |
6865 | NodeValue getResult() { return getNthResult(0); } |
6866 | const NodeValue getResult() const { return getNthResult(0); } |
6867 | int64_t getTotalDims() const { return TotalDims_; } |
6868 | glow::SplitEmbeddingPoolingMode getPoolingMode() const { return PoolingMode_; } |
6869 | glow::SplitEmbeddingSparseType getOutputDType() const { return OutputDType_; } |
6870 | |
6871 | static bool classof(const Kinded *k) { |
6872 | return k->getKind() == Kinded::Kind::IntNBitSplitEmbeddingBagsNodeKind; |
6873 | } |
6874 | |
6875 | |
6876 | bool isOverwrittenNthInput(unsigned idx) const { |
6877 | return false; |
6878 | } |
6879 | |
6880 | unsigned getNumInputs() const; |
6881 | std::string getInputName(unsigned idx) const; |
6882 | NodeValue getNthInput(unsigned idx); |
6883 | void setNthInput(unsigned idx, NodeValue val); |
6884 | llvm::StringRef getOutputName(unsigned idx) const; |
6885 | bool hasSideEffects() const { return 0; } |
6886 | bool isCanonical() const { return 1; } |
6887 | bool isDataParallel() const { return 0; } |
6888 | std::string getDebugDesc() const; |
6889 | bool isEqual(const IntNBitSplitEmbeddingBagsNode &other) const; |
6890 | llvm::hash_code getHash() const; |
6891 | void visit(Node *parent, NodeWalker *visitor); |
6892 | Node* clone() const; |
6893 | bool verify() const; |
6894 | }; |
6895 | } // namespace glow |
6896 | |
6897 | |
6898 | namespace glow { |
6899 | /// Table based batched embeddingbags with quantization support and indice weights. Experimental only and subject to change. |
6900 | class IntNBitSplitEmbeddingWeightedBagsNode final : public Node { |
6901 | NodeHandle DevWeights_; |
6902 | NodeHandle UvmWeights_; |
6903 | NodeHandle WeightsPlacements_; |
6904 | NodeHandle WeightsOffsets_; |
6905 | NodeHandle WeightsTys_; |
6906 | NodeHandle DimOffsets_; |
6907 | NodeHandle Indices_; |
6908 | NodeHandle Offsets_; |
6909 | NodeHandle IndiceWeight_; |
6910 | int64_t TotalDims_; |
6911 | glow::SplitEmbeddingPoolingMode PoolingMode_; |
6912 | glow::SplitEmbeddingSparseType OutputDType_; |
6913 | |
6914 | public: |
6915 | enum InputIndices { |
6916 | DevWeightsIdx = 0, |
6917 | UvmWeightsIdx = 1, |
6918 | WeightsPlacementsIdx = 2, |
6919 | WeightsOffsetsIdx = 3, |
6920 | WeightsTysIdx = 4, |
6921 | DimOffsetsIdx = 5, |
6922 | IndicesIdx = 6, |
6923 | OffsetsIdx = 7, |
6924 | IndiceWeightIdx = 8, |
6925 | }; |
6926 | |
6927 | enum ResultIndices { |
6928 | ResultIdx = 0, |
6929 | }; |
6930 | |
6931 | IntNBitSplitEmbeddingWeightedBagsNode(llvm::StringRef name, TypeRef Result , NodeValue DevWeights, NodeValue UvmWeights, NodeValue WeightsPlacements, NodeValue WeightsOffsets, NodeValue WeightsTys, NodeValue DimOffsets, NodeValue Indices, NodeValue Offsets, NodeValue IndiceWeight, int64_t TotalDims, glow::SplitEmbeddingPoolingMode PoolingMode, glow::SplitEmbeddingSparseType OutputDType) |
6932 | : Node(Kinded::Kind::IntNBitSplitEmbeddingWeightedBagsNodeKind, name), DevWeights_(this, DevWeights), UvmWeights_(this, UvmWeights), WeightsPlacements_(this, WeightsPlacements), WeightsOffsets_(this, WeightsOffsets), WeightsTys_(this, WeightsTys), DimOffsets_(this, DimOffsets), Indices_(this, Indices), Offsets_(this, Offsets), IndiceWeight_(this, IndiceWeight), TotalDims_(TotalDims), PoolingMode_(PoolingMode), OutputDType_(OutputDType) { |
6933 | addResult(Result); |
6934 | } |
6935 | const NodeValue getDevWeights() const { return DevWeights_; } |
6936 | const NodeValue getUvmWeights() const { return UvmWeights_; } |
6937 | const NodeValue getWeightsPlacements() const { return WeightsPlacements_; } |
6938 | const NodeValue getWeightsOffsets() const { return WeightsOffsets_; } |
6939 | const NodeValue getWeightsTys() const { return WeightsTys_; } |
6940 | const NodeValue getDimOffsets() const { return DimOffsets_; } |
6941 | const NodeValue getIndices() const { return Indices_; } |
6942 | const NodeValue getOffsets() const { return Offsets_; } |
6943 | const NodeValue getIndiceWeight() const { return IndiceWeight_; } |
6944 | NodeValue getResult() { return getNthResult(0); } |
6945 | const NodeValue getResult() const { return getNthResult(0); } |
6946 | int64_t getTotalDims() const { return TotalDims_; } |
6947 | glow::SplitEmbeddingPoolingMode getPoolingMode() const { return PoolingMode_; } |
6948 | glow::SplitEmbeddingSparseType getOutputDType() const { return OutputDType_; } |
6949 | |
6950 | static bool classof(const Kinded *k) { |
6951 | return k->getKind() == Kinded::Kind::IntNBitSplitEmbeddingWeightedBagsNodeKind; |
6952 | } |
6953 | |
6954 | |
6955 | bool isOverwrittenNthInput(unsigned idx) const { |
6956 | return false; |
6957 | } |
6958 | |
6959 | unsigned getNumInputs() const; |
6960 | std::string getInputName(unsigned idx) const; |
6961 | NodeValue getNthInput(unsigned idx); |
6962 | void setNthInput(unsigned idx, NodeValue val); |
6963 | llvm::StringRef getOutputName(unsigned idx) const; |
6964 | bool hasSideEffects() const { return 0; } |
6965 | bool isCanonical() const { return 1; } |
6966 | bool isDataParallel() const { return 0; } |
6967 | std::string getDebugDesc() const; |
6968 | bool isEqual(const IntNBitSplitEmbeddingWeightedBagsNode &other) const; |
6969 | llvm::hash_code getHash() const; |
6970 | void visit(Node *parent, NodeWalker *visitor); |
6971 | Node* clone() const; |
6972 | bool verify() const; |
6973 | }; |
6974 | } // namespace glow |
6975 | |
6976 | |
6977 | namespace glow { |
6978 | /// Fills an output tensor with samples drawn from a normal distribution specified by the mean and standard deviation arguments. The output tensor shape is determined by the input shape if provided, and shape otherwise |
6979 | class GaussianFillNode final : public Node { |
6980 | NodeHandle Input_; |
6981 | float Mean_; |
6982 | float Scale_; |
6983 | float Seed_; |
6984 | |
6985 | public: |
6986 | enum InputIndices { |
6987 | InputIdx = 0, |
6988 | }; |
6989 | |
6990 | enum ResultIndices { |
6991 | ResultIdx = 0, |
6992 | }; |
6993 | |
6994 | GaussianFillNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Mean, float Scale, float Seed) |
6995 | : Node(Kinded::Kind::GaussianFillNodeKind, name), Input_(this, Input), Mean_(Mean), Scale_(Scale), Seed_(Seed) { |
6996 | addResult(Result); |
6997 | } |
6998 | const NodeValue getInput() const { return Input_; } |
6999 | NodeValue getResult() { return getNthResult(0); } |
7000 | const NodeValue getResult() const { return getNthResult(0); } |
7001 | float getMean() const { return Mean_; } |
7002 | float getScale() const { return Scale_; } |
7003 | float getSeed() const { return Seed_; } |
7004 | |
7005 | static bool classof(const Kinded *k) { |
7006 | return k->getKind() == Kinded::Kind::GaussianFillNodeKind; |
7007 | } |
7008 | |
7009 | |
7010 | bool isOverwrittenNthInput(unsigned idx) const { |
7011 | return false; |
7012 | } |
7013 | |
7014 | unsigned getNumInputs() const; |
7015 | std::string getInputName(unsigned idx) const; |
7016 | NodeValue getNthInput(unsigned idx); |
7017 | void setNthInput(unsigned idx, NodeValue val); |
7018 | llvm::StringRef getOutputName(unsigned idx) const; |
7019 | bool hasSideEffects() const { return 0; } |
7020 | bool isCanonical() const { return 1; } |
7021 | bool isDataParallel() const { return 0; } |
7022 | std::string getDebugDesc() const; |
7023 | bool isEqual(const GaussianFillNode &other) const; |
7024 | llvm::hash_code getHash() const; |
7025 | void visit(Node *parent, NodeWalker *visitor); |
7026 | Node* clone() const; |
7027 | bool verify() const; |
7028 | }; |
7029 | } // namespace glow |
7030 | |
7031 | |
7032 | namespace glow { |
7033 | class ReluGradNode final : public Node { |
7034 | NodeHandle Input_; |
7035 | NodeHandle OriginalOutputForResult_; |
7036 | NodeHandle GradOfOriginalOutputNamedResult_; |
7037 | |
7038 | public: |
7039 | enum InputIndices { |
7040 | InputIdx = 0, |
7041 | OriginalOutputForResultIdx = 1, |
7042 | GradOfOriginalOutputNamedResultIdx = 2, |
7043 | }; |
7044 | |
7045 | enum ResultIndices { |
7046 | GradOfInputNamedInputIdx = 0, |
7047 | }; |
7048 | |
7049 | ReluGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
7050 | : Node(Kinded::Kind::ReluGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
7051 | addResult(Input.getType()); |
7052 | } |
7053 | const NodeValue getInput() const { return Input_; } |
7054 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
7055 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
7056 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
7057 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
7058 | |
7059 | static bool classof(const Kinded *k) { |
7060 | return k->getKind() == Kinded::Kind::ReluGradNodeKind; |
7061 | } |
7062 | |
7063 | |
7064 | bool isOverwrittenNthInput(unsigned idx) const { |
7065 | return false; |
7066 | } |
7067 | |
7068 | unsigned getNumInputs() const; |
7069 | std::string getInputName(unsigned idx) const; |
7070 | NodeValue getNthInput(unsigned idx); |
7071 | void setNthInput(unsigned idx, NodeValue val); |
7072 | llvm::StringRef getOutputName(unsigned idx) const; |
7073 | bool hasSideEffects() const { return 0; } |
7074 | bool isCanonical() const { return 1; } |
7075 | bool isDataParallel() const { return 1; } |
7076 | std::string getDebugDesc() const; |
7077 | bool isEqual(const ReluGradNode &other) const; |
7078 | llvm::hash_code getHash() const; |
7079 | void visit(Node *parent, NodeWalker *visitor); |
7080 | Node* clone() const; |
7081 | bool verify() const; |
7082 | }; |
7083 | } // namespace glow |
7084 | |
7085 | |
7086 | namespace glow { |
7087 | /// Applies ReLU, max(0, x), to each element in the Input tensor. |
7088 | class ReluNode final : public Node { |
7089 | NodeHandle Input_; |
7090 | |
7091 | public: |
7092 | enum InputIndices { |
7093 | InputIdx = 0, |
7094 | }; |
7095 | |
7096 | enum ResultIndices { |
7097 | ResultIdx = 0, |
7098 | }; |
7099 | |
7100 | ReluNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
7101 | : Node(Kinded::Kind::ReluNodeKind, name), Input_(this, Input) { |
7102 | addResult(Result); |
7103 | } |
7104 | const NodeValue getInput() const { return Input_; } |
7105 | NodeValue getResult() { return getNthResult(0); } |
7106 | const NodeValue getResult() const { return getNthResult(0); } |
7107 | |
7108 | static bool classof(const Kinded *k) { |
7109 | return k->getKind() == Kinded::Kind::ReluNodeKind; |
7110 | } |
7111 | |
7112 | |
7113 | bool isOverwrittenNthInput(unsigned idx) const { |
7114 | return false; |
7115 | } |
7116 | |
7117 | unsigned getNumInputs() const; |
7118 | std::string getInputName(unsigned idx) const; |
7119 | NodeValue getNthInput(unsigned idx); |
7120 | void setNthInput(unsigned idx, NodeValue val); |
7121 | llvm::StringRef getOutputName(unsigned idx) const; |
7122 | bool hasSideEffects() const { return 0; } |
7123 | bool isCanonical() const { return 1; } |
7124 | bool isDataParallel() const { return 1; } |
7125 | std::string getDebugDesc() const; |
7126 | bool isEqual(const ReluNode &other) const; |
7127 | llvm::hash_code getHash() const; |
7128 | void visit(Node *parent, NodeWalker *visitor); |
7129 | Node* clone() const; |
7130 | bool verify() const; |
7131 | ReluGradNode *getGrad(GraphGradMapper &builder); |
7132 | }; |
7133 | } // namespace glow |
7134 | |
7135 | |
7136 | namespace glow { |
7137 | /// Applies HardSwish to each element in the Input tensor. |
7138 | class HardSwishNode final : public Node { |
7139 | NodeHandle Input_; |
7140 | |
7141 | public: |
7142 | enum InputIndices { |
7143 | InputIdx = 0, |
7144 | }; |
7145 | |
7146 | enum ResultIndices { |
7147 | ResultIdx = 0, |
7148 | }; |
7149 | |
7150 | HardSwishNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
7151 | : Node(Kinded::Kind::HardSwishNodeKind, name), Input_(this, Input) { |
7152 | addResult(Result); |
7153 | } |
7154 | const NodeValue getInput() const { return Input_; } |
7155 | NodeValue getResult() { return getNthResult(0); } |
7156 | const NodeValue getResult() const { return getNthResult(0); } |
7157 | |
7158 | static bool classof(const Kinded *k) { |
7159 | return k->getKind() == Kinded::Kind::HardSwishNodeKind; |
7160 | } |
7161 | |
7162 | |
7163 | bool isOverwrittenNthInput(unsigned idx) const { |
7164 | return false; |
7165 | } |
7166 | |
7167 | unsigned getNumInputs() const; |
7168 | std::string getInputName(unsigned idx) const; |
7169 | NodeValue getNthInput(unsigned idx); |
7170 | void setNthInput(unsigned idx, NodeValue val); |
7171 | llvm::StringRef getOutputName(unsigned idx) const; |
7172 | bool hasSideEffects() const { return 0; } |
7173 | bool isCanonical() const { return 1; } |
7174 | bool isDataParallel() const { return 1; } |
7175 | std::string getDebugDesc() const; |
7176 | bool isEqual(const HardSwishNode &other) const; |
7177 | llvm::hash_code getHash() const; |
7178 | void visit(Node *parent, NodeWalker *visitor); |
7179 | Node* clone() const; |
7180 | bool verify() const; |
7181 | }; |
7182 | } // namespace glow |
7183 | |
7184 | |
7185 | namespace glow { |
7186 | /// Applies GeLU, to each element in the Input tensor. |
7187 | class GeluNode final : public Node { |
7188 | NodeHandle Input_; |
7189 | |
7190 | public: |
7191 | enum InputIndices { |
7192 | InputIdx = 0, |
7193 | }; |
7194 | |
7195 | enum ResultIndices { |
7196 | ResultIdx = 0, |
7197 | }; |
7198 | |
7199 | GeluNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
7200 | : Node(Kinded::Kind::GeluNodeKind, name), Input_(this, Input) { |
7201 | addResult(Result); |
7202 | } |
7203 | const NodeValue getInput() const { return Input_; } |
7204 | NodeValue getResult() { return getNthResult(0); } |
7205 | const NodeValue getResult() const { return getNthResult(0); } |
7206 | |
7207 | static bool classof(const Kinded *k) { |
7208 | return k->getKind() == Kinded::Kind::GeluNodeKind; |
7209 | } |
7210 | |
7211 | |
7212 | bool isOverwrittenNthInput(unsigned idx) const { |
7213 | return false; |
7214 | } |
7215 | |
7216 | unsigned getNumInputs() const; |
7217 | std::string getInputName(unsigned idx) const; |
7218 | NodeValue getNthInput(unsigned idx); |
7219 | void setNthInput(unsigned idx, NodeValue val); |
7220 | llvm::StringRef getOutputName(unsigned idx) const; |
7221 | bool hasSideEffects() const { return 0; } |
7222 | bool isCanonical() const { return 1; } |
7223 | bool isDataParallel() const { return 1; } |
7224 | std::string getDebugDesc() const; |
7225 | bool isEqual(const GeluNode &other) const; |
7226 | llvm::hash_code getHash() const; |
7227 | void visit(Node *parent, NodeWalker *visitor); |
7228 | Node* clone() const; |
7229 | bool verify() const; |
7230 | }; |
7231 | } // namespace glow |
7232 | |
7233 | |
7234 | namespace glow { |
7235 | /// Clip range of inputs to lie in [Min, Max]. |
7236 | class ClipNode final : public Node { |
7237 | NodeHandle Input_; |
7238 | float Min_; |
7239 | float Max_; |
7240 | |
7241 | public: |
7242 | enum InputIndices { |
7243 | InputIdx = 0, |
7244 | }; |
7245 | |
7246 | enum ResultIndices { |
7247 | ResultIdx = 0, |
7248 | }; |
7249 | |
7250 | ClipNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Min, float Max) |
7251 | : Node(Kinded::Kind::ClipNodeKind, name), Input_(this, Input), Min_(Min), Max_(Max) { |
7252 | addResult(Result); |
7253 | } |
7254 | const NodeValue getInput() const { return Input_; } |
7255 | NodeValue getResult() { return getNthResult(0); } |
7256 | const NodeValue getResult() const { return getNthResult(0); } |
7257 | float getMin() const { return Min_; } |
7258 | float getMax() const { return Max_; } |
7259 | |
7260 | static bool classof(const Kinded *k) { |
7261 | return k->getKind() == Kinded::Kind::ClipNodeKind; |
7262 | } |
7263 | |
7264 | |
7265 | bool isOverwrittenNthInput(unsigned idx) const { |
7266 | return false; |
7267 | } |
7268 | |
7269 | unsigned getNumInputs() const; |
7270 | std::string getInputName(unsigned idx) const; |
7271 | NodeValue getNthInput(unsigned idx); |
7272 | void setNthInput(unsigned idx, NodeValue val); |
7273 | llvm::StringRef getOutputName(unsigned idx) const; |
7274 | bool hasSideEffects() const { return 0; } |
7275 | bool isCanonical() const { return 1; } |
7276 | bool isDataParallel() const { return 1; } |
7277 | std::string getDebugDesc() const; |
7278 | bool isEqual(const ClipNode &other) const; |
7279 | llvm::hash_code getHash() const; |
7280 | void visit(Node *parent, NodeWalker *visitor); |
7281 | Node* clone() const; |
7282 | bool verify() const; |
7283 | }; |
7284 | } // namespace glow |
7285 | |
7286 | |
7287 | namespace glow { |
7288 | /// Applies PReLU, slope * min(0, x) + max(0, x), to each element in the Input tensor. |
7289 | class PReluNode final : public Node { |
7290 | NodeHandle Input_; |
7291 | NodeHandle Slope_; |
7292 | |
7293 | public: |
7294 | enum InputIndices { |
7295 | InputIdx = 0, |
7296 | SlopeIdx = 1, |
7297 | }; |
7298 | |
7299 | enum ResultIndices { |
7300 | ResultIdx = 0, |
7301 | }; |
7302 | |
7303 | PReluNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Slope) |
7304 | : Node(Kinded::Kind::PReluNodeKind, name), Input_(this, Input), Slope_(this, Slope) { |
7305 | addResult(Result); |
7306 | } |
7307 | const NodeValue getInput() const { return Input_; } |
7308 | const NodeValue getSlope() const { return Slope_; } |
7309 | NodeValue getResult() { return getNthResult(0); } |
7310 | const NodeValue getResult() const { return getNthResult(0); } |
7311 | |
7312 | static bool classof(const Kinded *k) { |
7313 | return k->getKind() == Kinded::Kind::PReluNodeKind; |
7314 | } |
7315 | |
7316 | |
7317 | bool isOverwrittenNthInput(unsigned idx) const { |
7318 | return false; |
7319 | } |
7320 | |
7321 | unsigned getNumInputs() const; |
7322 | std::string getInputName(unsigned idx) const; |
7323 | NodeValue getNthInput(unsigned idx); |
7324 | void setNthInput(unsigned idx, NodeValue val); |
7325 | llvm::StringRef getOutputName(unsigned idx) const; |
7326 | bool hasSideEffects() const { return 0; } |
7327 | bool isCanonical() const { return 1; } |
7328 | bool isDataParallel() const { return 1; } |
7329 | std::string getDebugDesc() const; |
7330 | bool isEqual(const PReluNode &other) const; |
7331 | llvm::hash_code getHash() const; |
7332 | void visit(Node *parent, NodeWalker *visitor); |
7333 | Node* clone() const; |
7334 | bool verify() const; |
7335 | }; |
7336 | } // namespace glow |
7337 | |
7338 | |
7339 | namespace glow { |
7340 | class SigmoidGradNode final : public Node { |
7341 | NodeHandle Input_; |
7342 | NodeHandle OriginalOutputForResult_; |
7343 | NodeHandle GradOfOriginalOutputNamedResult_; |
7344 | |
7345 | public: |
7346 | enum InputIndices { |
7347 | InputIdx = 0, |
7348 | OriginalOutputForResultIdx = 1, |
7349 | GradOfOriginalOutputNamedResultIdx = 2, |
7350 | }; |
7351 | |
7352 | enum ResultIndices { |
7353 | GradOfInputNamedInputIdx = 0, |
7354 | }; |
7355 | |
7356 | SigmoidGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
7357 | : Node(Kinded::Kind::SigmoidGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
7358 | addResult(Input.getType()); |
7359 | } |
7360 | const NodeValue getInput() const { return Input_; } |
7361 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
7362 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
7363 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
7364 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
7365 | |
7366 | static bool classof(const Kinded *k) { |
7367 | return k->getKind() == Kinded::Kind::SigmoidGradNodeKind; |
7368 | } |
7369 | |
7370 | |
7371 | bool isOverwrittenNthInput(unsigned idx) const { |
7372 | return false; |
7373 | } |
7374 | |
7375 | unsigned getNumInputs() const; |
7376 | std::string getInputName(unsigned idx) const; |
7377 | NodeValue getNthInput(unsigned idx); |
7378 | void setNthInput(unsigned idx, NodeValue val); |
7379 | llvm::StringRef getOutputName(unsigned idx) const; |
7380 | bool hasSideEffects() const { return 0; } |
7381 | bool isCanonical() const { return 1; } |
7382 | bool isDataParallel() const { return 1; } |
7383 | std::string getDebugDesc() const; |
7384 | bool isEqual(const SigmoidGradNode &other) const; |
7385 | llvm::hash_code getHash() const; |
7386 | void visit(Node *parent, NodeWalker *visitor); |
7387 | Node* clone() const; |
7388 | bool verify() const; |
7389 | }; |
7390 | } // namespace glow |
7391 | |
7392 | |
7393 | namespace glow { |
7394 | /// Applies Sigmoid, 1 / (1 + exp(-x)), to each element in the Input tensor. |
7395 | class SigmoidNode final : public Node { |
7396 | NodeHandle Input_; |
7397 | |
7398 | public: |
7399 | enum InputIndices { |
7400 | InputIdx = 0, |
7401 | }; |
7402 | |
7403 | enum ResultIndices { |
7404 | ResultIdx = 0, |
7405 | }; |
7406 | |
7407 | SigmoidNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
7408 | : Node(Kinded::Kind::SigmoidNodeKind, name), Input_(this, Input) { |
7409 | addResult(Result); |
7410 | } |
7411 | const NodeValue getInput() const { return Input_; } |
7412 | NodeValue getResult() { return getNthResult(0); } |
7413 | const NodeValue getResult() const { return getNthResult(0); } |
7414 | |
7415 | static bool classof(const Kinded *k) { |
7416 | return k->getKind() == Kinded::Kind::SigmoidNodeKind; |
7417 | } |
7418 | |
7419 | |
7420 | bool isOverwrittenNthInput(unsigned idx) const { |
7421 | return false; |
7422 | } |
7423 | |
7424 | unsigned getNumInputs() const; |
7425 | std::string getInputName(unsigned idx) const; |
7426 | NodeValue getNthInput(unsigned idx); |
7427 | void setNthInput(unsigned idx, NodeValue val); |
7428 | llvm::StringRef getOutputName(unsigned idx) const; |
7429 | bool hasSideEffects() const { return 0; } |
7430 | bool isCanonical() const { return 1; } |
7431 | bool isDataParallel() const { return 1; } |
7432 | std::string getDebugDesc() const; |
7433 | bool isEqual(const SigmoidNode &other) const; |
7434 | llvm::hash_code getHash() const; |
7435 | void visit(Node *parent, NodeWalker *visitor); |
7436 | Node* clone() const; |
7437 | bool verify() const; |
7438 | SigmoidGradNode *getGrad(GraphGradMapper &builder); |
7439 | }; |
7440 | } // namespace glow |
7441 | |
7442 | |
7443 | namespace glow { |
7444 | /// Applies Swish, X * Sigmoid(X), to each element in the Input tensor. |
7445 | class SwishNode final : public Node { |
7446 | NodeHandle Input_; |
7447 | |
7448 | public: |
7449 | enum InputIndices { |
7450 | InputIdx = 0, |
7451 | }; |
7452 | |
7453 | enum ResultIndices { |
7454 | ResultIdx = 0, |
7455 | }; |
7456 | |
7457 | SwishNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
7458 | : Node(Kinded::Kind::SwishNodeKind, name), Input_(this, Input) { |
7459 | addResult(Result); |
7460 | } |
7461 | const NodeValue getInput() const { return Input_; } |
7462 | NodeValue getResult() { return getNthResult(0); } |
7463 | const NodeValue getResult() const { return getNthResult(0); } |
7464 | |
7465 | static bool classof(const Kinded *k) { |
7466 | return k->getKind() == Kinded::Kind::SwishNodeKind; |
7467 | } |
7468 | |
7469 | |
7470 | bool isOverwrittenNthInput(unsigned idx) const { |
7471 | return false; |
7472 | } |
7473 | |
7474 | unsigned getNumInputs() const; |
7475 | std::string getInputName(unsigned idx) const; |
7476 | NodeValue getNthInput(unsigned idx); |
7477 | void setNthInput(unsigned idx, NodeValue val); |
7478 | llvm::StringRef getOutputName(unsigned idx) const; |
7479 | bool hasSideEffects() const { return 0; } |
7480 | bool isCanonical() const { return 1; } |
7481 | bool isDataParallel() const { return 1; } |
7482 | std::string getDebugDesc() const; |
7483 | bool isEqual(const SwishNode &other) const; |
7484 | llvm::hash_code getHash() const; |
7485 | void visit(Node *parent, NodeWalker *visitor); |
7486 | Node* clone() const; |
7487 | bool verify() const; |
7488 | }; |
7489 | } // namespace glow |
7490 | |
7491 | |
7492 | namespace glow { |
7493 | class TanhGradNode final : public Node { |
7494 | NodeHandle Input_; |
7495 | NodeHandle OriginalOutputForResult_; |
7496 | NodeHandle GradOfOriginalOutputNamedResult_; |
7497 | |
7498 | public: |
7499 | enum InputIndices { |
7500 | InputIdx = 0, |
7501 | OriginalOutputForResultIdx = 1, |
7502 | GradOfOriginalOutputNamedResultIdx = 2, |
7503 | }; |
7504 | |
7505 | enum ResultIndices { |
7506 | GradOfInputNamedInputIdx = 0, |
7507 | }; |
7508 | |
7509 | TanhGradNode(llvm::StringRef name, NodeValue Input, NodeValue OriginalOutputForResult, NodeValue GradOfOriginalOutputNamedResult) |
7510 | : Node(Kinded::Kind::TanhGradNodeKind, name), Input_(this, Input), OriginalOutputForResult_(this, OriginalOutputForResult), GradOfOriginalOutputNamedResult_(this, GradOfOriginalOutputNamedResult) { |
7511 | addResult(Input.getType()); |
7512 | } |
7513 | const NodeValue getInput() const { return Input_; } |
7514 | const NodeValue getOriginalOutputForResult() const { return OriginalOutputForResult_; } |
7515 | const NodeValue getGradOfOriginalOutputNamedResult() const { return GradOfOriginalOutputNamedResult_; } |
7516 | NodeValue getGradOfInputNamedInput() { return getNthResult(0); } |
7517 | const NodeValue getGradOfInputNamedInput() const { return getNthResult(0); } |
7518 | |
7519 | static bool classof(const Kinded *k) { |
7520 | return k->getKind() == Kinded::Kind::TanhGradNodeKind; |
7521 | } |
7522 | |
7523 | |
7524 | bool isOverwrittenNthInput(unsigned idx) const { |
7525 | return false; |
7526 | } |
7527 | |
7528 | unsigned getNumInputs() const; |
7529 | std::string getInputName(unsigned idx) const; |
7530 | NodeValue getNthInput(unsigned idx); |
7531 | void setNthInput(unsigned idx, NodeValue val); |
7532 | llvm::StringRef getOutputName(unsigned idx) const; |
7533 | bool hasSideEffects() const { return 0; } |
7534 | bool isCanonical() const { return 1; } |
7535 | bool isDataParallel() const { return 1; } |
7536 | std::string getDebugDesc() const; |
7537 | bool isEqual(const TanhGradNode &other) const; |
7538 | llvm::hash_code getHash() const; |
7539 | void visit(Node *parent, NodeWalker *visitor); |
7540 | Node* clone() const; |
7541 | bool verify() const; |
7542 | }; |
7543 | } // namespace glow |
7544 | |
7545 | |
7546 | namespace glow { |
7547 | /// Applies hyperbolic tangent to each element in the Input tensor. |
7548 | class TanhNode final : public Node { |
7549 | NodeHandle Input_; |
7550 | |
7551 | public: |
7552 | enum InputIndices { |
7553 | InputIdx = 0, |
7554 | }; |
7555 | |
7556 | enum ResultIndices { |
7557 | ResultIdx = 0, |
7558 | }; |
7559 | |
7560 | TanhNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
7561 | : Node(Kinded::Kind::TanhNodeKind, name), Input_(this, Input) { |
7562 | addResult(Result); |
7563 | } |
7564 | const NodeValue getInput() const { return Input_; } |
7565 | NodeValue getResult() { return getNthResult(0); } |
7566 | const NodeValue getResult() const { return getNthResult(0); } |
7567 | |
7568 | static bool classof(const Kinded *k) { |
7569 | return k->getKind() == Kinded::Kind::TanhNodeKind; |
7570 | } |
7571 | |
7572 | |
7573 | bool isOverwrittenNthInput(unsigned idx) const { |
7574 | return false; |
7575 | } |
7576 | |
7577 | unsigned getNumInputs() const; |
7578 | std::string getInputName(unsigned idx) const; |
7579 | NodeValue getNthInput(unsigned idx); |
7580 | void setNthInput(unsigned idx, NodeValue val); |
7581 | llvm::StringRef getOutputName(unsigned idx) const; |
7582 | bool hasSideEffects() const { return 0; } |
7583 | bool isCanonical() const { return 1; } |
7584 | bool isDataParallel() const { return 1; } |
7585 | std::string getDebugDesc() const; |
7586 | bool isEqual(const TanhNode &other) const; |
7587 | llvm::hash_code getHash() const; |
7588 | void visit(Node *parent, NodeWalker *visitor); |
7589 | Node* clone() const; |
7590 | bool verify() const; |
7591 | TanhGradNode *getGrad(GraphGradMapper &builder); |
7592 | }; |
7593 | } // namespace glow |
7594 | |
7595 | |
7596 | namespace glow { |
7597 | /// Applies LeakyReLU = x for positive x and alpha * x for negative x to each element in the Input tensor. |
7598 | class LeakyReluNode final : public Node { |
7599 | NodeHandle Input_; |
7600 | float Alpha_; |
7601 | |
7602 | public: |
7603 | enum InputIndices { |
7604 | InputIdx = 0, |
7605 | }; |
7606 | |
7607 | enum ResultIndices { |
7608 | ResultIdx = 0, |
7609 | }; |
7610 | |
7611 | LeakyReluNode(llvm::StringRef name, TypeRef Result , NodeValue Input, float Alpha) |
7612 | : Node(Kinded::Kind::LeakyReluNodeKind, name), Input_(this, Input), Alpha_(Alpha) { |
7613 | addResult(Result); |
7614 | } |
7615 | const NodeValue getInput() const { return Input_; } |
7616 | NodeValue getResult() { return getNthResult(0); } |
7617 | const NodeValue getResult() const { return getNthResult(0); } |
7618 | float getAlpha() const { return Alpha_; } |
7619 | |
7620 | static bool classof(const Kinded *k) { |
7621 | return k->getKind() == Kinded::Kind::LeakyReluNodeKind; |
7622 | } |
7623 | |
7624 | |
7625 | bool isOverwrittenNthInput(unsigned idx) const { |
7626 | return false; |
7627 | } |
7628 | |
7629 | unsigned getNumInputs() const; |
7630 | std::string getInputName(unsigned idx) const; |
7631 | NodeValue getNthInput(unsigned idx); |
7632 | void setNthInput(unsigned idx, NodeValue val); |
7633 | llvm::StringRef getOutputName(unsigned idx) const; |
7634 | bool hasSideEffects() const { return 0; } |
7635 | bool isCanonical() const { return 1; } |
7636 | bool isDataParallel() const { return 1; } |
7637 | std::string getDebugDesc() const; |
7638 | bool isEqual(const LeakyReluNode &other) const; |
7639 | llvm::hash_code getHash() const; |
7640 | void visit(Node *parent, NodeWalker *visitor); |
7641 | Node* clone() const; |
7642 | bool verify() const; |
7643 | }; |
7644 | } // namespace glow |
7645 | |
7646 | |
7647 | namespace glow { |
7648 | /// Performs SoftPlus, ln(exp(x) + 1), to each element in the Input tensor. |
7649 | class SoftPlusNode final : public Node { |
7650 | NodeHandle Input_; |
7651 | |
7652 | public: |
7653 | enum InputIndices { |
7654 | InputIdx = 0, |
7655 | }; |
7656 | |
7657 | enum ResultIndices { |
7658 | ResultIdx = 0, |
7659 | }; |
7660 | |
7661 | SoftPlusNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
7662 | : Node(Kinded::Kind::SoftPlusNodeKind, name), Input_(this, Input) { |
7663 | addResult(Result); |
7664 | } |
7665 | const NodeValue getInput() const { return Input_; } |
7666 | NodeValue getResult() { return getNthResult(0); } |
7667 | const NodeValue getResult() const { return getNthResult(0); } |
7668 | |
7669 | static bool classof(const Kinded *k) { |
7670 | return k->getKind() == Kinded::Kind::SoftPlusNodeKind; |
7671 | } |
7672 | |
7673 | |
7674 | bool isOverwrittenNthInput(unsigned idx) const { |
7675 | return false; |
7676 | } |
7677 | |
7678 | unsigned getNumInputs() const; |
7679 | std::string getInputName(unsigned idx) const; |
7680 | NodeValue getNthInput(unsigned idx); |
7681 | void setNthInput(unsigned idx, NodeValue val); |
7682 | llvm::StringRef getOutputName(unsigned idx) const; |
7683 | bool hasSideEffects() const { return 0; } |
7684 | bool isCanonical() const { return 1; } |
7685 | bool isDataParallel() const { return 1; } |
7686 | std::string getDebugDesc() const; |
7687 | bool isEqual(const SoftPlusNode &other) const; |
7688 | llvm::hash_code getHash() const; |
7689 | void visit(Node *parent, NodeWalker *visitor); |
7690 | Node* clone() const; |
7691 | bool verify() const; |
7692 | }; |
7693 | } // namespace glow |
7694 | |
7695 | |
7696 | namespace glow { |
7697 | /// Reshape the Input tensor to shape Dims. |
7698 | class ReshapeNode final : public Node { |
7699 | NodeHandle Input_; |
7700 | std::vector<dim_t> Dims_; |
7701 | std::string Layout_; |
7702 | |
7703 | public: |
7704 | enum InputIndices { |
7705 | InputIdx = 0, |
7706 | }; |
7707 | |
7708 | enum ResultIndices { |
7709 | ResultIdx = 0, |
7710 | }; |
7711 | |
7712 | ReshapeNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<dim_t> Dims, std::string Layout) |
7713 | : Node(Kinded::Kind::ReshapeNodeKind, name), Input_(this, Input), Dims_(Dims), Layout_(Layout) { |
7714 | addResult(Result); |
7715 | } |
7716 | const NodeValue getInput() const { return Input_; } |
7717 | NodeValue getResult() { return getNthResult(0); } |
7718 | const NodeValue getResult() const { return getNthResult(0); } |
7719 | llvm::ArrayRef<dim_t> getDims() const { return Dims_; } |
7720 | std::string getLayout() const { return Layout_; } |
7721 | |
7722 | static bool classof(const Kinded *k) { |
7723 | return k->getKind() == Kinded::Kind::ReshapeNodeKind; |
7724 | } |
7725 | |
7726 | |
7727 | bool isOverwrittenNthInput(unsigned idx) const { |
7728 | return false; |
7729 | } |
7730 | |
7731 | unsigned getNumInputs() const; |
7732 | std::string getInputName(unsigned idx) const; |
7733 | NodeValue getNthInput(unsigned idx); |
7734 | void setNthInput(unsigned idx, NodeValue val); |
7735 | llvm::StringRef getOutputName(unsigned idx) const; |
7736 | bool hasSideEffects() const { return 0; } |
7737 | bool isCanonical() const { return 1; } |
7738 | bool isDataParallel() const { return 0; } |
7739 | std::string getDebugDesc() const; |
7740 | bool isEqual(const ReshapeNode &other) const; |
7741 | llvm::hash_code getHash() const; |
7742 | void visit(Node *parent, NodeWalker *visitor); |
7743 | Node* clone() const; |
7744 | bool verify() const; |
7745 | }; |
7746 | } // namespace glow |
7747 | |
7748 | |
7749 | namespace glow { |
7750 | /// Transpose the Input tensor based on the vector Shuffle, which assigns a new axis for each dimension in Input. |
7751 | class TransposeNode final : public Node { |
7752 | NodeHandle Input_; |
7753 | std::vector<unsigned_t> Shuffle_; |
7754 | std::string Layout_; |
7755 | |
7756 | public: |
7757 | enum InputIndices { |
7758 | InputIdx = 0, |
7759 | }; |
7760 | |
7761 | enum ResultIndices { |
7762 | ResultIdx = 0, |
7763 | }; |
7764 | |
7765 | TransposeNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<unsigned_t> Shuffle, std::string Layout) |
7766 | : Node(Kinded::Kind::TransposeNodeKind, name), Input_(this, Input), Shuffle_(Shuffle), Layout_(Layout) { |
7767 | addResult(Result); |
7768 | } |
7769 | const NodeValue getInput() const { return Input_; } |
7770 | NodeValue getResult() { return getNthResult(0); } |
7771 | const NodeValue getResult() const { return getNthResult(0); } |
7772 | llvm::ArrayRef<unsigned_t> getShuffle() const { return Shuffle_; } |
7773 | std::string getLayout() const { return Layout_; } |
7774 | |
7775 | static bool classof(const Kinded *k) { |
7776 | return k->getKind() == Kinded::Kind::TransposeNodeKind; |
7777 | } |
7778 | |
7779 | |
7780 | bool isOverwrittenNthInput(unsigned idx) const { |
7781 | return false; |
7782 | } |
7783 | |
7784 | unsigned getNumInputs() const; |
7785 | std::string getInputName(unsigned idx) const; |
7786 | NodeValue getNthInput(unsigned idx); |
7787 | void setNthInput(unsigned idx, NodeValue val); |
7788 | llvm::StringRef getOutputName(unsigned idx) const; |
7789 | bool hasSideEffects() const { return 0; } |
7790 | bool isCanonical() const { return 1; } |
7791 | bool isDataParallel() const { return 0; } |
7792 | std::string getDebugDesc() const; |
7793 | bool isEqual(const TransposeNode &other) const; |
7794 | llvm::hash_code getHash() const; |
7795 | void visit(Node *parent, NodeWalker *visitor); |
7796 | Node* clone() const; |
7797 | bool verify() const; |
7798 | }; |
7799 | } // namespace glow |
7800 | |
7801 | |
7802 | namespace glow { |
7803 | /// The concat operator adds two tensors together. |
7804 | /// The parameter 'dim' specifies the dimension to use when joining the tensors. |
7805 | class ConcatNode final : public Node { |
7806 | std::vector<NodeHandle> Inputs_; |
7807 | unsigned_t Dim_; |
7808 | |
7809 | public: |
7810 | enum InputIndices { |
7811 | }; |
7812 | |
7813 | enum ResultIndices { |
7814 | ResultIdx = 0, |
7815 | }; |
7816 | |
7817 | ConcatNode(llvm::StringRef name, TypeRef Result , std::vector<NodeValue> Inputs, unsigned_t Dim) |
7818 | : Node(Kinded::Kind::ConcatNodeKind, name), Dim_(Dim) { |
7819 | addResult(Result); |
7820 | Inputs_.resize(Inputs.size()); |
7821 | for (size_t idx = 0, e = Inputs.size(); idx < e; ++idx) { |
7822 | Inputs_[idx] = Inputs[idx]; |
7823 | Inputs_[idx].setParent(this); |
7824 | } |
7825 | } |
7826 | NodeValue getResult() { return getNthResult(0); } |
7827 | const NodeValue getResult() const { return getNthResult(0); } |
7828 | NodeValueArrayRef getInputs() const { return Inputs_; } |
7829 | unsigned_t getDim() const { return Dim_; } |
7830 | |
7831 | static bool classof(const Kinded *k) { |
7832 | return k->getKind() == Kinded::Kind::ConcatNodeKind; |
7833 | } |
7834 | |
7835 | |
7836 | bool isOverwrittenNthInput(unsigned idx) const { |
7837 | return false; |
7838 | } |
7839 | |
7840 | unsigned getNumInputs() const; |
7841 | std::string getInputName(unsigned idx) const; |
7842 | NodeValue getNthInput(unsigned idx); |
7843 | void setNthInput(unsigned idx, NodeValue val); |
7844 | llvm::StringRef getOutputName(unsigned idx) const; |
7845 | bool hasSideEffects() const { return 0; } |
7846 | bool isCanonical() const { return 1; } |
7847 | bool isDataParallel() const { return 0; } |
7848 | std::string getDebugDesc() const; |
7849 | bool isEqual(const ConcatNode &other) const; |
7850 | llvm::hash_code getHash() const; |
7851 | void visit(Node *parent, NodeWalker *visitor); |
7852 | Node* clone() const; |
7853 | bool verify() const; |
7854 | }; |
7855 | } // namespace glow |
7856 | |
7857 | |
7858 | namespace glow { |
7859 | /// Produces a slice of the Input tensor. The Start vector defines the starting indices for each dimension from which the slice should be taken. The end index for each dimension is determined from the input type's shape. |
7860 | class SliceNode final : public Node { |
7861 | NodeHandle Input_; |
7862 | std::vector<dim_t> Start_; |
7863 | |
7864 | public: |
7865 | enum InputIndices { |
7866 | InputIdx = 0, |
7867 | }; |
7868 | |
7869 | enum ResultIndices { |
7870 | ResultIdx = 0, |
7871 | }; |
7872 | |
7873 | SliceNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<dim_t> Start) |
7874 | : Node(Kinded::Kind::SliceNodeKind, name), Input_(this, Input), Start_(Start) { |
7875 | addResult(Result); |
7876 | } |
7877 | const NodeValue getInput() const { return Input_; } |
7878 | NodeValue getResult() { return getNthResult(0); } |
7879 | const NodeValue getResult() const { return getNthResult(0); } |
7880 | llvm::ArrayRef<dim_t> getStart() const { return Start_; } |
7881 | |
7882 | static bool classof(const Kinded *k) { |
7883 | return k->getKind() == Kinded::Kind::SliceNodeKind; |
7884 | } |
7885 | |
7886 | |
7887 | bool isOverwrittenNthInput(unsigned idx) const { |
7888 | return false; |
7889 | } |
7890 | |
7891 | unsigned getNumInputs() const; |
7892 | std::string getInputName(unsigned idx) const; |
7893 | NodeValue getNthInput(unsigned idx); |
7894 | void setNthInput(unsigned idx, NodeValue val); |
7895 | llvm::StringRef getOutputName(unsigned idx) const; |
7896 | bool hasSideEffects() const { return 0; } |
7897 | bool isCanonical() const { return 1; } |
7898 | bool isDataParallel() const { return 0; } |
7899 | std::string getDebugDesc() const; |
7900 | bool isEqual(const SliceNode &other) const; |
7901 | llvm::hash_code getHash() const; |
7902 | void visit(Node *parent, NodeWalker *visitor); |
7903 | Node* clone() const; |
7904 | bool verify() const; |
7905 | }; |
7906 | } // namespace glow |
7907 | |
7908 | |
7909 | namespace glow { |
7910 | /// Insert tensor Small into tensor Big given indices Start. Small is inserted Count times along Axis. The resulting Tensor will have the same type as the input Big tensor. |
7911 | class InsertTensorNode final : public Node { |
7912 | NodeHandle Big_; |
7913 | NodeHandle Small_; |
7914 | std::vector<dim_t> Start_; |
7915 | unsigned_t Count_; |
7916 | unsigned_t Axis_; |
7917 | |
7918 | public: |
7919 | enum InputIndices { |
7920 | BigIdx = 0, |
7921 | SmallIdx = 1, |
7922 | }; |
7923 | |
7924 | enum ResultIndices { |
7925 | ResultIdx = 0, |
7926 | }; |
7927 | |
7928 | InsertTensorNode(llvm::StringRef name, NodeValue Big, NodeValue Small, std::vector<dim_t> Start, unsigned_t Count, unsigned_t Axis) |
7929 | : Node(Kinded::Kind::InsertTensorNodeKind, name), Big_(this, Big), Small_(this, Small), Start_(Start), Count_(Count), Axis_(Axis) { |
7930 | addResult(Big.getType()); |
7931 | } |
7932 | const NodeValue getBig() const { return Big_; } |
7933 | const NodeValue getSmall() const { return Small_; } |
7934 | NodeValue getResult() { return getNthResult(0); } |
7935 | const NodeValue getResult() const { return getNthResult(0); } |
7936 | llvm::ArrayRef<dim_t> getStart() const { return Start_; } |
7937 | unsigned_t getCount() const { return Count_; } |
7938 | unsigned_t getAxis() const { return Axis_; } |
7939 | |
7940 | static bool classof(const Kinded *k) { |
7941 | return k->getKind() == Kinded::Kind::InsertTensorNodeKind; |
7942 | } |
7943 | |
7944 | |
7945 | bool isOverwrittenNthInput(unsigned idx) const { |
7946 | return false; |
7947 | } |
7948 | |
7949 | unsigned getNumInputs() const; |
7950 | std::string getInputName(unsigned idx) const; |
7951 | NodeValue getNthInput(unsigned idx); |
7952 | void setNthInput(unsigned idx, NodeValue val); |
7953 | llvm::StringRef getOutputName(unsigned idx) const; |
7954 | bool hasSideEffects() const { return 0; } |
7955 | bool isCanonical() const { return 1; } |
7956 | bool isDataParallel() const { return 0; } |
7957 | std::string getDebugDesc() const; |
7958 | bool isEqual(const InsertTensorNode &other) const; |
7959 | llvm::hash_code getHash() const; |
7960 | void visit(Node *parent, NodeWalker *visitor); |
7961 | Node* clone() const; |
7962 | bool verify() const; |
7963 | }; |
7964 | } // namespace glow |
7965 | |
7966 | |
7967 | namespace glow { |
7968 | /// Gathers entries of the outer-most dimension of Data indexed by Indices, and concatenates them. Output tensor will have dimensions: {I_0, I_1, ... I_n, D_1, D_2, ... D_m}, where D_i and I_j denote Data and Indices dimensions respectively. If axis is not zero, the gather operator will treat the first axis as the batch and will concat the result of the gather operation on each sample in the batch. |
7969 | class GatherNode final : public Node { |
7970 | NodeHandle Data_; |
7971 | NodeHandle Indices_; |
7972 | unsigned_t BatchDims_; |
7973 | |
7974 | public: |
7975 | enum InputIndices { |
7976 | DataIdx = 0, |
7977 | IndicesIdx = 1, |
7978 | }; |
7979 | |
7980 | enum ResultIndices { |
7981 | ResultIdx = 0, |
7982 | }; |
7983 | |
7984 | GatherNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, unsigned_t BatchDims) |
7985 | : Node(Kinded::Kind::GatherNodeKind, name), Data_(this, Data), Indices_(this, Indices), BatchDims_(BatchDims) { |
7986 | addResult(Result); |
7987 | } |
7988 | const NodeValue getData() const { return Data_; } |
7989 | const NodeValue getIndices() const { return Indices_; } |
7990 | NodeValue getResult() { return getNthResult(0); } |
7991 | const NodeValue getResult() const { return getNthResult(0); } |
7992 | unsigned_t getBatchDims() const { return BatchDims_; } |
7993 | |
7994 | static bool classof(const Kinded *k) { |
7995 | return k->getKind() == Kinded::Kind::GatherNodeKind; |
7996 | } |
7997 | |
7998 | |
7999 | bool isOverwrittenNthInput(unsigned idx) const { |
8000 | return false; |
8001 | } |
8002 | |
8003 | unsigned getNumInputs() const; |
8004 | std::string getInputName(unsigned idx) const; |
8005 | NodeValue getNthInput(unsigned idx); |
8006 | void setNthInput(unsigned idx, NodeValue val); |
8007 | llvm::StringRef getOutputName(unsigned idx) const; |
8008 | bool hasSideEffects() const { return 0; } |
8009 | bool isCanonical() const { return 1; } |
8010 | bool isDataParallel() const { return 0; } |
8011 | std::string getDebugDesc() const; |
8012 | bool isEqual(const GatherNode &other) const; |
8013 | llvm::hash_code getHash() const; |
8014 | void visit(Node *parent, NodeWalker *visitor); |
8015 | Node* clone() const; |
8016 | bool verify() const; |
8017 | }; |
8018 | } // namespace glow |
8019 | |
8020 | |
8021 | namespace glow { |
8022 | /// Given Data tensor of rank r >= 1, Indices tensor of rank q >= 1 This operator gathers slices of Data into an output tensor of rank q + r - Indices_shape[-1] - 1 . |
8023 | class GatherNDNode final : public Node { |
8024 | NodeHandle Data_; |
8025 | NodeHandle Indices_; |
8026 | unsigned_t BatchDims_; |
8027 | |
8028 | public: |
8029 | enum InputIndices { |
8030 | DataIdx = 0, |
8031 | IndicesIdx = 1, |
8032 | }; |
8033 | |
8034 | enum ResultIndices { |
8035 | ResultIdx = 0, |
8036 | }; |
8037 | |
8038 | GatherNDNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, unsigned_t BatchDims) |
8039 | : Node(Kinded::Kind::GatherNDNodeKind, name), Data_(this, Data), Indices_(this, Indices), BatchDims_(BatchDims) { |
8040 | addResult(Result); |
8041 | } |
8042 | const NodeValue getData() const { return Data_; } |
8043 | const NodeValue getIndices() const { return Indices_; } |
8044 | NodeValue getResult() { return getNthResult(0); } |
8045 | const NodeValue getResult() const { return getNthResult(0); } |
8046 | unsigned_t getBatchDims() const { return BatchDims_; } |
8047 | |
8048 | static bool classof(const Kinded *k) { |
8049 | return k->getKind() == Kinded::Kind::GatherNDNodeKind; |
8050 | } |
8051 | |
8052 | |
8053 | bool isOverwrittenNthInput(unsigned idx) const { |
8054 | return false; |
8055 | } |
8056 | |
8057 | unsigned getNumInputs() const; |
8058 | std::string getInputName(unsigned idx) const; |
8059 | NodeValue getNthInput(unsigned idx); |
8060 | void setNthInput(unsigned idx, NodeValue val); |
8061 | llvm::StringRef getOutputName(unsigned idx) const; |
8062 | bool hasSideEffects() const { return 0; } |
8063 | bool isCanonical() const { return 1; } |
8064 | bool isDataParallel() const { return 0; } |
8065 | std::string getDebugDesc() const; |
8066 | bool isEqual(const GatherNDNode &other) const; |
8067 | llvm::hash_code getHash() const; |
8068 | void visit(Node *parent, NodeWalker *visitor); |
8069 | Node* clone() const; |
8070 | bool verify() const; |
8071 | }; |
8072 | } // namespace glow |
8073 | |
8074 | |
8075 | namespace glow { |
8076 | /// GatherElements takes inputs data and indices of the same rank r >= 1 and an attribute axis specified by dim. It is an indexingoperation that produces its output by indexing into the input data tensor at index positions determined by elements of the indices tensor. Its output shape is the same as the shape of indices and consists of one value (gathered from the data) for each element in indices. |
8077 | class GatherElementsNode final : public Node { |
8078 | NodeHandle Data_; |
8079 | NodeHandle Indices_; |
8080 | unsigned_t Dim_; |
8081 | |
8082 | public: |
8083 | enum InputIndices { |
8084 | DataIdx = 0, |
8085 | IndicesIdx = 1, |
8086 | }; |
8087 | |
8088 | enum ResultIndices { |
8089 | ResultIdx = 0, |
8090 | }; |
8091 | |
8092 | GatherElementsNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Indices, unsigned_t Dim) |
8093 | : Node(Kinded::Kind::GatherElementsNodeKind, name), Data_(this, Data), Indices_(this, Indices), Dim_(Dim) { |
8094 | addResult(Result); |
8095 | } |
8096 | const NodeValue getData() const { return Data_; } |
8097 | const NodeValue getIndices() const { return Indices_; } |
8098 | NodeValue getResult() { return getNthResult(0); } |
8099 | const NodeValue getResult() const { return getNthResult(0); } |
8100 | unsigned_t getDim() const { return Dim_; } |
8101 | |
8102 | static bool classof(const Kinded *k) { |
8103 | return k->getKind() == Kinded::Kind::GatherElementsNodeKind; |
8104 | } |
8105 | |
8106 | |
8107 | bool isOverwrittenNthInput(unsigned idx) const { |
8108 | return false; |
8109 | } |
8110 | |
8111 | unsigned getNumInputs() const; |
8112 | std::string getInputName(unsigned idx) const; |
8113 | NodeValue getNthInput(unsigned idx); |
8114 | void setNthInput(unsigned idx, NodeValue val); |
8115 | llvm::StringRef getOutputName(unsigned idx) const; |
8116 | bool hasSideEffects() const { return 0; } |
8117 | bool isCanonical() const { return 1; } |
8118 | bool isDataParallel() const { return 0; } |
8119 | std::string getDebugDesc() const; |
8120 | bool isEqual(const GatherElementsNode &other) const; |
8121 | llvm::hash_code getHash() const; |
8122 | void visit(Node *parent, NodeWalker *visitor); |
8123 | Node* clone() const; |
8124 | bool verify() const; |
8125 | }; |
8126 | } // namespace glow |
8127 | |
8128 | |
8129 | namespace glow { |
8130 | /// Gathers entries of Data into Output in groups specified by the elements of Ranges. Each element of Ranges contains a list of pairs of indices of the form (index, length) which specify which entries of data to gather. The ordering of elements in Ranges and of pairs within an element is preserved in Output. Lengths contains the lengths of the ranges gathered by each list of pairs in Ranges. |
8131 | class GatherRangesNode final : public Node { |
8132 | NodeHandle Data_; |
8133 | NodeHandle Ranges_; |
8134 | |
8135 | public: |
8136 | enum InputIndices { |
8137 | DataIdx = 0, |
8138 | RangesIdx = 1, |
8139 | }; |
8140 | |
8141 | enum ResultIndices { |
8142 | OutputIdx = 0, |
8143 | LengthsIdx = 1, |
8144 | }; |
8145 | |
8146 | GatherRangesNode(llvm::StringRef name, TypeRef Output , TypeRef Lengths , NodeValue Data, NodeValue Ranges) |
8147 | : Node(Kinded::Kind::GatherRangesNodeKind, name), Data_(this, Data), Ranges_(this, Ranges) { |
8148 | addResult(Output); |
8149 | addResult(Lengths); |
8150 | } |
8151 | const NodeValue getData() const { return Data_; } |
8152 | const NodeValue getRanges() const { return Ranges_; } |
8153 | NodeValue getOutput() { return getNthResult(0); } |
8154 | const NodeValue getOutput() const { return getNthResult(0); } |
8155 | NodeValue getLengths() { return getNthResult(1); } |
8156 | const NodeValue getLengths() const { return getNthResult(1); } |
8157 | |
8158 | static bool classof(const Kinded *k) { |
8159 | return k->getKind() == Kinded::Kind::GatherRangesNodeKind; |
8160 | } |
8161 | |
8162 | |
8163 | bool isOverwrittenNthInput(unsigned idx) const { |
8164 | return false; |
8165 | } |
8166 | |
8167 | unsigned getNumInputs() const; |
8168 | std::string getInputName(unsigned idx) const; |
8169 | NodeValue getNthInput(unsigned idx); |
8170 | void setNthInput(unsigned idx, NodeValue val); |
8171 | llvm::StringRef getOutputName(unsigned idx) const; |
8172 | bool hasSideEffects() const { return 0; } |
8173 | bool isCanonical() const { return 1; } |
8174 | bool isDataParallel() const { return 0; } |
8175 | std::string getDebugDesc() const; |
8176 | bool isEqual(const GatherRangesNode &other) const; |
8177 | llvm::hash_code getHash() const; |
8178 | void visit(Node *parent, NodeWalker *visitor); |
8179 | Node* clone() const; |
8180 | bool verify() const; |
8181 | }; |
8182 | } // namespace glow |
8183 | |
8184 | |
8185 | namespace glow { |
8186 | /// Copies each slice from Slices into Data at the corresponding index in Indices. For example, given input Data {{1,2},{3,4},{5,6}}, Slices {{-3,-4}}, and Indices {{1}}, the result is {{1,2},{-3,-4},{5,6}}. It also supports multi-dimensional indices. For example, given input Data {{1,2},{3,4},{5,6}}, Slices {-3,-4}, and Indices {{1,0},{1,1}} also produces {{1,2},{-3,-4},{5,6}}. If Cumulative is true, the node adds values from Slices to Data instead of copying. For example, given input Data {{1,2},{3,4},{5,6}}, Slices {{-3,-4}}, and Indices {1}, the result is {{1,2},{0,0},{5,6}}. If an index is specified several times, its updates will be added several times as well. |
8187 | class ScatterDataNode final : public Node { |
8188 | NodeHandle Data_; |
8189 | NodeHandle Indices_; |
8190 | NodeHandle Slices_; |
8191 | bool Cumulative_; |
8192 | |
8193 | public: |
8194 | enum InputIndices { |
8195 | DataIdx = 0, |
8196 | IndicesIdx = 1, |
8197 | SlicesIdx = 2, |
8198 | }; |
8199 | |
8200 | enum ResultIndices { |
8201 | ResultIdx = 0, |
8202 | }; |
8203 | |
8204 | ScatterDataNode(llvm::StringRef name, NodeValue Data, NodeValue Indices, NodeValue Slices, bool Cumulative) |
8205 | : Node(Kinded::Kind::ScatterDataNodeKind, name), Data_(this, Data), Indices_(this, Indices), Slices_(this, Slices), Cumulative_(Cumulative) { |
8206 | addResult(Data.getType()); |
8207 | } |
8208 | const NodeValue getData() const { return Data_; } |
8209 | const NodeValue getIndices() const { return Indices_; } |
8210 | const NodeValue getSlices() const { return Slices_; } |
8211 | NodeValue getResult() { return getNthResult(0); } |
8212 | const NodeValue getResult() const { return getNthResult(0); } |
8213 | bool getCumulative() const { return Cumulative_; } |
8214 | |
8215 | static bool classof(const Kinded *k) { |
8216 | return k->getKind() == Kinded::Kind::ScatterDataNodeKind; |
8217 | } |
8218 | |
8219 | |
8220 | bool isOverwrittenNthInput(unsigned idx) const { |
8221 | return false; |
8222 | } |
8223 | |
8224 | unsigned getNumInputs() const; |
8225 | std::string getInputName(unsigned idx) const; |
8226 | NodeValue getNthInput(unsigned idx); |
8227 | void setNthInput(unsigned idx, NodeValue val); |
8228 | llvm::StringRef getOutputName(unsigned idx) const; |
8229 | bool hasSideEffects() const { return 0; } |
8230 | bool isCanonical() const { return 1; } |
8231 | bool isDataParallel() const { return 0; } |
8232 | std::string getDebugDesc() const; |
8233 | bool isEqual(const ScatterDataNode &other) const; |
8234 | llvm::hash_code getHash() const; |
8235 | void visit(Node *parent, NodeWalker *visitor); |
8236 | Node* clone() const; |
8237 | bool verify() const; |
8238 | }; |
8239 | } // namespace glow |
8240 | |
8241 | |
8242 | namespace glow { |
8243 | /// Tile an Input tensor Count times along Axis. |
8244 | class TileNode final : public Node { |
8245 | NodeHandle Input_; |
8246 | unsigned_t Count_; |
8247 | unsigned_t Axis_; |
8248 | |
8249 | public: |
8250 | enum InputIndices { |
8251 | InputIdx = 0, |
8252 | }; |
8253 | |
8254 | enum ResultIndices { |
8255 | ResultIdx = 0, |
8256 | }; |
8257 | |
8258 | TileNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Count, unsigned_t Axis) |
8259 | : Node(Kinded::Kind::TileNodeKind, name), Input_(this, Input), Count_(Count), Axis_(Axis) { |
8260 | addResult(Result); |
8261 | } |
8262 | const NodeValue getInput() const { return Input_; } |
8263 | NodeValue getResult() { return getNthResult(0); } |
8264 | const NodeValue getResult() const { return getNthResult(0); } |
8265 | unsigned_t getCount() const { return Count_; } |
8266 | unsigned_t getAxis() const { return Axis_; } |
8267 | |
8268 | static bool classof(const Kinded *k) { |
8269 | return k->getKind() == Kinded::Kind::TileNodeKind; |
8270 | } |
8271 | |
8272 | |
8273 | bool isOverwrittenNthInput(unsigned idx) const { |
8274 | return false; |
8275 | } |
8276 | |
8277 | unsigned getNumInputs() const; |
8278 | std::string getInputName(unsigned idx) const; |
8279 | NodeValue getNthInput(unsigned idx); |
8280 | void setNthInput(unsigned idx, NodeValue val); |
8281 | llvm::StringRef getOutputName(unsigned idx) const; |
8282 | bool hasSideEffects() const { return 0; } |
8283 | bool isCanonical() const { return 1; } |
8284 | bool isDataParallel() const { return 0; } |
8285 | std::string getDebugDesc() const; |
8286 | bool isEqual(const TileNode &other) const; |
8287 | llvm::hash_code getHash() const; |
8288 | void visit(Node *parent, NodeWalker *visitor); |
8289 | Node* clone() const; |
8290 | bool verify() const; |
8291 | }; |
8292 | } // namespace glow |
8293 | |
8294 | |
8295 | namespace glow { |
8296 | /// Expands each row of the Data to a row of zeros and ones, according to One Hot Encoding. i-th element of Result's row is one iff Values[i] equals to the corresponding element of Data. |
8297 | class BatchOneHotNode final : public Node { |
8298 | NodeHandle Data_; |
8299 | NodeHandle Lengths_; |
8300 | NodeHandle Values_; |
8301 | |
8302 | public: |
8303 | enum InputIndices { |
8304 | DataIdx = 0, |
8305 | LengthsIdx = 1, |
8306 | ValuesIdx = 2, |
8307 | }; |
8308 | |
8309 | enum ResultIndices { |
8310 | ResultIdx = 0, |
8311 | }; |
8312 | |
8313 | BatchOneHotNode(llvm::StringRef name, TypeRef Result , NodeValue Data, NodeValue Lengths, NodeValue Values) |
8314 | : Node(Kinded::Kind::BatchOneHotNodeKind, name), Data_(this, Data), Lengths_(this, Lengths), Values_(this, Values) { |
8315 | addResult(Result); |
8316 | } |
8317 | const NodeValue getData() const { return Data_; } |
8318 | const NodeValue getLengths() const { return Lengths_; } |
8319 | const NodeValue getValues() const { return Values_; } |
8320 | NodeValue getResult() { return getNthResult(0); } |
8321 | const NodeValue getResult() const { return getNthResult(0); } |
8322 | |
8323 | static bool classof(const Kinded *k) { |
8324 | return k->getKind() == Kinded::Kind::BatchOneHotNodeKind; |
8325 | } |
8326 | |
8327 | |
8328 | bool isOverwrittenNthInput(unsigned idx) const { |
8329 | return false; |
8330 | } |
8331 | |
8332 | unsigned getNumInputs() const; |
8333 | std::string getInputName(unsigned idx) const; |
8334 | NodeValue getNthInput(unsigned idx); |
8335 | void setNthInput(unsigned idx, NodeValue val); |
8336 | llvm::StringRef getOutputName(unsigned idx) const; |
8337 | bool hasSideEffects() const { return 0; } |
8338 | bool isCanonical() const { return 1; } |
8339 | bool isDataParallel() const { return 0; } |
8340 | std::string getDebugDesc() const; |
8341 | bool isEqual(const BatchOneHotNode &other) const; |
8342 | llvm::hash_code getHash() const; |
8343 | void visit(Node *parent, NodeWalker *visitor); |
8344 | Node* clone() const; |
8345 | bool verify() const; |
8346 | }; |
8347 | } // namespace glow |
8348 | |
8349 | |
8350 | namespace glow { |
8351 | /// Given Input tensor of [N,H,W,C], where N is the batch axis, C is the channel or depth, H is the height and W is the width. This produces Output tensor of [N, H/BlockSize, W/BlockSize, C * BlockSize * BlockSize]. |
8352 | class SpaceToDepthNode final : public Node { |
8353 | NodeHandle Input_; |
8354 | unsigned_t BlockSize_; |
8355 | |
8356 | public: |
8357 | enum InputIndices { |
8358 | InputIdx = 0, |
8359 | }; |
8360 | |
8361 | enum ResultIndices { |
8362 | ResultIdx = 0, |
8363 | }; |
8364 | |
8365 | SpaceToDepthNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t BlockSize) |
8366 | : Node(Kinded::Kind::SpaceToDepthNodeKind, name), Input_(this, Input), BlockSize_(BlockSize) { |
8367 | addResult(Result); |
8368 | } |
8369 | const NodeValue getInput() const { return Input_; } |
8370 | NodeValue getResult() { return getNthResult(0); } |
8371 | const NodeValue getResult() const { return getNthResult(0); } |
8372 | unsigned_t getBlockSize() const { return BlockSize_; } |
8373 | |
8374 | static bool classof(const Kinded *k) { |
8375 | return k->getKind() == Kinded::Kind::SpaceToDepthNodeKind; |
8376 | } |
8377 | |
8378 | |
8379 | bool isOverwrittenNthInput(unsigned idx) const { |
8380 | return false; |
8381 | } |
8382 | |
8383 | unsigned getNumInputs() const; |
8384 | std::string getInputName(unsigned idx) const; |
8385 | NodeValue getNthInput(unsigned idx); |
8386 | void setNthInput(unsigned idx, NodeValue val); |
8387 | llvm::StringRef getOutputName(unsigned idx) const; |
8388 | bool hasSideEffects() const { return 0; } |
8389 | bool isCanonical() const { return 1; } |
8390 | bool isDataParallel() const { return 0; } |
8391 | std::string getDebugDesc() const; |
8392 | bool isEqual(const SpaceToDepthNode &other) const; |
8393 | llvm::hash_code getHash() const; |
8394 | void visit(Node *parent, NodeWalker *visitor); |
8395 | Node* clone() const; |
8396 | bool verify() const; |
8397 | }; |
8398 | } // namespace glow |
8399 | |
8400 | |
8401 | namespace glow { |
8402 | /// Given Input tensor of 3D, 4D, 5D or 6D, generates an Output tensor with resized spatial dimensions using nearest neighbor interpolation. The Output tensor is of shape floor(input_dimension * scale) |
8403 | class ResizeNearestNode final : public Node { |
8404 | NodeHandle Input_; |
8405 | std::vector<float> Scale_; |
8406 | |
8407 | public: |
8408 | enum InputIndices { |
8409 | InputIdx = 0, |
8410 | }; |
8411 | |
8412 | enum ResultIndices { |
8413 | ResultIdx = 0, |
8414 | }; |
8415 | |
8416 | ResizeNearestNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<float> Scale) |
8417 | : Node(Kinded::Kind::ResizeNearestNodeKind, name), Input_(this, Input), Scale_(Scale) { |
8418 | addResult(Result); |
8419 | } |
8420 | const NodeValue getInput() const { return Input_; } |
8421 | NodeValue getResult() { return getNthResult(0); } |
8422 | const NodeValue getResult() const { return getNthResult(0); } |
8423 | llvm::ArrayRef<float> getScale() const { return Scale_; } |
8424 | |
8425 | static bool classof(const Kinded *k) { |
8426 | return k->getKind() == Kinded::Kind::ResizeNearestNodeKind; |
8427 | } |
8428 | |
8429 | |
8430 | bool isOverwrittenNthInput(unsigned idx) const { |
8431 | return false; |
8432 | } |
8433 | |
8434 | unsigned getNumInputs() const; |
8435 | std::string getInputName(unsigned idx) const; |
8436 | NodeValue getNthInput(unsigned idx); |
8437 | void setNthInput(unsigned idx, NodeValue val); |
8438 | llvm::StringRef getOutputName(unsigned idx) const; |
8439 | bool hasSideEffects() const { return 0; } |
8440 | bool isCanonical() const { return 1; } |
8441 | bool isDataParallel() const { return 0; } |
8442 | std::string getDebugDesc() const; |
8443 | bool isEqual(const ResizeNearestNode &other) const; |
8444 | llvm::hash_code getHash() const; |
8445 | void visit(Node *parent, NodeWalker *visitor); |
8446 | Node* clone() const; |
8447 | bool verify() const; |
8448 | }; |
8449 | } // namespace glow |
8450 | |
8451 | |
8452 | namespace glow { |
8453 | /// Given Input tensor of [N,H,W,C], where N is the batch, C is the channel or depth, H is the height and W is the width, Generates an Output tensor with resized spatial dimensions using bilinear neighbor interpolation. The Output tensor is of shape floor(input_dimension * scale) |
8454 | class ResizeBilinearNode final : public Node { |
8455 | NodeHandle Input_; |
8456 | std::vector<float> Scale_; |
8457 | |
8458 | public: |
8459 | enum InputIndices { |
8460 | InputIdx = 0, |
8461 | }; |
8462 | |
8463 | enum ResultIndices { |
8464 | ResultIdx = 0, |
8465 | }; |
8466 | |
8467 | ResizeBilinearNode(llvm::StringRef name, TypeRef Result , NodeValue Input, std::vector<float> Scale) |
8468 | : Node(Kinded::Kind::ResizeBilinearNodeKind, name), Input_(this, Input), Scale_(Scale) { |
8469 | addResult(Result); |
8470 | } |
8471 | const NodeValue getInput() const { return Input_; } |
8472 | NodeValue getResult() { return getNthResult(0); } |
8473 | const NodeValue getResult() const { return getNthResult(0); } |
8474 | llvm::ArrayRef<float> getScale() const { return Scale_; } |
8475 | |
8476 | static bool classof(const Kinded *k) { |
8477 | return k->getKind() == Kinded::Kind::ResizeBilinearNodeKind; |
8478 | } |
8479 | |
8480 | |
8481 | bool isOverwrittenNthInput(unsigned idx) const { |
8482 | return false; |
8483 | } |
8484 | |
8485 | unsigned getNumInputs() const; |
8486 | std::string getInputName(unsigned idx) const; |
8487 | NodeValue getNthInput(unsigned idx); |
8488 | void setNthInput(unsigned idx, NodeValue val); |
8489 | llvm::StringRef getOutputName(unsigned idx) const; |
8490 | bool hasSideEffects() const { return 0; } |
8491 | bool isCanonical() const { return 1; } |
8492 | bool isDataParallel() const { return 0; } |
8493 | std::string getDebugDesc() const; |
8494 | bool isEqual(const ResizeBilinearNode &other) const; |
8495 | llvm::hash_code getHash() const; |
8496 | void visit(Node *parent, NodeWalker *visitor); |
8497 | Node* clone() const; |
8498 | bool verify() const; |
8499 | }; |
8500 | } // namespace glow |
8501 | |
8502 | |
8503 | namespace glow { |
8504 | /// Broadcast the Input tensor to TargetDim using Axis to indicate the offset between Input dimension and TargetDim |
8505 | class BroadcastNode final : public Node { |
8506 | NodeHandle Input_; |
8507 | unsigned_t Axis_; |
8508 | std::vector<dim_t> TargetDim_; |
8509 | |
8510 | public: |
8511 | enum InputIndices { |
8512 | InputIdx = 0, |
8513 | }; |
8514 | |
8515 | enum ResultIndices { |
8516 | ResultIdx = 0, |
8517 | }; |
8518 | |
8519 | BroadcastNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis, std::vector<dim_t> TargetDim) |
8520 | : Node(Kinded::Kind::BroadcastNodeKind, name), Input_(this, Input), Axis_(Axis), TargetDim_(TargetDim) { |
8521 | addResult(Result); |
8522 | } |
8523 | const NodeValue getInput() const { return Input_; } |
8524 | NodeValue getResult() { return getNthResult(0); } |
8525 | const NodeValue getResult() const { return getNthResult(0); } |
8526 | unsigned_t getAxis() const { return Axis_; } |
8527 | llvm::ArrayRef<dim_t> getTargetDim() const { return TargetDim_; } |
8528 | |
8529 | static bool classof(const Kinded *k) { |
8530 | return k->getKind() == Kinded::Kind::BroadcastNodeKind; |
8531 | } |
8532 | |
8533 | |
8534 | bool isOverwrittenNthInput(unsigned idx) const { |
8535 | return false; |
8536 | } |
8537 | |
8538 | unsigned getNumInputs() const; |
8539 | std::string getInputName(unsigned idx) const; |
8540 | NodeValue getNthInput(unsigned idx); |
8541 | void setNthInput(unsigned idx, NodeValue val); |
8542 | llvm::StringRef getOutputName(unsigned idx) const; |
8543 | bool hasSideEffects() const { return 0; } |
8544 | bool isCanonical() const { return 1; } |
8545 | bool isDataParallel() const { return 0; } |
8546 | std::string getDebugDesc() const; |
8547 | bool isEqual(const BroadcastNode &other) const; |
8548 | llvm::hash_code getHash() const; |
8549 | void visit(Node *parent, NodeWalker *visitor); |
8550 | Node* clone() const; |
8551 | bool verify() const; |
8552 | }; |
8553 | } // namespace glow |
8554 | |
8555 | |
8556 | namespace glow { |
8557 | /// TODO |
8558 | class SparseLabelSplitNode final : public Node { |
8559 | NodeHandle Lengths_; |
8560 | NodeHandle Indices_; |
8561 | NodeHandle Values_; |
8562 | unsigned_t NumLabels_; |
8563 | |
8564 | public: |
8565 | enum InputIndices { |
8566 | LengthsIdx = 0, |
8567 | IndicesIdx = 1, |
8568 | ValuesIdx = 2, |
8569 | }; |
8570 | |
8571 | enum ResultIndices { |
8572 | LabelValuesIdx = 0, |
8573 | ExampleIdsIdx = 1, |
8574 | GradientOffsetMapIdx = 2, |
8575 | }; |
8576 | |
8577 | SparseLabelSplitNode(llvm::StringRef name, TypeRef LabelValues , TypeRef ExampleIds , TypeRef GradientOffsetMap , NodeValue Lengths, NodeValue Indices, NodeValue Values, unsigned_t NumLabels) |
8578 | : Node(Kinded::Kind::SparseLabelSplitNodeKind, name), Lengths_(this, Lengths), Indices_(this, Indices), Values_(this, Values), NumLabels_(NumLabels) { |
8579 | addResult(LabelValues); |
8580 | addResult(ExampleIds); |
8581 | addResult(GradientOffsetMap); |
8582 | } |
8583 | const NodeValue getLengths() const { return Lengths_; } |
8584 | const NodeValue getIndices() const { return Indices_; } |
8585 | const NodeValue getValues() const { return Values_; } |
8586 | NodeValue getLabelValues() { return getNthResult(0); } |
8587 | const NodeValue getLabelValues() const { return getNthResult(0); } |
8588 | NodeValue getExampleIds() { return getNthResult(1); } |
8589 | const NodeValue getExampleIds() const { return getNthResult(1); } |
8590 | NodeValue getGradientOffsetMap() { return getNthResult(2); } |
8591 | const NodeValue getGradientOffsetMap() const { return getNthResult(2); } |
8592 | unsigned_t getNumLabels() const { return NumLabels_; } |
8593 | |
8594 | static bool classof(const Kinded *k) { |
8595 | return k->getKind() == Kinded::Kind::SparseLabelSplitNodeKind; |
8596 | } |
8597 | |
8598 | |
8599 | bool isOverwrittenNthInput(unsigned idx) const { |
8600 | return false; |
8601 | } |
8602 | |
8603 | unsigned getNumInputs() const; |
8604 | std::string getInputName(unsigned idx) const; |
8605 | NodeValue getNthInput(unsigned idx); |
8606 | void setNthInput(unsigned idx, NodeValue val); |
8607 | llvm::StringRef getOutputName(unsigned idx) const; |
8608 | bool hasSideEffects() const { return 0; } |
8609 | bool isCanonical() const { return 1; } |
8610 | bool isDataParallel() const { return 0; } |
8611 | std::string getDebugDesc() const; |
8612 | bool isEqual(const SparseLabelSplitNode &other) const; |
8613 | llvm::hash_code getHash() const; |
8614 | void visit(Node *parent, NodeWalker *visitor); |
8615 | Node* clone() const; |
8616 | bool verify() const; |
8617 | }; |
8618 | } // namespace glow |
8619 | |
8620 | |
8621 | namespace glow { |
8622 | /// Reverse the order of elements in a tensor along the given axis. The shape of the tensor is preserved, but the elements are reordered. The node is inspired from Python numpy. |
8623 | class FlipNode final : public Node { |
8624 | NodeHandle Input_; |
8625 | unsigned_t Axis_; |
8626 | |
8627 | public: |
8628 | enum InputIndices { |
8629 | InputIdx = 0, |
8630 | }; |
8631 | |
8632 | enum ResultIndices { |
8633 | ResultIdx = 0, |
8634 | }; |
8635 | |
8636 | FlipNode(llvm::StringRef name, TypeRef Result , NodeValue Input, unsigned_t Axis) |
8637 | : Node(Kinded::Kind::FlipNodeKind, name), Input_(this, Input), Axis_(Axis) { |
8638 | addResult(Result); |
8639 | } |
8640 | const NodeValue getInput() const { return Input_; } |
8641 | NodeValue getResult() { return getNthResult(0); } |
8642 | const NodeValue getResult() const { return getNthResult(0); } |
8643 | unsigned_t getAxis() const { return Axis_; } |
8644 | |
8645 | static bool classof(const Kinded *k) { |
8646 | return k->getKind() == Kinded::Kind::FlipNodeKind; |
8647 | } |
8648 | |
8649 | |
8650 | bool isOverwrittenNthInput(unsigned idx) const { |
8651 | return false; |
8652 | } |
8653 | |
8654 | unsigned getNumInputs() const; |
8655 | std::string getInputName(unsigned idx) const; |
8656 | NodeValue getNthInput(unsigned idx); |
8657 | void setNthInput(unsigned idx, NodeValue val); |
8658 | llvm::StringRef getOutputName(unsigned idx) const; |
8659 | bool hasSideEffects() const { return 0; } |
8660 | bool isCanonical() const { return 1; } |
8661 | bool isDataParallel() const { return 0; } |
8662 | std::string getDebugDesc() const; |
8663 | bool isEqual(const FlipNode &other) const; |
8664 | llvm::hash_code getHash() const; |
8665 | void visit(Node *parent, NodeWalker *visitor); |
8666 | Node* clone() const; |
8667 | bool verify() const; |
8668 | }; |
8669 | } // namespace glow |
8670 | |
8671 | |
8672 | namespace glow { |
8673 | /// Generate a tensor of a specific type filled with 'Value'.Splat always keep floating point value internally but canquantize it based on the output type. |
8674 | class SplatNode final : public Node { |
8675 | float Value_; |
8676 | |
8677 | public: |
8678 | enum InputIndices { |
8679 | }; |
8680 | |
8681 | enum ResultIndices { |
8682 | ResultIdx = 0, |
8683 | }; |
8684 | |
8685 | SplatNode(llvm::StringRef name, TypeRef Result , float Value) |
8686 | : Node(Kinded::Kind::SplatNodeKind, name), Value_(Value) { |
8687 | addResult(Result); |
8688 | } |
8689 | NodeValue getResult() { return getNthResult(0); } |
8690 | const NodeValue getResult() const { return getNthResult(0); } |
8691 | float getValue() const { return Value_; } |
8692 | |
8693 | static bool classof(const Kinded *k) { |
8694 | return k->getKind() == Kinded::Kind::SplatNodeKind; |
8695 | } |
8696 | |
8697 | |
8698 | bool isOverwrittenNthInput(unsigned idx) const { |
8699 | return false; |
8700 | } |
8701 | |
8702 | unsigned getNumInputs() const; |
8703 | std::string getInputName(unsigned idx) const; |
8704 | NodeValue getNthInput(unsigned idx); |
8705 | void setNthInput(unsigned idx, NodeValue val); |
8706 | llvm::StringRef getOutputName(unsigned idx) const; |
8707 | bool hasSideEffects() const { return 0; } |
8708 | bool isCanonical() const { return 1; } |
8709 | bool isDataParallel() const { return 0; } |
8710 | std::string getDebugDesc() const; |
8711 | bool isEqual(const SplatNode &other) const; |
8712 | llvm::hash_code getHash() const; |
8713 | void visit(Node *parent, NodeWalker *visitor); |
8714 | Node* clone() const; |
8715 | bool verify() const; |
8716 | }; |
8717 | } // namespace glow |
8718 | |
8719 | |
8720 | namespace glow { |
8721 | /// Generate a tensor of a specific type without initializing it. This is useful when filling a big tensor entirely with multiple small slices using InsertTensor nodes such that the big tensor is not required to be initialized (filled) with some value prior to insertion. This node is intended to remove the overhead associated with the initialization in situations where it is not required. |
8722 | class TouchNode final : public Node { |
8723 | |
8724 | public: |
8725 | enum InputIndices { |
8726 | }; |
8727 | |
8728 | enum ResultIndices { |
8729 | ResultIdx = 0, |
8730 | }; |
8731 | |
8732 | TouchNode(llvm::StringRef name, TypeRef Result ) |
8733 | : Node(Kinded::Kind::TouchNodeKind, name) { |
8734 | addResult(Result); |
8735 | } |
8736 | NodeValue getResult() { return getNthResult(0); } |
8737 | const NodeValue getResult() const { return getNthResult(0); } |
8738 | |
8739 | static bool classof(const Kinded *k) { |
8740 | return k->getKind() == Kinded::Kind::TouchNodeKind; |
8741 | } |
8742 | |
8743 | |
8744 | bool isOverwrittenNthInput(unsigned idx) const { |
8745 | return false; |
8746 | } |
8747 | |
8748 | unsigned getNumInputs() const; |
8749 | std::string getInputName(unsigned idx) const; |
8750 | NodeValue getNthInput(unsigned idx); |
8751 | void setNthInput(unsigned idx, NodeValue val); |
8752 | llvm::StringRef getOutputName(unsigned idx) const; |
8753 | bool hasSideEffects() const { return 0; } |
8754 | bool isCanonical() const { return 1; } |
8755 | bool isDataParallel() const { return 0; } |
8756 | std::string getDebugDesc() const; |
8757 | bool isEqual(const TouchNode &other) const; |
8758 | llvm::hash_code getHash() const; |
8759 | void visit(Node *parent, NodeWalker *visitor); |
8760 | Node* clone() const; |
8761 | bool verify() const; |
8762 | }; |
8763 | } // namespace glow |
8764 | |
8765 | |
8766 | namespace glow { |
8767 | /// Stochastic Gradient Descent node used during training. Produces the updated weight that needs to be used instead of Weight for the next iteration. |
8768 | class SGDNode final : public Node { |
8769 | NodeHandle Gradient_; |
8770 | NodeHandle Weight_; |
8771 | float L1Decay_; |
8772 | float L2Decay_; |
8773 | float LearningRate_; |
8774 | float Momentum_; |
8775 | unsigned_t BatchSize_; |
8776 | |
8777 | public: |
8778 | enum InputIndices { |
8779 | GradientIdx = 0, |
8780 | WeightIdx = 1, |
8781 | }; |
8782 | |
8783 | enum ResultIndices { |
8784 | UpdatedWeightIdx = 0, |
8785 | }; |
8786 | |
8787 | SGDNode(llvm::StringRef name, NodeValue Gradient, NodeValue Weight, float L1Decay, float L2Decay, float LearningRate, float Momentum, unsigned_t BatchSize) |
8788 | : Node(Kinded::Kind::SGDNodeKind, name), Gradient_(this, Gradient), Weight_(this, Weight), L1Decay_(L1Decay), L2Decay_(L2Decay), LearningRate_(LearningRate), Momentum_(Momentum), BatchSize_(BatchSize) { |
8789 | addResult(Weight.getType()); |
8790 | } |
8791 | const NodeValue getGradient() const { return Gradient_; } |
8792 | const NodeValue getWeight() const { return Weight_; } |
8793 | NodeValue getUpdatedWeight() { return getNthResult(0); } |
8794 | const NodeValue getUpdatedWeight() const { return getNthResult(0); } |
8795 | float getL1Decay() const { return L1Decay_; } |
8796 | float getL2Decay() const { return L2Decay_; } |
8797 | float getLearningRate() const { return LearningRate_; } |
8798 | float getMomentum() const { return Momentum_; } |
8799 | unsigned_t getBatchSize() const { return BatchSize_; } |
8800 | |
8801 | static bool classof(const Kinded *k) { |
8802 | return k->getKind() == Kinded::Kind::SGDNodeKind; |
8803 | } |
8804 | |
8805 | |
8806 | bool isOverwrittenNthInput(unsigned idx) const { |
8807 | return false; |
8808 | } |
8809 | |
8810 | unsigned getNumInputs() const; |
8811 | std::string getInputName(unsigned idx) const; |
8812 | NodeValue getNthInput(unsigned idx); |
8813 | void setNthInput(unsigned idx, NodeValue val); |
8814 | llvm::StringRef getOutputName(unsigned idx) const; |
8815 | bool hasSideEffects() const { return 1; } |
8816 | bool isCanonical() const { return 1; } |
8817 | bool isDataParallel() const { return 0; } |
8818 | std::string getDebugDesc() const; |
8819 | bool isEqual(const SGDNode &other) const; |
8820 | llvm::hash_code getHash() const; |
8821 | void visit(Node *parent, NodeWalker *visitor); |
8822 | Node* clone() const; |
8823 | bool verify() const; |
8824 | }; |
8825 | } // namespace glow |
8826 | |
8827 | |
8828 | namespace glow { |
8829 | /// Inserts a TraceEvent for profiling. |
8830 | class TraceEventNode final : public Node { |
8831 | NodeHandle Data_; |
8832 | std::string EventName_; |
8833 | std::string EventType_; |
8834 | unsigned_t Index_; |
8835 | |
8836 | public: |
8837 | enum InputIndices { |
8838 | DataIdx = 0, |
8839 | }; |
8840 | |
8841 | enum ResultIndices { |
8842 | }; |
8843 | |
8844 | TraceEventNode(llvm::StringRef name, NodeValue Data, std::string EventName, std::string EventType, unsigned_t Index) |
8845 | : Node(Kinded::Kind::TraceEventNodeKind, name), Data_(this, Data), EventName_(EventName), EventType_(EventType), Index_(Index) { |
8846 | } |
8847 | const NodeValue getData() const { return Data_; } |
8848 | std::string getEventName() const { return EventName_; } |
8849 | std::string getEventType() const { return EventType_; } |
8850 | unsigned_t getIndex() const { return Index_; } |
8851 | |
8852 | static bool classof(const Kinded *k) { |
8853 | return k->getKind() == Kinded::Kind::TraceEventNodeKind; |
8854 | } |
8855 | |
8856 | |
8857 | bool isOverwrittenNthInput(unsigned idx) const { |
8858 | return false; |
8859 | } |
8860 | |
8861 | unsigned getNumInputs() const; |
8862 | std::string getInputName(unsigned idx) const; |
8863 | NodeValue getNthInput(unsigned idx); |
8864 | void setNthInput(unsigned idx, NodeValue val); |
8865 | llvm::StringRef getOutputName(unsigned idx) const; |
8866 | bool hasSideEffects() const { return 1; } |
8867 | bool isCanonical() const { return 1; } |
8868 | bool isDataParallel() const { return 0; } |
8869 | std::string getDebugDesc() const; |
8870 | bool isEqual(const TraceEventNode &other) const; |
8871 | llvm::hash_code getHash() const; |
8872 | void visit(Node *parent, NodeWalker *visitor); |
8873 | Node* clone() const; |
8874 | bool verify() const; |
8875 | }; |
8876 | } // namespace glow |
8877 | |
8878 | |
8879 | namespace glow { |
8880 | /// Generate profile (distribution of values) of the Input tensor. This data is used for quantization of the tensor later on. ProfiledNodeName contains the name of the node which is profiled by the QuantizationProfile node. ProfiledNodeName is helpful as lowering might transform the original graph. ProfiledOutputNumber contains the position of the node's output which gets profiled. |
8881 | class QuantizationProfileNode final : public Node { |
8882 | NodeHandle Input_; |
8883 | NodeHandle Histogram_; |
8884 | NodeHandle ComputationInfo_; |
8885 | std::string ProfiledNodeName_; |
8886 | unsigned_t ProfiledOutputNumber_; |
8887 | |
8888 | public: |
8889 | enum InputIndices { |
8890 | InputIdx = 0, |
8891 | HistogramIdx = 1, |
8892 | ComputationInfoIdx = 2, |
8893 | }; |
8894 | |
8895 | enum ResultIndices { |
8896 | }; |
8897 | |
8898 | QuantizationProfileNode(llvm::StringRef name, NodeValue Input, NodeValue Histogram, NodeValue ComputationInfo, std::string ProfiledNodeName, unsigned_t ProfiledOutputNumber) |
8899 | : Node(Kinded::Kind::QuantizationProfileNodeKind, name), Input_(this, Input), Histogram_(this, Histogram), ComputationInfo_(this, ComputationInfo), ProfiledNodeName_(ProfiledNodeName), ProfiledOutputNumber_(ProfiledOutputNumber) { |
8900 | } |
8901 | const NodeValue getInput() const { return Input_; } |
8902 | const NodeValue getHistogram() const { return Histogram_; } |
8903 | const NodeValue getComputationInfo() const { return ComputationInfo_; } |
8904 | std::string getProfiledNodeName() const { return ProfiledNodeName_; } |
8905 | unsigned_t getProfiledOutputNumber() const { return ProfiledOutputNumber_; } |
8906 | |
8907 | static bool classof(const Kinded *k) { |
8908 | return k->getKind() == Kinded::Kind::QuantizationProfileNodeKind; |
8909 | } |
8910 | |
8911 | |
8912 | bool isOverwrittenNthInput(unsigned idx) const { |
8913 | if (idx == 2) return true; |
8914 | if (idx == 1) return true; |
8915 | return false; |
8916 | } |
8917 | |
8918 | unsigned getNumInputs() const; |
8919 | std::string getInputName(unsigned idx) const; |
8920 | NodeValue getNthInput(unsigned idx); |
8921 | void setNthInput(unsigned idx, NodeValue val); |
8922 | llvm::StringRef getOutputName(unsigned idx) const; |
8923 | bool hasSideEffects() const { return 1; } |
8924 | bool isCanonical() const { return 1; } |
8925 | bool isDataParallel() const { return 0; } |
8926 | std::string getDebugDesc() const; |
8927 | bool isEqual(const QuantizationProfileNode &other) const; |
8928 | llvm::hash_code getHash() const; |
8929 | void visit(Node *parent, NodeWalker *visitor); |
8930 | Node* clone() const; |
8931 | bool verify() const; |
8932 | Placeholder *getHistogramPlaceholder() const ; |
8933 | Placeholder *getComputationInfoPlaceholder() const; |
8934 | }; |
8935 | } // namespace glow |
8936 | |
8937 | |
8938 | namespace glow { |
8939 | /// Simple mapping between quantized numbers.This can be used as quantized sigmoid or tanh functions. |
8940 | class IntLookupTableNode final : public Node { |
8941 | NodeHandle Input_; |
8942 | NodeHandle Mapping_; |
8943 | |
8944 | public: |
8945 | enum InputIndices { |
8946 | InputIdx = 0, |
8947 | MappingIdx = 1, |
8948 | }; |
8949 | |
8950 | enum ResultIndices { |
8951 | ResultIdx = 0, |
8952 | }; |
8953 | |
8954 | IntLookupTableNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Mapping) |
8955 | : Node(Kinded::Kind::IntLookupTableNodeKind, name), Input_(this, Input), Mapping_(this, Mapping) { |
8956 | addResult(Result); |
8957 | } |
8958 | const NodeValue getInput() const { return Input_; } |
8959 | const NodeValue getMapping() const { return Mapping_; } |
8960 | NodeValue getResult() { return getNthResult(0); } |
8961 | const NodeValue getResult() const { return getNthResult(0); } |
8962 | |
8963 | static bool classof(const Kinded *k) { |
8964 | return k->getKind() == Kinded::Kind::IntLookupTableNodeKind; |
8965 | } |
8966 | |
8967 | |
8968 | bool isOverwrittenNthInput(unsigned idx) const { |
8969 | return false; |
8970 | } |
8971 | |
8972 | unsigned getNumInputs() const; |
8973 | std::string getInputName(unsigned idx) const; |
8974 | NodeValue getNthInput(unsigned idx); |
8975 | void setNthInput(unsigned idx, NodeValue val); |
8976 | llvm::StringRef getOutputName(unsigned idx) const; |
8977 | bool hasSideEffects() const { return 0; } |
8978 | bool isCanonical() const { return 1; } |
8979 | bool isDataParallel() const { return 1; } |
8980 | std::string getDebugDesc() const; |
8981 | bool isEqual(const IntLookupTableNode &other) const; |
8982 | llvm::hash_code getHash() const; |
8983 | void visit(Node *parent, NodeWalker *visitor); |
8984 | Node* clone() const; |
8985 | bool verify() const; |
8986 | }; |
8987 | } // namespace glow |
8988 | |
8989 | |
8990 | namespace glow { |
8991 | /// Quantize floating point tensor. This operation converts floating point numbers to integers based on the given Scale and Offset. Scale and Offset are deduced from the type of the output.x_q = clip(round(x/Scale) + Offset, -128, 127) |
8992 | class QuantizeNode final : public Node { |
8993 | NodeHandle Input_; |
8994 | |
8995 | public: |
8996 | enum InputIndices { |
8997 | InputIdx = 0, |
8998 | }; |
8999 | |
9000 | enum ResultIndices { |
9001 | ResultIdx = 0, |
9002 | }; |
9003 | |
9004 | QuantizeNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
9005 | : Node(Kinded::Kind::QuantizeNodeKind, name), Input_(this, Input) { |
9006 | addResult(Result); |
9007 | } |
9008 | const NodeValue getInput() const { return Input_; } |
9009 | NodeValue getResult() { return getNthResult(0); } |
9010 | const NodeValue getResult() const { return getNthResult(0); } |
9011 | |
9012 | static bool classof(const Kinded *k) { |
9013 | return k->getKind() == Kinded::Kind::QuantizeNodeKind; |
9014 | } |
9015 | |
9016 | |
9017 | bool isOverwrittenNthInput(unsigned idx) const { |
9018 | return false; |
9019 | } |
9020 | |
9021 | unsigned getNumInputs() const; |
9022 | std::string getInputName(unsigned idx) const; |
9023 | NodeValue getNthInput(unsigned idx); |
9024 | void setNthInput(unsigned idx, NodeValue val); |
9025 | llvm::StringRef getOutputName(unsigned idx) const; |
9026 | bool hasSideEffects() const { return 0; } |
9027 | bool isCanonical() const { return 1; } |
9028 | bool isDataParallel() const { return 1; } |
9029 | std::string getDebugDesc() const; |
9030 | bool isEqual(const QuantizeNode &other) const; |
9031 | llvm::hash_code getHash() const; |
9032 | void visit(Node *parent, NodeWalker *visitor); |
9033 | Node* clone() const; |
9034 | bool verify() const; |
9035 | }; |
9036 | } // namespace glow |
9037 | |
9038 | |
9039 | namespace glow { |
9040 | /// Convert quantized input tensor into the float representation. x = Scale * (x_q - Offset). |
9041 | class DequantizeNode final : public Node { |
9042 | NodeHandle Input_; |
9043 | |
9044 | public: |
9045 | enum InputIndices { |
9046 | InputIdx = 0, |
9047 | }; |
9048 | |
9049 | enum ResultIndices { |
9050 | ResultIdx = 0, |
9051 | }; |
9052 | |
9053 | DequantizeNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
9054 | : Node(Kinded::Kind::DequantizeNodeKind, name), Input_(this, Input) { |
9055 | addResult(Result); |
9056 | } |
9057 | const NodeValue getInput() const { return Input_; } |
9058 | NodeValue getResult() { return getNthResult(0); } |
9059 | const NodeValue getResult() const { return getNthResult(0); } |
9060 | |
9061 | static bool classof(const Kinded *k) { |
9062 | return k->getKind() == Kinded::Kind::DequantizeNodeKind; |
9063 | } |
9064 | |
9065 | |
9066 | bool isOverwrittenNthInput(unsigned idx) const { |
9067 | return false; |
9068 | } |
9069 | |
9070 | unsigned getNumInputs() const; |
9071 | std::string getInputName(unsigned idx) const; |
9072 | NodeValue getNthInput(unsigned idx); |
9073 | void setNthInput(unsigned idx, NodeValue val); |
9074 | llvm::StringRef getOutputName(unsigned idx) const; |
9075 | bool hasSideEffects() const { return 0; } |
9076 | bool isCanonical() const { return 1; } |
9077 | bool isDataParallel() const { return 1; } |
9078 | std::string getDebugDesc() const; |
9079 | bool isEqual(const DequantizeNode &other) const; |
9080 | llvm::hash_code getHash() const; |
9081 | void visit(Node *parent, NodeWalker *visitor); |
9082 | Node* clone() const; |
9083 | bool verify() const; |
9084 | }; |
9085 | } // namespace glow |
9086 | |
9087 | |
9088 | namespace glow { |
9089 | /// Rescale the input quantized tensor to a new Scale and Offset. The new Scale and Offset are specified by the output type passed to the constructor |
9090 | class RescaleQuantizedNode final : public Node { |
9091 | NodeHandle Input_; |
9092 | |
9093 | public: |
9094 | enum InputIndices { |
9095 | InputIdx = 0, |
9096 | }; |
9097 | |
9098 | enum ResultIndices { |
9099 | ResultIdx = 0, |
9100 | }; |
9101 | |
9102 | RescaleQuantizedNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
9103 | : Node(Kinded::Kind::RescaleQuantizedNodeKind, name), Input_(this, Input) { |
9104 | addResult(Result); |
9105 | } |
9106 | const NodeValue getInput() const { return Input_; } |
9107 | NodeValue getResult() { return getNthResult(0); } |
9108 | const NodeValue getResult() const { return getNthResult(0); } |
9109 | |
9110 | static bool classof(const Kinded *k) { |
9111 | return k->getKind() == Kinded::Kind::RescaleQuantizedNodeKind; |
9112 | } |
9113 | |
9114 | |
9115 | bool isOverwrittenNthInput(unsigned idx) const { |
9116 | return false; |
9117 | } |
9118 | |
9119 | unsigned getNumInputs() const; |
9120 | std::string getInputName(unsigned idx) const; |
9121 | NodeValue getNthInput(unsigned idx); |
9122 | void setNthInput(unsigned idx, NodeValue val); |
9123 | llvm::StringRef getOutputName(unsigned idx) const; |
9124 | bool hasSideEffects() const { return 0; } |
9125 | bool isCanonical() const { return 1; } |
9126 | bool isDataParallel() const { return 1; } |
9127 | std::string getDebugDesc() const; |
9128 | bool isEqual(const RescaleQuantizedNode &other) const; |
9129 | llvm::hash_code getHash() const; |
9130 | void visit(Node *parent, NodeWalker *visitor); |
9131 | Node* clone() const; |
9132 | bool verify() const; |
9133 | }; |
9134 | } // namespace glow |
9135 | |
9136 | |
9137 | namespace glow { |
9138 | /// Finds the top K maximal elements for each vector in the tensor. Vectors are defined as the last dimension in the tensor. The input shape {D_0, D_1, ... D_n} results in the outputs {D_0, D_1, ... D_n-1, K}, sorted in non-decreasing order. |
9139 | class TopKNode final : public Node { |
9140 | NodeHandle Input_; |
9141 | unsigned_t K_; |
9142 | |
9143 | public: |
9144 | enum InputIndices { |
9145 | InputIdx = 0, |
9146 | }; |
9147 | |
9148 | enum ResultIndices { |
9149 | ValuesIdx = 0, |
9150 | IndicesIdx = 1, |
9151 | }; |
9152 | |
9153 | TopKNode(llvm::StringRef name, TypeRef Values , TypeRef Indices , NodeValue Input, unsigned_t K) |
9154 | : Node(Kinded::Kind::TopKNodeKind, name), Input_(this, Input), K_(K) { |
9155 | addResult(Values); |
9156 | addResult(Indices); |
9157 | } |
9158 | const NodeValue getInput() const { return Input_; } |
9159 | NodeValue getValues() { return getNthResult(0); } |
9160 | const NodeValue getValues() const { return getNthResult(0); } |
9161 | NodeValue getIndices() { return getNthResult(1); } |
9162 | const NodeValue getIndices() const { return getNthResult(1); } |
9163 | unsigned_t getK() const { return K_; } |
9164 | |
9165 | static bool classof(const Kinded *k) { |
9166 | return k->getKind() == Kinded::Kind::TopKNodeKind; |
9167 | } |
9168 | |
9169 | |
9170 | bool isOverwrittenNthInput(unsigned idx) const { |
9171 | return false; |
9172 | } |
9173 | |
9174 | unsigned getNumInputs() const; |
9175 | std::string getInputName(unsigned idx) const; |
9176 | NodeValue getNthInput(unsigned idx); |
9177 | void setNthInput(unsigned idx, NodeValue val); |
9178 | llvm::StringRef getOutputName(unsigned idx) const; |
9179 | bool hasSideEffects() const { return 0; } |
9180 | bool isCanonical() const { return 1; } |
9181 | bool isDataParallel() const { return 0; } |
9182 | std::string getDebugDesc() const; |
9183 | bool isEqual(const TopKNode &other) const; |
9184 | llvm::hash_code getHash() const; |
9185 | void visit(Node *parent, NodeWalker *visitor); |
9186 | Node* clone() const; |
9187 | bool verify() const; |
9188 | }; |
9189 | } // namespace glow |
9190 | |
9191 | |
9192 | namespace glow { |
9193 | /// A LSTM unit node, take Input as I, F, G, O,takes F from forget gate, I from input gate,O from output gate, G from cell gate and C from cell state. Calulates newC = sigmoid(F) * C + sigmoid(I) * tanh(G), newH = tanh(newC) * sigmoid(O). |
9194 | class LSTMUnitNode final : public Node { |
9195 | NodeHandle Input_; |
9196 | NodeHandle C_; |
9197 | |
9198 | public: |
9199 | enum InputIndices { |
9200 | InputIdx = 0, |
9201 | CIdx = 1, |
9202 | }; |
9203 | |
9204 | enum ResultIndices { |
9205 | newCIdx = 0, |
9206 | newHIdx = 1, |
9207 | }; |
9208 | |
9209 | LSTMUnitNode(llvm::StringRef name, NodeValue Input, NodeValue C) |
9210 | : Node(Kinded::Kind::LSTMUnitNodeKind, name), Input_(this, Input), C_(this, C) { |
9211 | addResult(C.getType()); |
9212 | addResult(C.getType()); |
9213 | } |
9214 | const NodeValue getInput() const { return Input_; } |
9215 | const NodeValue getC() const { return C_; } |
9216 | NodeValue getnewC() { return getNthResult(0); } |
9217 | const NodeValue getnewC() const { return getNthResult(0); } |
9218 | NodeValue getnewH() { return getNthResult(1); } |
9219 | const NodeValue getnewH() const { return getNthResult(1); } |
9220 | |
9221 | static bool classof(const Kinded *k) { |
9222 | return k->getKind() == Kinded::Kind::LSTMUnitNodeKind; |
9223 | } |
9224 | |
9225 | |
9226 | bool isOverwrittenNthInput(unsigned idx) const { |
9227 | return false; |
9228 | } |
9229 | |
9230 | unsigned getNumInputs() const; |
9231 | std::string getInputName(unsigned idx) const; |
9232 | NodeValue getNthInput(unsigned idx); |
9233 | void setNthInput(unsigned idx, NodeValue val); |
9234 | llvm::StringRef getOutputName(unsigned idx) const; |
9235 | bool hasSideEffects() const { return 0; } |
9236 | bool isCanonical() const { return 1; } |
9237 | bool isDataParallel() const { return 0; } |
9238 | std::string getDebugDesc() const; |
9239 | bool isEqual(const LSTMUnitNode &other) const; |
9240 | llvm::hash_code getHash() const; |
9241 | void visit(Node *parent, NodeWalker *visitor); |
9242 | Node* clone() const; |
9243 | bool verify() const; |
9244 | }; |
9245 | } // namespace glow |
9246 | |
9247 | |
9248 | namespace glow { |
9249 | /// Convert the input from its current type to the destination type. The input and output types must have the same shapes. Moreover the input and output types must not be quantized types. Quantized types should use the appropriate Quantize, Dequantize, and Rescale nodes. |
9250 | class ConvertToNode final : public Node { |
9251 | NodeHandle Input_; |
9252 | |
9253 | public: |
9254 | enum InputIndices { |
9255 | InputIdx = 0, |
9256 | }; |
9257 | |
9258 | enum ResultIndices { |
9259 | ResultIdx = 0, |
9260 | }; |
9261 | |
9262 | ConvertToNode(llvm::StringRef name, TypeRef Result , NodeValue Input) |
9263 | : Node(Kinded::Kind::ConvertToNodeKind, name), Input_(this, Input) { |
9264 | addResult(Result); |
9265 | } |
9266 | const NodeValue getInput() const { return Input_; } |
9267 | NodeValue getResult() { return getNthResult(0); } |
9268 | const NodeValue getResult() const { return getNthResult(0); } |
9269 | |
9270 | static bool classof(const Kinded *k) { |
9271 | return k->getKind() == Kinded::Kind::ConvertToNodeKind; |
9272 | } |
9273 | |
9274 | |
9275 | bool isOverwrittenNthInput(unsigned idx) const { |
9276 | return false; |
9277 | } |
9278 | |
9279 | unsigned getNumInputs() const; |
9280 | std::string getInputName(unsigned idx) const; |
9281 | NodeValue getNthInput(unsigned idx); |
9282 | void setNthInput(unsigned idx, NodeValue val); |
9283 | llvm::StringRef getOutputName(unsigned idx) const; |
9284 | bool hasSideEffects() const { return 0; } |
9285 | bool isCanonical() const { return 1; } |
9286 | bool isDataParallel() const { return 1; } |
9287 | std::string getDebugDesc() const; |
9288 | bool isEqual(const ConvertToNode &other) const; |
9289 | llvm::hash_code getHash() const; |
9290 | void visit(Node *parent, NodeWalker *visitor); |
9291 | Node* clone() const; |
9292 | bool verify() const; |
9293 | }; |
9294 | } // namespace glow |
9295 | |
9296 | |
9297 | namespace glow { |
9298 | /// This is a node representing an external function call. One possible use of this capability is to pass a source code for a function/kernel. When processing this node, a backend can compile and execute the source code. This node can also be used to pass binary or pointers to executable code. The semantics and implementation of this node not standardized and is very backend-specific. |
9299 | class ExternalFunctionCallNode final : public Node { |
9300 | std::vector<NodeHandle> Inputs_; |
9301 | std::string FunctionName_; |
9302 | std::string FunctionImpl_; |
9303 | std::string FunctionKind_; |
9304 | |
9305 | public: |
9306 | enum InputIndices { |
9307 | }; |
9308 | |
9309 | enum ResultIndices { |
9310 | ResultIdx = 0, |
9311 | }; |
9312 | |
9313 | ExternalFunctionCallNode(llvm::StringRef name, TypeRef Result , std::vector<NodeValue> Inputs, std::string FunctionName, std::string FunctionImpl, std::string FunctionKind) |
9314 | : Node(Kinded::Kind::ExternalFunctionCallNodeKind, name), FunctionName_(FunctionName), FunctionImpl_(FunctionImpl), FunctionKind_(FunctionKind) { |
9315 | addResult(Result); |
9316 | Inputs_.resize(Inputs.size()); |
9317 | for (size_t idx = 0, e = Inputs.size(); idx < e; ++idx) { |
9318 | Inputs_[idx] = Inputs[idx]; |
9319 | Inputs_[idx].setParent(this); |
9320 | } |
9321 | } |
9322 | NodeValue getResult() { return getNthResult(0); } |
9323 | const NodeValue getResult() const { return getNthResult(0); } |
9324 | NodeValueArrayRef getInputs() const { return Inputs_; } |
9325 | std::string getFunctionName() const { return FunctionName_; } |
9326 | std::string getFunctionImpl() const { return FunctionImpl_; } |
9327 | std::string getFunctionKind() const { return FunctionKind_; } |
9328 | |
9329 | static bool classof(const Kinded *k) { |
9330 | return k->getKind() == Kinded::Kind::ExternalFunctionCallNodeKind; |
9331 | } |
9332 | |
9333 | |
9334 | bool isOverwrittenNthInput(unsigned idx) const { |
9335 | return false; |
9336 | } |
9337 | |
9338 | unsigned getNumInputs() const; |
9339 | std::string getInputName(unsigned idx) const; |
9340 | NodeValue getNthInput(unsigned idx); |
9341 | void setNthInput(unsigned idx, NodeValue val); |
9342 | llvm::StringRef getOutputName(unsigned idx) const; |
9343 | bool hasSideEffects() const { return 1; } |
9344 | bool isCanonical() const { return 1; } |
9345 | bool isDataParallel() const { return 0; } |
9346 | std::string getDebugDesc() const; |
9347 | bool isEqual(const ExternalFunctionCallNode &other) const; |
9348 | llvm::hash_code getHash() const; |
9349 | void visit(Node *parent, NodeWalker *visitor); |
9350 | Node* clone() const; |
9351 | bool verify() const; |
9352 | }; |
9353 | } // namespace glow |
9354 | |
9355 | |
9356 | namespace glow { |
9357 | /// Computes the spectrogram of a mono audio signal using given window size and stride. The FFT length used to compute the spectrogram is the next power of 2 (for a window size of 640 the FFT length is 1024). The length of each spectrogram window is FFT_length / 2 + 1. This node is inspired from TensorFlow. |
9358 | class AudioSpectrogramNode final : public Node { |
9359 | NodeHandle Input_; |
9360 | NodeHandle Window_; |
9361 | NodeHandle TwiddleFactors_; |
9362 | NodeHandle BitReverseIndices_; |
9363 | NodeHandle ComplexToRealWeights_; |
9364 | unsigned_t WindowSize_; |
9365 | unsigned_t WindowStride_; |
9366 | bool MagnitudeSquared_; |
9367 | |
9368 | public: |
9369 | enum InputIndices { |
9370 | InputIdx = 0, |
9371 | WindowIdx = 1, |
9372 | TwiddleFactorsIdx = 2, |
9373 | BitReverseIndicesIdx = 3, |
9374 | ComplexToRealWeightsIdx = 4, |
9375 | }; |
9376 | |
9377 | enum ResultIndices { |
9378 | SpectrogramIdx = 0, |
9379 | }; |
9380 | |
9381 | AudioSpectrogramNode(llvm::StringRef name, TypeRef Spectrogram , NodeValue Input, NodeValue Window, NodeValue TwiddleFactors, NodeValue BitReverseIndices, NodeValue ComplexToRealWeights, unsigned_t WindowSize, unsigned_t WindowStride, bool MagnitudeSquared) |
9382 | : Node(Kinded::Kind::AudioSpectrogramNodeKind, name), Input_(this, Input), Window_(this, Window), TwiddleFactors_(this, TwiddleFactors), BitReverseIndices_(this, BitReverseIndices), ComplexToRealWeights_(this, ComplexToRealWeights), WindowSize_(WindowSize), WindowStride_(WindowStride), MagnitudeSquared_(MagnitudeSquared) { |
9383 | addResult(Spectrogram); |
9384 | } |
9385 | const NodeValue getInput() const { return Input_; } |
9386 | const NodeValue getWindow() const { return Window_; } |
9387 | const NodeValue getTwiddleFactors() const { return TwiddleFactors_; } |
9388 | const NodeValue getBitReverseIndices() const { return BitReverseIndices_; } |
9389 | const NodeValue getComplexToRealWeights() const { return ComplexToRealWeights_; } |
9390 | NodeValue getSpectrogram() { return getNthResult(0); } |
9391 | const NodeValue getSpectrogram() const { return getNthResult(0); } |
9392 | unsigned_t getWindowSize() const { return WindowSize_; } |
9393 | unsigned_t getWindowStride() const { return WindowStride_; } |
9394 | bool getMagnitudeSquared() const { return MagnitudeSquared_; } |
9395 | |
9396 | static bool classof(const Kinded *k) { |
9397 | return k->getKind() == Kinded::Kind::AudioSpectrogramNodeKind; |
9398 | } |
9399 | |
9400 | |
9401 | bool isOverwrittenNthInput(unsigned idx) const { |
9402 | return false; |
9403 | } |
9404 | |
9405 | unsigned getNumInputs() const; |
9406 | std::string getInputName(unsigned idx) const; |
9407 | NodeValue getNthInput(unsigned idx); |
9408 | void setNthInput(unsigned idx, NodeValue val); |
9409 | llvm::StringRef getOutputName(unsigned idx) const; |
9410 | bool hasSideEffects() const { return 0; } |
9411 | bool isCanonical() const { return 1; } |
9412 | bool isDataParallel() const { return 0; } |
9413 | std::string getDebugDesc() const; |
9414 | bool isEqual(const AudioSpectrogramNode &other) const; |
9415 | llvm::hash_code getHash() const; |
9416 | void visit(Node *parent, NodeWalker *visitor); |
9417 | Node* clone() const; |
9418 | bool verify() const; |
9419 | }; |
9420 | } // namespace glow |
9421 | |
9422 | |
9423 | namespace glow { |
9424 | /// Computes the MFCC (Mel Frequency Cepstral Coefficient) for the given spectrogram. This node is mostly used as feature extractor for voice/speech audio data in voice command or keyword spotting applications. The input is assumed to be a power spectrogram and not a magnitude.This node is inspired from TensorFlow. |
9425 | class MFCCNode final : public Node { |
9426 | NodeHandle Spectrogram_; |
9427 | NodeHandle MelWeights_; |
9428 | NodeHandle MelRanges_; |
9429 | NodeHandle DctMat_; |
9430 | float SampleRate_; |
9431 | float LowerFrequency_; |
9432 | float UpperFrequency_; |
9433 | unsigned_t FilterBankCount_; |
9434 | unsigned_t NumCoefficients_; |
9435 | |
9436 | public: |
9437 | enum InputIndices { |
9438 | SpectrogramIdx = 0, |
9439 | MelWeightsIdx = 1, |
9440 | MelRangesIdx = 2, |
9441 | DctMatIdx = 3, |
9442 | }; |
9443 | |
9444 | enum ResultIndices { |
9445 | CoefficientsIdx = 0, |
9446 | }; |
9447 | |
9448 | MFCCNode(llvm::StringRef name, TypeRef Coefficients , NodeValue Spectrogram, NodeValue MelWeights, NodeValue MelRanges, NodeValue DctMat, float SampleRate, float LowerFrequency, float UpperFrequency, unsigned_t FilterBankCount, unsigned_t NumCoefficients) |
9449 | : Node(Kinded::Kind::MFCCNodeKind, name), Spectrogram_(this, Spectrogram), MelWeights_(this, MelWeights), MelRanges_(this, MelRanges), DctMat_(this, DctMat), SampleRate_(SampleRate), LowerFrequency_(LowerFrequency), UpperFrequency_(UpperFrequency), FilterBankCount_(FilterBankCount), NumCoefficients_(NumCoefficients) { |
9450 | addResult(Coefficients); |
9451 | } |
9452 | const NodeValue getSpectrogram() const { return Spectrogram_; } |
9453 | const NodeValue getMelWeights() const { return MelWeights_; } |
9454 | const NodeValue getMelRanges() const { return MelRanges_; } |
9455 | const NodeValue getDctMat() const { return DctMat_; } |
9456 | NodeValue getCoefficients() { return getNthResult(0); } |
9457 | const NodeValue getCoefficients() const { return getNthResult(0); } |
9458 | float getSampleRate() const { return SampleRate_; } |
9459 | float getLowerFrequency() const { return LowerFrequency_; } |
9460 | float getUpperFrequency() const { return UpperFrequency_; } |
9461 | unsigned_t getFilterBankCount() const { return FilterBankCount_; } |
9462 | unsigned_t getNumCoefficients() const { return NumCoefficients_; } |
9463 | |
9464 | static bool classof(const Kinded *k) { |
9465 | return k->getKind() == Kinded::Kind::MFCCNodeKind; |
9466 | } |
9467 | |
9468 | |
9469 | bool isOverwrittenNthInput(unsigned idx) const { |
9470 | return false; |
9471 | } |
9472 | |
9473 | unsigned getNumInputs() const; |
9474 | std::string getInputName(unsigned idx) const; |
9475 | NodeValue getNthInput(unsigned idx); |
9476 | void setNthInput(unsigned idx, NodeValue val); |
9477 | llvm::StringRef getOutputName(unsigned idx) const; |
9478 | bool hasSideEffects() const { return 0; } |
9479 | bool isCanonical() const { return 1; } |
9480 | bool isDataParallel() const { return 0; } |
9481 | std::string getDebugDesc() const; |
9482 | bool isEqual(const MFCCNode &other) const; |
9483 | llvm::hash_code getHash() const; |
9484 | void visit(Node *parent, NodeWalker *visitor); |
9485 | Node* clone() const; |
9486 | bool verify() const; |
9487 | }; |
9488 | } // namespace glow |
9489 | |
9490 | |
9491 | namespace glow { |
9492 | /// This is a mix of ONNX and TF NMSv4. It supports multiple classes and does per class NMS. It also supports TF NMS V4 by outputting indices and scalar tensor with number of valid indices. It pads the rest with global MIN box. |
9493 | class NonMaxSuppressionNode final : public Node { |
9494 | NodeHandle Boxes_; |
9495 | NodeHandle Scores_; |
9496 | unsigned_t CenterPointBox_; |
9497 | unsigned_t MaxOutputBoxesPerClass_; |
9498 | float IouThreshold_; |
9499 | float ScoreThreshold_; |
9500 | bool IsTFVersion4_; |
9501 | |
9502 | public: |
9503 | enum InputIndices { |
9504 | BoxesIdx = 0, |
9505 | ScoresIdx = 1, |
9506 | }; |
9507 | |
9508 | enum ResultIndices { |
9509 | IndicesIdx = 0, |
9510 | NumberOfSelectedIndicesIdx = 1, |
9511 | }; |
9512 | |
9513 | NonMaxSuppressionNode(llvm::StringRef name, TypeRef Indices , TypeRef NumberOfSelectedIndices , NodeValue Boxes, NodeValue Scores, unsigned_t CenterPointBox, unsigned_t MaxOutputBoxesPerClass, float IouThreshold, float ScoreThreshold, bool IsTFVersion4) |
9514 | : Node(Kinded::Kind::NonMaxSuppressionNodeKind, name), Boxes_(this, Boxes), Scores_(this, Scores), CenterPointBox_(CenterPointBox), MaxOutputBoxesPerClass_(MaxOutputBoxesPerClass), IouThreshold_(IouThreshold), ScoreThreshold_(ScoreThreshold), IsTFVersion4_(IsTFVersion4) { |
9515 | addResult(Indices); |
9516 | addResult(NumberOfSelectedIndices); |
9517 | } |
9518 | const NodeValue getBoxes() const { return Boxes_; } |
9519 | const NodeValue getScores() const { return Scores_; } |
9520 | NodeValue getIndices() { return getNthResult(0); } |
9521 | const NodeValue getIndices() const { return getNthResult(0); } |
9522 | NodeValue getNumberOfSelectedIndices() { return getNthResult(1); } |
9523 | const NodeValue getNumberOfSelectedIndices() const { return getNthResult(1); } |
9524 | unsigned_t getCenterPointBox() const { return CenterPointBox_; } |
9525 | unsigned_t getMaxOutputBoxesPerClass() const { return MaxOutputBoxesPerClass_; } |
9526 | float getIouThreshold() const { return IouThreshold_; } |
9527 | float getScoreThreshold() const { return ScoreThreshold_; } |
9528 | bool getIsTFVersion4() const { return IsTFVersion4_; } |
9529 | |
9530 | static bool classof(const Kinded *k) { |
9531 | return k->getKind() == Kinded::Kind::NonMaxSuppressionNodeKind; |
9532 | } |
9533 | |
9534 | |
9535 | bool isOverwrittenNthInput(unsigned idx) const { |
9536 | return false; |
9537 | } |
9538 | |
9539 | unsigned getNumInputs() const; |
9540 | std::string getInputName(unsigned idx) const; |
9541 | NodeValue getNthInput(unsigned idx); |
9542 | void setNthInput(unsigned idx, NodeValue val); |
9543 | llvm::StringRef getOutputName(unsigned idx) const; |
9544 | bool hasSideEffects() const { return 0; } |
9545 | bool isCanonical() const { return 1; } |
9546 | bool isDataParallel() const { return 0; } |
9547 | std::string getDebugDesc() const; |
9548 | bool isEqual(const NonMaxSuppressionNode &other) const; |
9549 | llvm::hash_code getHash() const; |
9550 | void visit(Node *parent, NodeWalker *visitor); |
9551 | Node* clone() const; |
9552 | bool verify() const; |
9553 | }; |
9554 | } // namespace glow |
9555 | |
9556 | |
9557 | namespace glow { |
9558 | /// This node is a TensorFlowLite version of NonMaxSuppresion. The node has the following inputs: Boxes with size [N, B, 4], Scores with size [N, B, C] and Anchors with size [B, 4] where N is the batch size, B is the number of boxes and C is the number of classes. The node has the following attributes (parameters): NumClasses - Number of classes (without the background class). MaxDetections - The maximum number of detections. MaxClassesPerDetection - Maximum classes per detection (Fast NMS). MaxDetectionsPerClass - Maximum detections per class (Regular NMS). IouThreshold - Detection threshold for IoU metric. ScoreThreshold - Detection threshold for scores. XScale - X scale used for decoding the boxes. YScale - Y scale used for decoding the boxes. HScale - H scale used for decoding the boxes. WScale - W scale used for decoding the boxes. RegularNMS - Whether the NMS is 'Regular' or 'Fast'. The node will have the following outputs: DetectionBoxes - the chosen boxes (float). DetectionClasses - the classes of the chosen boxes (int32). DetectionScores - the scores of the chosen boxes (float). NumDetections - number of chose boxes (int32). The first three output tensors will be allocated using the maximum number of possible detections (worst case scenario) but the actual usage will be given by the 'NumDetections' output. |
9559 | class TFLiteDetectionPostProcessNode final : public Node { |
9560 | NodeHandle Boxes_; |
9561 | NodeHandle Scores_; |
9562 | NodeHandle Anchors_; |
9563 | unsigned_t NumClasses_; |
9564 | unsigned_t MaxDetections_; |
9565 | unsigned_t MaxClassesPerDetection_; |
9566 | unsigned_t MaxDetectionsPerClass_; |
9567 | float IouThreshold_; |
9568 | float ScoreThreshold_; |
9569 | float XScale_; |
9570 | float YScale_; |
9571 | float HScale_; |
9572 | float WScale_; |
9573 | bool RegularNMS_; |
9574 | |
9575 | public: |
9576 | enum InputIndices { |
9577 | BoxesIdx = 0, |
9578 | ScoresIdx = 1, |
9579 | AnchorsIdx = 2, |
9580 | }; |
9581 | |
9582 | enum ResultIndices { |
9583 | DetectionBoxesIdx = 0, |
9584 | DetectionClassesIdx = 1, |
9585 | DetectionScoresIdx = 2, |
9586 | NumDetectionsIdx = 3, |
9587 | }; |
9588 | |
9589 | TFLiteDetectionPostProcessNode(llvm::StringRef name, TypeRef DetectionBoxes , TypeRef DetectionClasses , TypeRef DetectionScores , TypeRef NumDetections , NodeValue Boxes, NodeValue Scores, NodeValue Anchors, unsigned_t NumClasses, unsigned_t MaxDetections, unsigned_t MaxClassesPerDetection, unsigned_t MaxDetectionsPerClass, float IouThreshold, float ScoreThreshold, float XScale, float YScale, float HScale, float WScale, bool RegularNMS) |
9590 | : Node(Kinded::Kind::TFLiteDetectionPostProcessNodeKind, name), Boxes_(this, Boxes), Scores_(this, Scores), Anchors_(this, Anchors), NumClasses_(NumClasses), MaxDetections_(MaxDetections), MaxClassesPerDetection_(MaxClassesPerDetection), MaxDetectionsPerClass_(MaxDetectionsPerClass), IouThreshold_(IouThreshold), ScoreThreshold_(ScoreThreshold), XScale_(XScale), YScale_(YScale), HScale_(HScale), WScale_(WScale), RegularNMS_(RegularNMS) { |
9591 | addResult(DetectionBoxes); |
9592 | addResult(DetectionClasses); |
9593 | addResult(DetectionScores); |
9594 | addResult(NumDetections); |
9595 | } |
9596 | const NodeValue getBoxes() const { return Boxes_; } |
9597 | const NodeValue getScores() const { return Scores_; } |
9598 | const NodeValue getAnchors() const { return Anchors_; } |
9599 | NodeValue getDetectionBoxes() { return getNthResult(0); } |
9600 | const NodeValue getDetectionBoxes() const { return getNthResult(0); } |
9601 | NodeValue getDetectionClasses() { return getNthResult(1); } |
9602 | const NodeValue getDetectionClasses() const { return getNthResult(1); } |
9603 | NodeValue getDetectionScores() { return getNthResult(2); } |
9604 | const NodeValue getDetectionScores() const { return getNthResult(2); } |
9605 | NodeValue getNumDetections() { return getNthResult(3); } |
9606 | const NodeValue getNumDetections() const { return getNthResult(3); } |
9607 | unsigned_t getNumClasses() const { return NumClasses_; } |
9608 | unsigned_t getMaxDetections() const { return MaxDetections_; } |
9609 | unsigned_t getMaxClassesPerDetection() const { return MaxClassesPerDetection_; } |
9610 | unsigned_t getMaxDetectionsPerClass() const { return MaxDetectionsPerClass_; } |
9611 | float getIouThreshold() const { return IouThreshold_; } |
9612 | float getScoreThreshold() const { return ScoreThreshold_; } |
9613 | float getXScale() const { return XScale_; } |
9614 | float getYScale() const { return YScale_; } |
9615 | float getHScale() const { return HScale_; } |
9616 | float getWScale() const { return WScale_; } |
9617 | bool getRegularNMS() const { return RegularNMS_; } |
9618 | |
9619 | static bool classof(const Kinded *k) { |
9620 | return k->getKind() == Kinded::Kind::TFLiteDetectionPostProcessNodeKind; |
9621 | } |
9622 | |
9623 | |
9624 | bool isOverwrittenNthInput(unsigned idx) const { |
9625 | return false; |
9626 | } |
9627 | |
9628 | unsigned getNumInputs() const; |
9629 | std::string getInputName(unsigned idx) const; |
9630 | NodeValue getNthInput(unsigned idx); |
9631 | void setNthInput(unsigned idx, NodeValue val); |
9632 | llvm::StringRef getOutputName(unsigned idx) const; |
9633 | bool hasSideEffects() const { return 0; } |
9634 | bool isCanonical() const { return 1; } |
9635 | bool isDataParallel() const { return 0; } |
9636 | std::string getDebugDesc() const; |
9637 | bool isEqual(const TFLiteDetectionPostProcessNode &other) const; |
9638 | llvm::hash_code getHash() const; |
9639 | void visit(Node *parent, NodeWalker *visitor); |
9640 | Node* clone() const; |
9641 | bool verify() const; |
9642 | }; |
9643 | } // namespace glow |
9644 | |
9645 | |
9646 | namespace glow { |
9647 | /// Performs region of interest align (ROI) operator. FeatureMap - a tensor of [N,H,W,C]. N is the batch, C is the channel, H is the height, W is the width. Boxes - a tensor of [K,4] or [K,5] with format [[optinal_batch_index] x0, y0, x1, y1]. K is the number of boxes. BatchIndices - a tensor of [K,]. If N > 1 and Box shape is [K,4], BatchIndices must be valid. Output is a tensor with shape [K, OutputHeight, OutputWidth, C]. Aligned - if true, coordinates are aligned to a center of a pixel. |
9648 | class ROIAlignNode final : public Node { |
9649 | NodeHandle FeatureMap_; |
9650 | NodeHandle Boxes_; |
9651 | NodeHandle BatchIndices_; |
9652 | unsigned_t Mode_; |
9653 | unsigned_t OutputHeight_; |
9654 | unsigned_t OutputWidth_; |
9655 | unsigned_t SamplingRatio_; |
9656 | float SpatialScale_; |
9657 | bool Aligned_; |
9658 | bool Rotated_; |
9659 | |
9660 | public: |
9661 | enum InputIndices { |
9662 | FeatureMapIdx = 0, |
9663 | BoxesIdx = 1, |
9664 | BatchIndicesIdx = 2, |
9665 | }; |
9666 | |
9667 | enum ResultIndices { |
9668 | ResultIdx = 0, |
9669 | }; |
9670 | |
9671 | ROIAlignNode(llvm::StringRef name, TypeRef Result , NodeValue FeatureMap, NodeValue Boxes, NodeValue BatchIndices, unsigned_t Mode, unsigned_t OutputHeight, unsigned_t OutputWidth, unsigned_t SamplingRatio, float SpatialScale, bool Aligned, bool Rotated) |
9672 | : Node(Kinded::Kind::ROIAlignNodeKind, name), FeatureMap_(this, FeatureMap), Boxes_(this, Boxes), BatchIndices_(this, BatchIndices), Mode_(Mode), OutputHeight_(OutputHeight), OutputWidth_(OutputWidth), SamplingRatio_(SamplingRatio), SpatialScale_(SpatialScale), Aligned_(Aligned), Rotated_(Rotated) { |
9673 | addResult(Result); |
9674 | } |
9675 | const NodeValue getFeatureMap() const { return FeatureMap_; } |
9676 | const NodeValue getBoxes() const { return Boxes_; } |
9677 | const NodeValue getBatchIndices() const { return BatchIndices_; } |
9678 | NodeValue getResult() { return getNthResult(0); } |
9679 | const NodeValue getResult() const { return getNthResult(0); } |
9680 | unsigned_t getMode() const { return Mode_; } |
9681 | unsigned_t getOutputHeight() const { return OutputHeight_; } |
9682 | unsigned_t getOutputWidth() const { return OutputWidth_; } |
9683 | unsigned_t getSamplingRatio() const { return SamplingRatio_; } |
9684 | float getSpatialScale() const { return SpatialScale_; } |
9685 | bool getAligned() const { return Aligned_; } |
9686 | bool getRotated() const { return Rotated_; } |
9687 | |
9688 | static bool classof(const Kinded *k) { |
9689 | return k->getKind() == Kinded::Kind::ROIAlignNodeKind; |
9690 | } |
9691 | |
9692 | |
9693 | bool isOverwrittenNthInput(unsigned idx) const { |
9694 | return false; |
9695 | } |
9696 | |
9697 | unsigned getNumInputs() const; |
9698 | std::string getInputName(unsigned idx) const; |
9699 | NodeValue getNthInput(unsigned idx); |
9700 | void setNthInput(unsigned idx, NodeValue val); |
9701 | llvm::StringRef getOutputName(unsigned idx) const; |
9702 | bool hasSideEffects() const { return 0; } |
9703 | bool isCanonical() const { return 1; } |
9704 | bool isDataParallel() const { return 0; } |
9705 | std::string getDebugDesc() const; |
9706 | bool isEqual(const ROIAlignNode &other) const; |
9707 | llvm::hash_code getHash() const; |
9708 | void visit(Node *parent, NodeWalker *visitor); |
9709 | Node* clone() const; |
9710 | bool verify() const; |
9711 | }; |
9712 | } // namespace glow |
9713 | |
9714 | |
9715 | namespace glow { |
9716 | /// Transform proposal bounding boxes to target bounding box using bounding box regression deltas. Rois tensor's format is: <[optional_batch_index], x1, y1, x2, y2>, shape (M, 4) or (M, 5) where M is the number of Rois. For rotated boxes, this would have an additional angle (in degrees) in the format <[optional_batch_id], ctr_x, ctr_y, w, h, angle> Deltas are of shape (M, K*4) with format <dx, dy, dw, dh>, where K is the number of classes. For rotated Rois: shape (M, K*5), format <dx, dy, dw, dh, da>. ImInfo is of shape <batch_size, 3> with format <img_height, img_width, img_scale>.If proposals from multiple images in a batch are present, they should be grouped sequentially and in incremental order. |
9717 | class BBoxTransformNode final : public Node { |
9718 | NodeHandle Rois_; |
9719 | NodeHandle Deltas_; |
9720 | NodeHandle ImInfo_; |
9721 | std::vector<float> Weights_; |
9722 | bool ApplyScale_; |
9723 | bool Rotated_; |
9724 | bool AngleBoundOn_; |
9725 | int64_t AngleBoundLo_; |
9726 | int64_t AngleBoundHi_; |
9727 | float ClipAngleThresh_; |
9728 | bool LegacyPlusOne_; |
9729 | |
9730 | public: |
9731 | enum InputIndices { |
9732 | RoisIdx = 0, |
9733 | DeltasIdx = 1, |
9734 | ImInfoIdx = 2, |
9735 | }; |
9736 | |
9737 | enum ResultIndices { |
9738 | BoxOutIdx = 0, |
9739 | RoiBatchSplitsIdx = 1, |
9740 | }; |
9741 | |
9742 | BBoxTransformNode(llvm::StringRef name, TypeRef BoxOut , TypeRef RoiBatchSplits , NodeValue Rois, NodeValue Deltas, NodeValue ImInfo, std::vector<float> Weights, bool ApplyScale, bool Rotated, bool AngleBoundOn, int64_t AngleBoundLo, int64_t AngleBoundHi, float ClipAngleThresh, bool LegacyPlusOne) |
9743 | : Node(Kinded::Kind::BBoxTransformNodeKind, name), Rois_(this, Rois), Deltas_(this, Deltas), ImInfo_(this, ImInfo), Weights_(Weights), ApplyScale_(ApplyScale), Rotated_(Rotated), AngleBoundOn_(AngleBoundOn), AngleBoundLo_(AngleBoundLo), AngleBoundHi_(AngleBoundHi), ClipAngleThresh_(ClipAngleThresh), LegacyPlusOne_(LegacyPlusOne) { |
9744 | addResult(BoxOut); |
9745 | addResult(RoiBatchSplits); |
9746 | } |
9747 | const NodeValue getRois() const { return Rois_; } |
9748 | const NodeValue getDeltas() const { return Deltas_; } |
9749 | const NodeValue getImInfo() const { return ImInfo_; } |
9750 | NodeValue getBoxOut() { return getNthResult(0); } |
9751 | const NodeValue getBoxOut() const { return getNthResult(0); } |
9752 | NodeValue getRoiBatchSplits() { return getNthResult(1); } |
9753 | const NodeValue getRoiBatchSplits() const { return getNthResult(1); } |
9754 | llvm::ArrayRef<float> getWeights() const { return Weights_; } |
9755 | bool getApplyScale() const { return ApplyScale_; } |
9756 | bool getRotated() const { return Rotated_; } |
9757 | bool getAngleBoundOn() const { return AngleBoundOn_; } |
9758 | int64_t getAngleBoundLo() const { return AngleBoundLo_; } |
9759 | int64_t getAngleBoundHi() const { return AngleBoundHi_; } |
9760 | float getClipAngleThresh() const { return ClipAngleThresh_; } |
9761 | bool getLegacyPlusOne() const { return LegacyPlusOne_; } |
9762 | |
9763 | static bool classof(const Kinded *k) { |
9764 | return k->getKind() == Kinded::Kind::BBoxTransformNodeKind; |
9765 | } |
9766 | |
9767 | |
9768 | bool isOverwrittenNthInput(unsigned idx) const { |
9769 | return false; |
9770 | } |
9771 | |
9772 | unsigned getNumInputs() const; |
9773 | std::string getInputName(unsigned idx) const; |
9774 | NodeValue getNthInput(unsigned idx); |
9775 | void setNthInput(unsigned idx, NodeValue val); |
9776 | llvm::StringRef getOutputName(unsigned idx) const; |
9777 | bool hasSideEffects() const { return 0; } |
9778 | bool isCanonical() const { return 1; } |
9779 | bool isDataParallel() const { return 0; } |
9780 | std::string getDebugDesc() const; |
9781 | bool isEqual(const BBoxTransformNode &other) const; |
9782 | llvm::hash_code getHash() const; |
9783 | void visit(Node *parent, NodeWalker *visitor); |
9784 | Node* clone() const; |
9785 | bool verify() const; |
9786 | }; |
9787 | } // namespace glow |
9788 | |
9789 | |
9790 | namespace glow { |
9791 | /// Given RpnMinLevel, RpnMaxLevel and RpnPostNmsTopN CollectRpnProposals merges RoisIn based on RoisProbsIn and returns top proposals limited to RpnPostNmsTopN total, size (n x B), where B is box dimensions and based on dimension of input rois. Format for upright boxes is (image_index, x1, y1, x2, y2).Format for rotated boxes (image_index, ctr_x, ctr_y, w, h, angle)RpnPostNmsTopN should be greater than zero |
9792 | class CollectRpnProposalsNode final : public Node { |
9793 | std::vector<NodeHandle> RoisIn_; |
9794 | std::vector<NodeHandle> RoisProbsIn_; |
9795 | int64_t RpnMaxLevel_; |
9796 | int64_t RpnMinLevel_; |
9797 | unsigned_t RpnPostNmsTopN_; |
9798 | |
9799 | public: |
9800 | enum InputIndices { |
9801 | }; |
9802 | |
9803 | enum ResultIndices { |
9804 | ResultIdx = 0, |
9805 | }; |
9806 | |
9807 | CollectRpnProposalsNode(llvm::StringRef name, TypeRef Result , std::vector<NodeValue> RoisIn, std::vector<NodeValue> RoisProbsIn, int64_t RpnMaxLevel, int64_t RpnMinLevel, unsigned_t RpnPostNmsTopN) |
9808 | : Node(Kinded::Kind::CollectRpnProposalsNodeKind, name), RpnMaxLevel_(RpnMaxLevel), RpnMinLevel_(RpnMinLevel), RpnPostNmsTopN_(RpnPostNmsTopN) { |
9809 | addResult(Result); |
9810 | RoisIn_.resize(RoisIn.size()); |
9811 | for (size_t idx = 0, e = RoisIn.size(); idx < e; ++idx) { |
9812 | RoisIn_[idx] = RoisIn[idx]; |
9813 | RoisIn_[idx].setParent(this); |
9814 | } |
9815 | RoisProbsIn_.resize(RoisProbsIn.size()); |
9816 | for (size_t idx = 0, e = RoisProbsIn.size(); idx < e; ++idx) { |
9817 | RoisProbsIn_[idx] = RoisProbsIn[idx]; |
9818 | RoisProbsIn_[idx].setParent(this); |
9819 | } |
9820 | } |
9821 | NodeValue getResult() { return getNthResult(0); } |
9822 | const NodeValue getResult() const { return getNthResult(0); } |
9823 | NodeValueArrayRef getRoisIn() const { return RoisIn_; } |
9824 | NodeValueArrayRef getRoisProbsIn() const { return RoisProbsIn_; } |
9825 | int64_t getRpnMaxLevel() const { return RpnMaxLevel_; } |
9826 | int64_t getRpnMinLevel() const { return RpnMinLevel_; } |
9827 | unsigned_t getRpnPostNmsTopN() const { return RpnPostNmsTopN_; } |
9828 | |
9829 | static bool classof(const Kinded *k) { |
9830 | return k->getKind() == Kinded::Kind::CollectRpnProposalsNodeKind; |
9831 | } |
9832 | |
9833 | |
9834 | bool isOverwrittenNthInput(unsigned idx) const { |
9835 | return false; |
9836 | } |
9837 | |
9838 | unsigned getNumInputs() const; |
9839 | std::string getInputName(unsigned idx) const; |
9840 | NodeValue getNthInput(unsigned idx); |
9841 | void setNthInput(unsigned idx, NodeValue val); |
9842 | llvm::StringRef getOutputName(unsigned idx) const; |
9843 | bool hasSideEffects() const { return 0; } |
9844 | bool isCanonical() const { return 1; } |
9845 | bool isDataParallel() const { return 0; } |
9846 | std::string getDebugDesc() const; |
9847 | bool isEqual(const CollectRpnProposalsNode &other) const; |
9848 | llvm::hash_code getHash() const; |
9849 | void visit(Node *parent, NodeWalker *visitor); |
9850 | Node* clone() const; |
9851 | bool verify() const; |
9852 | }; |
9853 | } // namespace glow |
9854 | |
9855 | |
9856 | namespace glow { |
9857 | /// LookupTable based data-parallel operation.Given an interpolation table and and index table, return interpolated approximations for arbitrary functions. |
9858 | class LookupTableNode final : public Node { |
9859 | NodeHandle Input_; |
9860 | NodeHandle Table_; |
9861 | NodeHandle TableIdx_; |
9862 | glow::LUTOperator Operator_; |
9863 | std::vector<float> OperatorArgs_; |
9864 | |
9865 | public: |
9866 | enum InputIndices { |
9867 | InputIdx = 0, |
9868 | TableIdx = 1, |
9869 | TableIdxIdx = 2, |
9870 | }; |
9871 | |
9872 | enum ResultIndices { |
9873 | ResultIdx = 0, |
9874 | }; |
9875 | |
9876 | LookupTableNode(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Table, NodeValue TableIdx, glow::LUTOperator Operator, std::vector<float> OperatorArgs) |
9877 | : Node(Kinded::Kind::LookupTableNodeKind, name), Input_(this, Input), Table_(this, Table), TableIdx_(this, TableIdx), Operator_(Operator), OperatorArgs_(OperatorArgs) { |
9878 | addResult(Result); |
9879 | } |
9880 | const NodeValue getInput() const { return Input_; } |
9881 | const NodeValue getTable() const { return Table_; } |
9882 | const NodeValue getTableIdx() const { return TableIdx_; } |
9883 | NodeValue getResult() { return getNthResult(0); } |
9884 | const NodeValue getResult() const { return getNthResult(0); } |
9885 | glow::LUTOperator getOperator() const { return Operator_; } |
9886 | llvm::ArrayRef<float> getOperatorArgs() const { return OperatorArgs_; } |
9887 | |
9888 | static bool classof(const Kinded *k) { |
9889 | return k->getKind() == Kinded::Kind::LookupTableNodeKind; |
9890 | } |
9891 | |
9892 | |
9893 | bool isOverwrittenNthInput(unsigned idx) const { |
9894 | return false; |
9895 | } |
9896 | |
9897 | unsigned getNumInputs() const; |
9898 | std::string getInputName(unsigned idx) const; |
9899 | NodeValue getNthInput(unsigned idx); |
9900 | void setNthInput(unsigned idx, NodeValue val); |
9901 | llvm::StringRef getOutputName(unsigned idx) const; |
9902 | bool hasSideEffects() const { return 0; } |
9903 | bool isCanonical() const { return 1; } |
9904 | bool isDataParallel() const { return 1; } |
9905 | std::string getDebugDesc() const; |
9906 | bool isEqual(const LookupTableNode &other) const; |
9907 | llvm::hash_code getHash() const; |
9908 | void visit(Node *parent, NodeWalker *visitor); |
9909 | Node* clone() const; |
9910 | bool verify() const; |
9911 | }; |
9912 | } // namespace glow |
9913 | |
9914 | |
9915 | namespace glow { |
9916 | /// A Max node with one splat input; CPU specific. |
9917 | class CPUMaxSplatNode final : public Node { |
9918 | NodeHandle Input_; |
9919 | float SplatValue_; |
9920 | |
9921 | public: |
9922 | enum InputIndices { |
9923 | InputIdx = 0, |
9924 | }; |
9925 | |
9926 | enum ResultIndices { |
9927 | ResultIdx = 0, |
9928 | }; |
9929 | |
9930 | CPUMaxSplatNode(llvm::StringRef name, NodeValue Input, float SplatValue) |
9931 | : Node(Kinded::Kind::CPUMaxSplatNodeKind, name), Input_(this, Input), SplatValue_(SplatValue) { |
9932 | addResult(Input.getType()); |
9933 | } |
9934 | const NodeValue getInput() const { return Input_; } |
9935 | NodeValue getResult() { return getNthResult(0); } |
9936 | const NodeValue getResult() const { return getNthResult(0); } |
9937 | float getSplatValue() const { return SplatValue_; } |
9938 | |
9939 | static bool classof(const Kinded *k) { |
9940 | return k->getKind() == Kinded::Kind::CPUMaxSplatNodeKind; |
9941 | } |
9942 | |
9943 | |
9944 | bool isOverwrittenNthInput(unsigned idx) const { |
9945 | return false; |
9946 | } |
9947 | |
9948 | unsigned getNumInputs() const; |
9949 | std::string getInputName(unsigned idx) const; |
9950 | NodeValue getNthInput(unsigned idx); |
9951 | void setNthInput(unsigned idx, NodeValue val); |
9952 | llvm::StringRef getOutputName(unsigned idx) const; |
9953 | bool hasSideEffects() const { return 0; } |
9954 | bool isCanonical() const { return 0; } |
9955 | bool isDataParallel() const { return 0; } |
9956 | std::string getDebugDesc() const; |
9957 | bool isEqual(const CPUMaxSplatNode &other) const; |
9958 | llvm::hash_code getHash() const; |
9959 | void visit(Node *parent, NodeWalker *visitor); |
9960 | Node* clone() const; |
9961 | bool verify() const; |
9962 | }; |
9963 | } // namespace glow |
9964 | |
9965 | |
9966 | namespace glow { |
9967 | /// This is a cpu-specific convolution implementation where the filter is transposed to the shape [D/8, K, K, C, 8] |
9968 | class CPUConvDKKC8Node final : public Node { |
9969 | NodeHandle Input_; |
9970 | NodeHandle Filter_; |
9971 | NodeHandle Bias_; |
9972 | std::vector<unsigned_t> Kernels_; |
9973 | std::vector<unsigned_t> Strides_; |
9974 | std::vector<unsigned_t> Pads_; |
9975 | unsigned_t Group_; |
9976 | |
9977 | public: |
9978 | enum InputIndices { |
9979 | InputIdx = 0, |
9980 | FilterIdx = 1, |
9981 | BiasIdx = 2, |
9982 | }; |
9983 | |
9984 | enum ResultIndices { |
9985 | ResultIdx = 0, |
9986 | }; |
9987 | |
9988 | CPUConvDKKC8Node(llvm::StringRef name, TypeRef Result , NodeValue Input, NodeValue Filter, NodeValue Bias, std::vector<unsigned_t> Kernels, std::vector<unsigned_t> Strides, std::vector<unsigned_t> Pads, unsigned_t Group) |
9989 | : Node(Kinded::Kind::CPUConvDKKC8NodeKind, name), Input_(this, Input), Filter_(this, Filter), Bias_(this, Bias), Kernels_(Kernels), Strides_(Strides), Pads_(Pads), Group_(Group) { |
9990 | addResult(Result); |
9991 | } |
9992 | const NodeValue getInput() const { return Input_; } |
9993 | const NodeValue getFilter() const { return Filter_; } |
9994 | const NodeValue getBias() const { return Bias_; } |
9995 | NodeValue getResult() { return getNthResult(0); } |
9996 | const NodeValue getResult() const { return getNthResult(0); } |
9997 | llvm::ArrayRef<unsigned_t> getKernels() const { return Kernels_; } |
9998 | llvm::ArrayRef<unsigned_t> getStrides() const { return Strides_; } |
9999 | llvm::ArrayRef<unsigned_t> getPads() const { return Pads_; } |
10000 | unsigned_t getGroup() const { return Group_; } |
10001 | |
10002 | static bool classof(const Kinded *k) { |
10003 | return k->getKind() == Kinded::Kind::CPUConvDKKC8NodeKind; |
10004 | } |
10005 | |
10006 | |
10007 | bool isOverwrittenNthInput(unsigned idx) const { |
10008 | return false; |
10009 | } |
10010 | |
10011 | unsigned getNumInputs() const; |
10012 | std::string getInputName(unsigned idx) const; |
10013 | NodeValue getNthInput(unsigned idx); |
10014 | void setNthInput(unsigned idx, NodeValue val); |
10015 | llvm::StringRef getOutputName(unsigned idx) const; |
10016 | bool hasSideEffects() const { return 0; } |
10017 | bool isCanonical() const { return 0; } |
10018 | bool isDataParallel() const { return 0; } |
10019 | std::string getDebugDesc() const; |
10020 | bool isEqual(const CPUConvDKKC8Node &other) const; |
10021 | llvm::hash_code getHash() const; |
10022 | void visit(Node *parent, NodeWalker *visitor); |
10023 | Node* clone() const; |
10024 | bool verify() const; |
10025 | }; |
10026 | } // namespace glow |
10027 | |