1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Graph/Nodes.h" |
18 | #include "glow/Base/Type.h" |
19 | #include "glow/Graph/Graph.h" |
20 | #include "glow/Graph/VerifierHelper.h" |
21 | #include "glow/Support/Support.h" |
22 | |
23 | using namespace glow; |
24 | |
25 | bool Storage::isEqual(const Storage &other) const { |
26 | /// A storage should be equal only to itself! |
27 | return this == &other; |
28 | } |
29 | |
30 | llvm::hash_code Constant::getHash() const { |
31 | return llvm::hash_combine(getName(), getType()); |
32 | } |
33 | |
34 | llvm::hash_code Placeholder::getHash() const { |
35 | return llvm::hash_combine(getName()); |
36 | } |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Visitor methods |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | void Storage::visit(Node *parent, NodeWalker *visitor) { |
43 | if (!visitor->shouldVisit(parent, this)) { |
44 | return; |
45 | } |
46 | visitor->pre(parent, this); |
47 | visitor->post(parent, this); |
48 | } |
49 | |
50 | void Storage::visit(const Node *parent, NodeWalker *visitor) const { |
51 | if (!visitor->shouldVisit(parent, this)) { |
52 | return; |
53 | } |
54 | visitor->pre(parent, this); |
55 | visitor->post(parent, this); |
56 | } |
57 | |
58 | //===----------------------------------------------------------------------===// |
59 | // Edge getters methods |
60 | //===----------------------------------------------------------------------===// |
61 | unsigned Storage::getNumInputs() const { return 0; } |
62 | |
63 | std::string Storage::getInputName(unsigned idx) const { |
64 | llvm_unreachable("Invalid index" ); |
65 | } |
66 | |
67 | NodeValue Storage::getNthInput(unsigned idx) { |
68 | llvm_unreachable("Invalid index" ); |
69 | } |
70 | |
71 | llvm::StringRef Storage::getOutputName(unsigned idx) const { |
72 | if (idx == 0) { |
73 | return "Output" ; |
74 | } |
75 | llvm_unreachable("Invalid index" ); |
76 | } |
77 | |
78 | bool Storage::hasSideEffects() const { return false; } |
79 | |
80 | Node *Storage::clone() const { llvm_unreachable("Storage can't be cloned." ); } |
81 | |
82 | //===----------------------------------------------------------------------===// |
83 | // Debug description methods |
84 | //===----------------------------------------------------------------------===// |
85 | |
86 | std::string Constant::getDebugDesc(bool skipUsers) const { |
87 | DescriptionBuilder db(getKindName()); |
88 | db.addParam("Name" , separateString(getName(), 100, "\n" )) |
89 | .addParam("Layout" , getLayout()) |
90 | .addParam("Output" , *getType()); |
91 | if (!skipUsers) { |
92 | db.addParam("Users" , getNumUsers()); |
93 | } |
94 | return db; |
95 | } |
96 | |
97 | std::string Placeholder::getDebugDesc(bool skipUsers) const { |
98 | DescriptionBuilder db(getKindName()); |
99 | db.addParam("Name" , separateString(getName(), 100, "\n" )) |
100 | .addParam("Layout" , getLayout()) |
101 | .addParam("Output" , *getType()) |
102 | .addParam("Trainable" , isTraining()) |
103 | .addParam("Static" , isStatic()); |
104 | if (!skipUsers) { |
105 | db.addParam("Users" , getNumUsers()); |
106 | } |
107 | return db; |
108 | } |
109 | |
110 | //===----------------------------------------------------------------------===// |
111 | // Nodes verification |
112 | //===----------------------------------------------------------------------===// |
113 | |
114 | static bool verifyConvFilter(const Node *parent, NodeValue filter, |
115 | const ShapeNHWC &idim, const ShapeNHWC &odim, |
116 | const ShapeHW &kdim, unsigned_t group) { |
117 | const dim_t filterDims[] = {odim.c, kdim.height, kdim.width, |
118 | idim.c / (dim_t)group}; |
119 | return expectCompareTrue("Invalid filter dimensions" , |
120 | filter.getType()->dims(), |
121 | llvm::makeArrayRef(filterDims), parent); |
122 | } |
123 | |
124 | static bool verifyConvFilter(const Node *parent, NodeValue filter, |
125 | const ShapeNCHW &idim, const ShapeNCHW &odim, |
126 | const ShapeHW &kdim, unsigned_t group) { |
127 | const dim_t filterDims[] = {odim.c, idim.c / (dim_t)group, kdim.height, |
128 | kdim.width}; |
129 | |
130 | return expectCompareTrue("Invalid filter dimensions" , |
131 | filter.getType()->dims(), |
132 | llvm::makeArrayRef(filterDims), parent); |
133 | } |
134 | |
135 | template <typename Shape> |
136 | static bool verifyConvolution(NodeValue src, NodeValue dest, NodeValue filter, |
137 | NodeValue bias, |
138 | llvm::ArrayRef<unsigned_t> kernels, |
139 | llvm::ArrayRef<unsigned_t> strides, |
140 | llvm::ArrayRef<unsigned_t> pads, unsigned_t group, |
141 | llvm::ArrayRef<unsigned_t> dilation, |
142 | bool checkBiasType = true) { |
143 | const Node *parent = dest.getNode(); |
144 | bool isValid = checkType(src, dest.getElementType(), parent); |
145 | isValid &= checkType(src, filter.getElementType(), parent); |
146 | if (checkBiasType) { |
147 | // Non quantization type check. |
148 | if (src.getElementType() == ElemKind::FloatTy) { |
149 | isValid &= checkType(bias, ElemKind::FloatTy, parent); |
150 | } |
151 | // Quantization type check. |
152 | if (src.getElementType() == ElemKind::Int8QTy) { |
153 | isValid &= |
154 | expectCompareTrue("Bias type should be float, Int8 or Int32 for Conv" , |
155 | bias.getElementType() == ElemKind::FloatTy || |
156 | bias.getElementType() == ElemKind::Int8QTy || |
157 | bias.getElementType() == ElemKind::Int32QTy, |
158 | true, parent); |
159 | } |
160 | } |
161 | Shape idim(src.getType()->dims()); |
162 | Shape odim(dest.getType()->dims()); |
163 | PaddingTLBR pdim(pads); |
164 | ShapeHW kdim(kernels); |
165 | isValid &= expectCompareTrue("buffer height too small for selected stride" , |
166 | idim.h + pdim.top + pdim.bottom, kdim.height, |
167 | parent, CompareOperatorGreaterEqual<dim_t>()); |
168 | isValid &= expectCompareTrue("buffer width too small for selected stride" , |
169 | idim.w + pdim.left + pdim.right, kdim.width, |
170 | parent, CompareOperatorGreaterEqual<dim_t>()); |
171 | isValid &= expectCompareTrue("channels number must be divisible by groups" , |
172 | idim.c % group, dim_t(0), parent); |
173 | isValid &= expectCompareTrue("Dilation should have same length as Stride" , |
174 | dilation.size(), strides.size(), parent); |
175 | |
176 | auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, |
177 | pads, dilation); |
178 | isValid &= |
179 | expectCompareTrue("Invalid output dimension N" , odim.n, idim.n, parent); |
180 | isValid &= expectCompareTrue("Invalid output dimension H" , odim.h, |
181 | outSz.first, parent); |
182 | isValid &= expectCompareTrue("Invalid output dimension W" , odim.w, |
183 | outSz.second, parent); |
184 | isValid &= expectCompareTrue("Invalid output dimension C" , odim.c % group, |
185 | dim_t(0), parent); |
186 | |
187 | isValid &= verifyConvFilter(parent, filter, idim, odim, kdim, group); |
188 | |
189 | const dim_t biasDims[] = {odim.c}; |
190 | |
191 | isValid &= |
192 | expectCompareTrue("Invalid bias dimensions" , bias.getType()->dims(), |
193 | llvm::makeArrayRef(biasDims), parent); |
194 | return isValid; |
195 | } |
196 | |
197 | static bool verifyConvolution3D(NodeValue src, NodeValue dest, NodeValue filter, |
198 | NodeValue bias, |
199 | llvm::ArrayRef<unsigned_t> kernels, |
200 | llvm::ArrayRef<unsigned_t> strides, |
201 | llvm::ArrayRef<unsigned_t> pads, |
202 | unsigned_t group) { |
203 | |
204 | const Node *parent = dest.getNode(); |
205 | bool isValid = checkType(src, dest.getElementType(), parent); |
206 | isValid &= checkType(src, filter.getElementType(), parent); |
207 | // Non quantization type check. |
208 | if (src.getElementType() == ElemKind::FloatTy) { |
209 | isValid &= checkType(bias, ElemKind::FloatTy, parent); |
210 | } |
211 | // Quantization type check. |
212 | if (src.getElementType() == ElemKind::Int8QTy) { |
213 | isValid &= |
214 | expectCompareTrue("Bias type should be Float, Int8 or Int32 for Conv3D" , |
215 | bias.getElementType() == ElemKind::FloatTy || |
216 | bias.getElementType() == ElemKind::Int8QTy || |
217 | bias.getElementType() == ElemKind::Int32QTy, |
218 | true, parent); |
219 | } |
220 | ShapeNTHWC idim(src.getType()->dims()); |
221 | ShapeNTHWC odim(dest.getType()->dims()); |
222 | PaddingNFTBLR pdim(pads); |
223 | ShapeTHW kdim(kernels); |
224 | isValid &= expectCompareTrue("buffer height too small for selected stride" , |
225 | idim.h + pdim.top + pdim.bottom, kdim.height, |
226 | parent, CompareOperatorGreaterEqual<dim_t>()); |
227 | isValid &= expectCompareTrue("buffer width too small for selected stride" , |
228 | idim.w + pdim.left + pdim.right, kdim.width, |
229 | parent, CompareOperatorGreaterEqual<dim_t>()); |
230 | isValid &= |
231 | expectCompareTrue("buffer time too small for selected stride" , |
232 | idim.t + pdim.near + pdim.far, kdim.temporal_frames, |
233 | parent, CompareOperatorGreaterEqual<dim_t>()); |
234 | isValid &= expectCompareTrue("channels number must be divisible by groups" , |
235 | idim.c % group, dim_t(0), parent); |
236 | |
237 | auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels, |
238 | strides, pads); |
239 | isValid &= |
240 | expectCompareTrue("Invalid output dimension N" , odim.n, idim.n, parent); |
241 | isValid &= expectCompareTrue("Invalid output dimension T" , odim.t, |
242 | outSz.temporal_frames, parent); |
243 | isValid &= expectCompareTrue("Invalid output dimension H" , odim.h, |
244 | outSz.height, parent); |
245 | isValid &= expectCompareTrue("Invalid output dimension W" , odim.w, |
246 | outSz.width, parent); |
247 | isValid &= expectCompareTrue("Invalid output dimension C" , odim.c % group, |
248 | dim_t(0), parent); |
249 | |
250 | const dim_t filterDims[] = {odim.c, kdim.temporal_frames, kdim.height, |
251 | kdim.width, idim.c / group}; |
252 | isValid &= |
253 | expectCompareTrue("Invalid filter dimensions" , filter.getType()->dims(), |
254 | llvm::makeArrayRef(filterDims), parent); |
255 | const dim_t biasDims[] = {odim.c}; |
256 | isValid &= |
257 | expectCompareTrue("Invalid bias dimensions" , bias.getType()->dims(), |
258 | llvm::makeArrayRef(biasDims), parent); |
259 | return isValid; |
260 | } |
261 | |
262 | static bool verifyConvTranspose(NodeValue src, NodeValue dest, NodeValue filter, |
263 | llvm::ArrayRef<unsigned_t> kernels, |
264 | llvm::ArrayRef<unsigned_t> strides, |
265 | llvm::ArrayRef<unsigned_t> pads, |
266 | unsigned_t group, |
267 | llvm::ArrayRef<unsigned_t> dilation) { |
268 | const Node *parent = dest.getNode(); |
269 | bool isValid = checkType(src, dest.getElementType(), parent); |
270 | isValid &= checkType(src, filter.getElementType(), parent); |
271 | ShapeNHWC idim(src.getType()->dims()); |
272 | ShapeNHWC odim(dest.getType()->dims()); |
273 | PaddingTLBR pdim(pads); |
274 | ShapeHW kdim(kernels); |
275 | // TODO: any kernel size check in respect to input ? In contrast to Conv, |
276 | // seems kernel can be any size. |
277 | |
278 | isValid &= expectCompareTrue("channels number must be divisible by groups" , |
279 | idim.c % group, dim_t(0), parent); |
280 | |
281 | isValid &= expectCompareTrue("Stride should be less than kernel." , |
282 | strides[0] <= kernels[0], true, parent); |
283 | |
284 | isValid &= expectCompareTrue("Stride should be less than kernel." , |
285 | strides[1] <= kernels[1], true, parent); |
286 | |
287 | isValid &= expectCompareTrue("channels number must be divisible by groups" , |
288 | idim.c % group, dim_t(0), parent); |
289 | |
290 | isValid &= expectCompareTrue("Dilation should have same length as Stride" , |
291 | dilation.size(), strides.size(), parent); |
292 | |
293 | auto outSz = calculateConvTransposeOutputDims(idim.h, idim.w, kernels, |
294 | strides, pads, dilation); |
295 | (void)outSz; |
296 | isValid &= |
297 | expectCompareTrue("Invalid output dimension N" , odim.n, idim.n, parent); |
298 | |
299 | isValid &= |
300 | expectCompareTrue("Invalid output dimension HT" , odim.h, outSz.first, |
301 | parent, CompareOperatorGreaterEqual<dim_t>()); |
302 | isValid &= |
303 | expectCompareTrue("Invalid output dimension WT" , odim.w, outSz.second, |
304 | parent, CompareOperatorGreaterEqual<dim_t>()); |
305 | |
306 | isValid &= expectCompareTrue("Invalid output dimension CT" , odim.c % group, |
307 | dim_t(0), parent); |
308 | |
309 | const dim_t filterDims[] = {odim.c / (dim_t)group, kdim.height, kdim.width, |
310 | idim.c}; |
311 | isValid &= |
312 | expectCompareTrue("Invalid filter dimensions" , filter.getType()->dims(), |
313 | llvm::makeArrayRef(filterDims), parent); |
314 | return isValid; |
315 | } |
316 | |
317 | static bool verifyFullyConnected(NodeValue src, NodeValue weights, |
318 | NodeValue bias, NodeValue dest) { |
319 | const Node *parent = dest.getNode(); |
320 | bool isValid = expectCompareTrue("FC input must be 2D" , size_t(2), |
321 | src.dims().size(), parent); |
322 | isValid &= expectCompareTrue("FC weights must be 2D" , size_t(2), |
323 | weights.dims().size(), parent); |
324 | isValid &= expectCompareTrue("FC bias must be 1D" , size_t(1), |
325 | bias.dims().size(), parent); |
326 | isValid &= expectCompareTrue("Mismatch between source and dest dimensions" , |
327 | src.dims()[0], dest.dims()[0], parent); |
328 | isValid &= expectCompareTrue("Mismatch between source and weight dimensions" , |
329 | src.dims()[1], weights.dims()[0], parent); |
330 | isValid &= expectCompareTrue("Inconsistent bias/dest sizes" , bias.dims()[0], |
331 | weights.dims()[1], parent); |
332 | isValid &= expectCompareTrue("Inconsistent weights/dest sizes" , |
333 | weights.dims()[1], dest.dims()[1], parent); |
334 | |
335 | if (src.getElementType() == ElemKind::Int8QTy) { |
336 | isValid &= |
337 | expectCompareTrue("Bias type should be Int8, Int32 or FP32 for FC" , |
338 | bias.getElementType() == ElemKind::Int8QTy || |
339 | bias.getElementType() == ElemKind::Int32QTy || |
340 | bias.getElementType() == ElemKind::FloatTy, |
341 | true, parent); |
342 | } |
343 | return isValid; |
344 | } |
345 | |
346 | template <typename Shape> |
347 | static bool verifyPool(NodeValue src, NodeValue dest, |
348 | llvm::ArrayRef<unsigned_t> kernels, |
349 | llvm::ArrayRef<unsigned_t> strides, |
350 | llvm::ArrayRef<unsigned_t> pads, bool isAvgPool = true) { |
351 | const Node *parent = dest.getNode(); |
352 | Shape idim(src.getType()->dims()); |
353 | Shape odim(dest.getType()->dims()); |
354 | PaddingTLBR pdim(pads); |
355 | ShapeHW kdim(kernels); |
356 | |
357 | bool isValid = |
358 | expectCompareTrue("buffer height too small for selected stride" , |
359 | idim.h + pdim.top + pdim.bottom, kdim.height, parent, |
360 | CompareOperatorGreaterEqual<dim_t>()); |
361 | isValid &= expectCompareTrue("buffer width too small for selected stride" , |
362 | idim.w + pdim.left + pdim.right, kdim.width, |
363 | parent, CompareOperatorGreaterEqual<dim_t>()); |
364 | |
365 | auto outSz = |
366 | calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads); |
367 | Shape exp(idim); |
368 | exp.h = outSz.first; |
369 | exp.w = outSz.second; |
370 | isValid &= |
371 | expectCompareTrue("Unexpected output dimensions" , exp, odim, parent); |
372 | |
373 | // For quantized AvgPool, the scale and offset of its input and output could |
374 | // be different. But for quantized MaxPool, the scale and offset of its input |
375 | // and output should be the same. |
376 | isValid &= checkSameIsQuantized(src.getType(), dest.getType(), parent); |
377 | if (!isAvgPool) { |
378 | isValid &= checkTypeIgnoreShape(src, dest, parent); |
379 | } |
380 | return isValid; |
381 | } |
382 | |
383 | template <typename Shape> |
384 | static bool |
385 | verifyPool3D(NodeValue src, NodeValue dest, llvm::ArrayRef<unsigned_t> kernels, |
386 | llvm::ArrayRef<unsigned_t> strides, |
387 | llvm::ArrayRef<unsigned_t> pads, bool isAvgPool = true) { |
388 | const Node *parent = dest.getNode(); |
389 | Shape idim(src.getType()->dims()); |
390 | Shape odim(dest.getType()->dims()); |
391 | PaddingTLNBRF pdim(pads); |
392 | ShapeTHW kdim(kernels); |
393 | |
394 | bool isValid = |
395 | expectCompareTrue("buffer height too small for selected stride" , |
396 | idim.h + pdim.top + pdim.bottom, kdim.height, parent, |
397 | CompareOperatorGreaterEqual<dim_t>()); |
398 | isValid &= expectCompareTrue("buffer width too small for selected stride" , |
399 | idim.w + pdim.left + pdim.right, kdim.width, |
400 | parent, CompareOperatorGreaterEqual<dim_t>()); |
401 | isValid &= |
402 | expectCompareTrue("buffer temporal_frames are too small for " |
403 | "selected stride" , |
404 | idim.t + pdim.near + pdim.far, kdim.temporal_frames, |
405 | parent, CompareOperatorGreaterEqual<dim_t>()); |
406 | |
407 | auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels, |
408 | strides, pads); |
409 | Shape exp(idim); |
410 | exp.t = outSz.temporal_frames; |
411 | exp.h = outSz.height; |
412 | exp.w = outSz.width; |
413 | isValid &= |
414 | expectCompareTrue("Unexpected output dimensions" , exp, odim, parent); |
415 | |
416 | // For quantized AvgPool, the scale and offset of its input and output could |
417 | // be different. But for quantized MaxPool, the scale and offset of its input |
418 | // and output should be the same. |
419 | isValid = |
420 | isValid && checkSameIsQuantized(src.getType(), dest.getType(), parent); |
421 | if (!isAvgPool) { |
422 | isValid = isValid && checkTypeIgnoreShape(src, dest, parent); |
423 | } |
424 | |
425 | return isValid; |
426 | } |
427 | |
428 | static bool verifyBatchNormalization(NodeValue src, NodeValue dest, |
429 | NodeValue bias, NodeValue scale, |
430 | NodeValue mean, NodeValue var, |
431 | unsigned_t channel) { |
432 | const Node *parent = dest.getNode(); |
433 | |
434 | // Source and Dest can have different quantization params |
435 | // but need to match in shape and element type. |
436 | bool isValid = checkSameShape(dest, src, parent); |
437 | isValid = isValid && checkType(dest, src.getElementType(), parent); |
438 | |
439 | isValid = |
440 | isValid && |
441 | expectCompareTrue( |
442 | "Require at least two input dims i.e., batch and channel dimensions" , |
443 | src.dims().size(), (size_t)1, parent, |
444 | CompareOperatorGreaterThan<size_t>()); |
445 | |
446 | // Figure out how many channels are in the tensor. |
447 | dim_t channels = src.dims()[channel]; |
448 | |
449 | const dim_t expArray[] = {channels}; |
450 | auto exp = llvm::makeArrayRef(expArray); |
451 | isValid = isValid && expectCompareTrue("Invalid bias dimension" , |
452 | bias.getType()->dims(), exp, parent); |
453 | isValid = isValid && expectCompareTrue("Invalid scale dimension" , |
454 | scale.getType()->dims(), exp, parent); |
455 | isValid = isValid && expectCompareTrue("Invalid mean dimension" , |
456 | mean.getType()->dims(), exp, parent); |
457 | isValid = isValid && expectCompareTrue("Invalid var dimension" , |
458 | var.getType()->dims(), exp, parent); |
459 | return isValid; |
460 | } |
461 | |
462 | static bool verifyInstanceNormalization(NodeValue src, NodeValue dest, |
463 | NodeValue bias, NodeValue scale, |
464 | unsigned_t channel) { |
465 | const Node *parent = dest.getNode(); |
466 | bool isValid = true; |
467 | if (src.getType()->isQuantizedType()) { |
468 | isValid &= checkType(src, dest.getElementType(), dest.getNode()); |
469 | isValid &= checkSameShape(src, dest, parent); |
470 | } else { |
471 | isValid &= checkSameType(src, dest, parent); |
472 | } |
473 | |
474 | isValid &= expectCompareTrue( |
475 | "Require at least two input dims i.e., batch and channel dimensions" , |
476 | src.dims().size(), (size_t)1, parent, |
477 | CompareOperatorGreaterThan<size_t>()); |
478 | |
479 | // Figure out how many channels are in the tensor. |
480 | dim_t channels = src.dims()[channel]; |
481 | |
482 | const dim_t expArray[] = {channels}; |
483 | auto exp = llvm::makeArrayRef(expArray); |
484 | isValid &= expectCompareTrue("Invalid bias dimension" , bias.getType()->dims(), |
485 | exp, parent); |
486 | isValid &= expectCompareTrue("Invalid scale dimension" , |
487 | scale.getType()->dims(), exp, parent); |
488 | |
489 | return isValid; |
490 | } |
491 | |
492 | static bool verifyActivation(NodeValue src, NodeValue dest) { |
493 | const Node *parent = dest.getNode(); |
494 | bool isValid = checkSameIsQuantized(src.getType(), dest.getType(), parent); |
495 | if (src.getType()->isQuantizedType()) { |
496 | isValid &= checkType(src, dest.getElementType(), dest.getNode()); |
497 | isValid &= checkSameShape(src, dest, parent); |
498 | } else { |
499 | isValid &= checkSameType(src, dest, parent); |
500 | } |
501 | return isValid; |
502 | } |
503 | |
504 | static bool verifySoftMax(NodeValue src, NodeValue dest) { |
505 | const Node *parent = dest.getNode(); |
506 | if (src.getType()->isQuantizedType()) { |
507 | return checkType(src, dest.getElementType(), parent) && |
508 | checkSameShape(src, dest, parent); |
509 | } |
510 | return checkSameType(src, dest, parent); |
511 | } |
512 | |
513 | static bool verifyLogSoftMax(NodeValue src, NodeValue dest) { |
514 | const Node *parent = dest.getNode(); |
515 | if (src.getType()->isQuantizedType()) { |
516 | return checkType(src, dest.getElementType(), parent) && |
517 | checkSameShape(src, dest, parent); |
518 | } |
519 | return checkSameType(src, dest, parent); |
520 | } |
521 | |
522 | static bool verifyCrossEntropyLoss(NodeValue P, NodeValue CE, |
523 | NodeValue labels) { |
524 | const Node *parent = CE.getNode(); |
525 | bool isValid = checkType(P, CE.getElementType(), parent); |
526 | isValid &= expectCompareTrue("Mismatching shape" , P.dims()[0], |
527 | labels.dims()[0], parent); |
528 | return isValid; |
529 | } |
530 | |
531 | static bool verifyLocalResponseNormalization(NodeValue src, NodeValue dest) { |
532 | return checkSameType(src, dest, dest.getNode()); |
533 | } |
534 | |
535 | static bool verifyArithmetic(NodeValue LHS, NodeValue RHS, NodeValue res) { |
536 | return checkSameShape(res, LHS, res.getNode()) && |
537 | checkSameShape(LHS, RHS, res.getNode()); |
538 | } |
539 | |
540 | static bool verifyRelu(NodeValue result, NodeValue input) { |
541 | const Node *parent = result.getNode(); |
542 | if (input.getType()->isQuantizedType()) { |
543 | return checkSameIsQuantized(input.getType(), result.getType(), parent) && |
544 | checkSameShape(result, input, parent); |
545 | } |
546 | return checkSameType(result, input, parent); |
547 | } |
548 | |
549 | static bool verifyPRelu(NodeValue result, NodeValue input, NodeValue slope) { |
550 | const Node *parent = result.getNode(); |
551 | if (input.getType()->isQuantizedType()) { |
552 | return checkSameIsQuantized(input.getType(), result.getType(), parent) && |
553 | checkSameIsQuantized(input.getType(), slope.getType(), parent) && |
554 | checkSameShape(result, input, parent) && |
555 | checkSameShape(slope, input, parent); |
556 | } |
557 | return checkSameType(result, input, parent) && |
558 | checkSameType(slope, input, parent) && |
559 | checkSameShape(slope, input, parent); |
560 | } |
561 | |
562 | static bool verifyRegression(NodeValue src, NodeValue dest, |
563 | NodeValue expected) { |
564 | return checkSameType(src, dest, dest.getNode()) && |
565 | checkSameType(dest, expected, dest.getNode()); |
566 | } |
567 | |
568 | static bool verifySparseLengthsSum(NodeValue dest, NodeValue data, |
569 | NodeValue indices, NodeValue lengths) { |
570 | bool isValid = checkType(dest, data.getElementType(), dest.getNode()); |
571 | isValid &= checkType(indices, {ElemKind::Int64ITy, ElemKind::Int32ITy}, |
572 | dest.getNode()); |
573 | isValid &= checkType(lengths, ElemKind::Int32ITy, dest.getNode()); |
574 | isValid &= |
575 | expectCompareTrue("Indices must be a 1D vector" , indices.dims().size(), |
576 | size_t(1), dest.getNode()); |
577 | isValid &= |
578 | expectCompareTrue("Lengths must be a 1D vector" , lengths.dims().size(), |
579 | size_t(1), dest.getNode()); |
580 | return isValid; |
581 | } |
582 | |
583 | static bool verifySparseLengthsWeightedSum(NodeValue dest, NodeValue data, |
584 | NodeValue weights, NodeValue indices, |
585 | NodeValue lengths) { |
586 | bool isValid = checkType(dest, data.getElementType(), dest.getNode()); |
587 | isValid &= checkType(weights, data.getElementType(), dest.getNode()); |
588 | isValid &= checkType(indices, {ElemKind::Int64ITy, ElemKind::Int32ITy}, |
589 | dest.getNode()); |
590 | isValid &= checkType(lengths, ElemKind::Int32ITy, dest.getNode()); |
591 | isValid &= |
592 | expectCompareTrue("Indices must be a 1D vector" , indices.dims().size(), |
593 | size_t(1), dest.getNode()); |
594 | isValid &= |
595 | expectCompareTrue("Lengths must be a 1D vector" , lengths.dims().size(), |
596 | size_t(1), dest.getNode()); |
597 | isValid &= |
598 | expectCompareTrue("Weights must be a 1D vector" , weights.dims().size(), |
599 | size_t(1), dest.getNode()); |
600 | |
601 | isValid &= |
602 | expectCompareTrue("Weights and Indices must have the same size" , |
603 | weights.dims()[0], indices.dims()[0], dest.getNode()); |
604 | return isValid; |
605 | } |
606 | |
607 | static bool verifyEmbedding(NodeValue dest, NodeValue weights, |
608 | NodeValue indices) { |
609 | bool isValid = checkType(dest, weights.getElementType(), dest.getNode()); |
610 | isValid &= checkType( |
611 | indices, |
612 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), |
613 | dest.getNode()); |
614 | isValid &= |
615 | expectCompareTrue("Weights must be a 2D tensor" , weights.dims().size(), |
616 | size_t(2), weights.getNode()); |
617 | return isValid; |
618 | } |
619 | |
620 | static bool verifyEmbeddingBag(NodeValue dest, NodeValue data, |
621 | NodeValue weights, NodeValue indices, |
622 | NodeValue offsets) { |
623 | bool isValid = checkType(dest, data.getElementType(), dest.getNode()); |
624 | isValid &= checkType(weights, data.getElementType(), dest.getNode()); |
625 | isValid &= checkType( |
626 | indices, |
627 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), |
628 | dest.getNode()); |
629 | isValid &= checkType( |
630 | offsets, |
631 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), |
632 | dest.getNode()); |
633 | isValid &= |
634 | expectCompareTrue("Indices must be a 1D vector" , indices.dims().size(), |
635 | size_t(1), dest.getNode()); |
636 | isValid &= |
637 | expectCompareTrue("Offsets must be a 1D vector" , offsets.dims().size(), |
638 | size_t(1), dest.getNode()); |
639 | isValid &= |
640 | expectCompareTrue("Weights must be a 1D vector" , weights.dims().size(), |
641 | size_t(1), dest.getNode()); |
642 | |
643 | isValid &= |
644 | expectCompareTrue("Weights and Indices must have the same size" , |
645 | weights.dims()[0], indices.dims()[0], dest.getNode()); |
646 | return isValid; |
647 | } |
648 | |
649 | bool HardSwishNode::verify() const { |
650 | return checkSameType(getInput(), getResult(), this); |
651 | } |
652 | |
653 | bool PadNode::verify() const { |
654 | // Pad is currently only supported for constant padding. |
655 | return expectCompareTrue("only the 'constant' mode is currrently supported" , |
656 | getMode() == PaddingMode::CONSTANT, true, |
657 | getResult().getNode()); |
658 | } |
659 | |
660 | bool ConvolutionNode::verify() const { |
661 | if (getLayout() == NHWC) { |
662 | return verifyConvolution<ShapeNHWC>(getInput(), getResult(), getFilter(), |
663 | getBias(), Kernels_, Strides_, Pads_, |
664 | Group_, Dilation_); |
665 | } else { |
666 | return verifyConvolution<ShapeNCHW>(getInput(), getResult(), getFilter(), |
667 | getBias(), Kernels_, Strides_, Pads_, |
668 | Group_, Dilation_); |
669 | } |
670 | } |
671 | |
672 | bool ChannelwiseQuantizedConvolutionNode::verify() const { |
673 | auto input_dims = getInput().getType()->dims(); |
674 | bool isValid = false; |
675 | bool isConv3D = (input_dims.size() == 5); |
676 | if (isConv3D) { |
677 | isValid = verifyConvolution3D(getInput(), getResult(), getFilter(), |
678 | getBias(), Kernels_, Strides_, Pads_, Group_); |
679 | |
680 | if (!all_of(Dilation_.begin(), Dilation_.end(), |
681 | [](unsigned_t i) { return i == 1; })) { |
682 | report("For Conv3D dilation must be 1" ); |
683 | } |
684 | } else { |
685 | isValid = verifyConvolution<ShapeNHWC>( |
686 | getInput(), getResult(), getFilter(), getBias(), Kernels_, Strides_, |
687 | Pads_, Group_, Dilation_, /* checkBiasType */ false); |
688 | } |
689 | |
690 | isValid &= checkType(getResult(), ElemKind::Int8QTy, this); |
691 | isValid &= checkType(getInput(), ElemKind::Int8QTy, this); |
692 | isValid &= checkType(getFilter(), ElemKind::Int8QTy, this); |
693 | isValid &= checkType( |
694 | getBias(), {ElemKind::Int8QTy, ElemKind::Int32QTy, ElemKind::FloatTy}, |
695 | this); |
696 | |
697 | // Check qparam types. |
698 | isValid &= checkType(getFilterOffsets(), ElemKind::Int32ITy, this); |
699 | isValid &= checkType(getFilterScales(), ElemKind::FloatTy, this); |
700 | isValid &= checkType(getBiasOffsets(), ElemKind::Int32ITy, this); |
701 | isValid &= checkType(getBiasScales(), ElemKind::FloatTy, this); |
702 | |
703 | // Check qparam dimensions. |
704 | isValid &= |
705 | expectCompareTrue("Filter offsets must be a 1D vector" , |
706 | getFilterOffsets().dims().size(), size_t(1), this); |
707 | isValid &= |
708 | expectCompareTrue("Filter scales must be a 1D vector" , |
709 | getFilterScales().dims().size(), size_t(1), this); |
710 | isValid &= expectCompareTrue("Bias offsets must be a 1D vector" , |
711 | getBiasOffsets().dims().size(), size_t(1), this); |
712 | isValid &= expectCompareTrue("Bias scales must be a 1D vector" , |
713 | getBiasScales().dims().size(), size_t(1), this); |
714 | |
715 | // Check qparam sizes. |
716 | isValid &= expectCompareTrue( |
717 | "There must be one filter offset qparam per output channel" , |
718 | getFilterOffsets().dims()[0], dim_t(getResult().dims().back()), this); |
719 | isValid &= expectCompareTrue( |
720 | "There must be one filter scale qparam per output channel" , |
721 | getFilterScales().dims()[0], dim_t(getResult().dims().back()), this); |
722 | isValid &= expectCompareTrue( |
723 | "There must be one bias offset qparam per output channel" , |
724 | getBiasOffsets().dims()[0], dim_t(getResult().dims().back()), this); |
725 | isValid &= expectCompareTrue( |
726 | "There must be one bias scale qparam per output channel" , |
727 | getBiasScales().dims()[0], dim_t(getResult().dims().back()), this); |
728 | |
729 | return isValid; |
730 | } |
731 | |
732 | bool Convolution3DNode::verify() const { |
733 | return verifyConvolution3D(getInput(), getResult(), getFilter(), getBias(), |
734 | Kernels_, Strides_, Pads_, Group_); |
735 | } |
736 | |
737 | bool ConvTransposeNode::verify() const { |
738 | return verifyConvTranspose(getInput(), getResult(), getFilter(), Kernels_, |
739 | Strides_, Pads_, Group_, Dilation_); |
740 | } |
741 | |
742 | /// Verify that types of an input and its gradient are the same. |
743 | static bool verifyInputAndGradInputTypes(NodeValue input, NodeValue gradInput, |
744 | const Node *parent) { |
745 | return checkSameType(input, gradInput, parent); |
746 | } |
747 | |
748 | /// Verify that types of an output and its gradient are the same. |
749 | static bool verifyOutputAndGradOutputTypes(NodeValue output, |
750 | NodeValue gradOutput, |
751 | const Node *parent) { |
752 | return checkSameType(output, gradOutput, parent); |
753 | } |
754 | |
755 | bool Constant::verify() const { |
756 | return expectCompareTrue("Underlying tensor type doesn't match constant type" , |
757 | *getType(), getPayload().getType(), this); |
758 | } |
759 | |
760 | bool ConvolutionGradNode::verify() const { |
761 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
762 | getGradOfInputNamedInput(), this); |
763 | isValid &= verifyInputAndGradInputTypes(getFilter(), |
764 | getGradOfInputNamedFilter(), this); |
765 | isValid &= |
766 | verifyInputAndGradInputTypes(getBias(), getGradOfInputNamedBias(), this); |
767 | isValid &= verifyOutputAndGradOutputTypes( |
768 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
769 | if (getLayout() == NHWC) { |
770 | isValid &= verifyConvolution<ShapeNHWC>( |
771 | getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(), |
772 | getGradOfInputNamedFilter(), getGradOfInputNamedBias(), Kernels_, |
773 | Strides_, Pads_, Group_, Dilation_); |
774 | } else { |
775 | isValid &= verifyConvolution<ShapeNCHW>( |
776 | getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(), |
777 | getGradOfInputNamedFilter(), getGradOfInputNamedBias(), Kernels_, |
778 | Strides_, Pads_, Group_, Dilation_); |
779 | } |
780 | |
781 | return isValid; |
782 | } |
783 | |
784 | bool Convolution3DGradNode::verify() const { |
785 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
786 | getGradOfInputNamedInput(), this); |
787 | isValid &= verifyInputAndGradInputTypes(getFilter(), |
788 | getGradOfInputNamedFilter(), this); |
789 | isValid &= |
790 | verifyInputAndGradInputTypes(getBias(), getGradOfInputNamedBias(), this); |
791 | isValid &= verifyOutputAndGradOutputTypes( |
792 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
793 | isValid &= verifyConvolution3D( |
794 | getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(), |
795 | getGradOfInputNamedFilter(), getGradOfInputNamedBias(), Kernels_, |
796 | Strides_, Pads_, Group_); |
797 | return isValid; |
798 | } |
799 | |
800 | /// \returns the number of columns of data from a fused \p type (i.e. not |
801 | /// considering the columns for per row scale/offsets). |
802 | static size_t getNumDataColumnsFromFused(TypeRef type) { |
803 | size_t n = type->dims()[1]; |
804 | switch (type->getElementType()) { |
805 | case ElemKind::UInt8FusedQTy: |
806 | return n - 2 * sizeof(float); |
807 | case ElemKind::UInt8FusedFP16QTy: |
808 | return n - 2 * sizeof(float16_t); |
809 | case ElemKind::UInt4FusedFP16QTy: |
810 | return (n - 2 * sizeof(float16_t)) * 2; |
811 | case ElemKind::UInt4FusedQTy: |
812 | return (n - 2 * sizeof(float)) * 2; |
813 | default: |
814 | llvm_unreachable("Not supported Fused ElemKind" ); |
815 | } |
816 | } |
817 | |
818 | bool ConvertToNode::verify() const { |
819 | TypeRef srcTy = getInput().getType(); |
820 | TypeRef dstTy = getResult().getType(); |
821 | const bool srcIsFused = isFusedQuantizedElemKind(srcTy->getElementType()); |
822 | const bool dstIsFused = isFusedQuantizedElemKind(dstTy->getElementType()); |
823 | |
824 | bool isValid = expectCompareTrue( |
825 | "Conversion of src and dst with mismatched fused property is not yet " |
826 | "implemented" , |
827 | (srcIsFused && dstIsFused) || (!srcIsFused && !dstIsFused), true, this); |
828 | |
829 | if (srcIsFused && dstIsFused) { |
830 | size_t srcNumCols = getNumDataColumnsFromFused(srcTy); |
831 | size_t dstNumCols = getNumDataColumnsFromFused(dstTy); |
832 | return expectCompareTrue("Shapes of data for fused kinds do not match" , |
833 | srcNumCols, dstNumCols, this); |
834 | } |
835 | |
836 | isValid &= checkSameShape(getInput(), getResult(), this); |
837 | isValid &= expectCompareTrue( |
838 | "Quantized conversion should use Dequantize, Quantize and Rescale" , |
839 | srcTy->isQuantizedType() || dstTy->isQuantizedType(), false, this); |
840 | return isValid; |
841 | } |
842 | |
843 | bool MaxPoolNode::verify() const { |
844 | switch (getLayout()) { |
845 | case NHWC: |
846 | return verifyPool<ShapeNHWC>(getInput(), getResult(), Kernels_, Strides_, |
847 | Pads_, |
848 | /* isAvgPool */ false); |
849 | case NCHW: |
850 | return verifyPool<ShapeNCHW>(getInput(), getResult(), Kernels_, Strides_, |
851 | Pads_, |
852 | /* isAvgPool */ false); |
853 | default: // MaxPool3D is unsupported |
854 | return false; |
855 | } |
856 | } |
857 | |
858 | bool AvgPoolNode::verify() const { |
859 | switch (getLayout()) { |
860 | case NHWC: |
861 | return verifyPool<ShapeNHWC>(getInput(), getResult(), Kernels_, Strides_, |
862 | Pads_); |
863 | case NCHW: |
864 | return verifyPool<ShapeNCHW>(getInput(), getResult(), Kernels_, Strides_, |
865 | Pads_); |
866 | case NTHWC: |
867 | return verifyPool3D<ShapeNTHWC>(getInput(), getResult(), Kernels_, Strides_, |
868 | Pads_); |
869 | case NCTHW: |
870 | return verifyPool3D<ShapeNCTHW>(getInput(), getResult(), Kernels_, Strides_, |
871 | Pads_); |
872 | default: |
873 | llvm_unreachable("Unsupported format" ); |
874 | } |
875 | } |
876 | |
877 | bool AdaptiveAvgPoolGradNode::verify() const { |
878 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
879 | getGradOfInputNamedInput(), this); |
880 | isValid &= verifyOutputAndGradOutputTypes( |
881 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
882 | |
883 | ShapeNHWC idim(getInput().getType()->dims()); |
884 | ShapeNHWC odim(getOriginalOutputForResult().getType()->dims()); |
885 | |
886 | isValid &= expectCompareTrue( |
887 | "expected the same number of channels for input and output" , odim.c, |
888 | idim.c, this); |
889 | |
890 | isValid &= expectCompareTrue( |
891 | "expected the same number of batches for input and output" , odim.n, |
892 | idim.n, this); |
893 | |
894 | isValid &= expectCompareTrue("height too small for averaging area" , odim.h, |
895 | idim.h, this, CompareOperatorLessEqual<dim_t>()); |
896 | |
897 | isValid &= expectCompareTrue("width too small for averaging area" , odim.w, |
898 | idim.w, this, CompareOperatorLessEqual<dim_t>()); |
899 | |
900 | return isValid; |
901 | } |
902 | |
903 | bool AdaptiveAvgPoolNode::verify() const { |
904 | bool isValid = checkTypeIgnoreShape(getInput(), getResult(), this); |
905 | |
906 | TypeRef inTy = getInput().getType(); |
907 | TypeRef outTy = getResult().getType(); |
908 | |
909 | isValid &= expectCompareTrue("Input should have 4 dimensions" , |
910 | inTy->dims().size(), (size_t)4, this); |
911 | |
912 | isValid &= expectCompareTrue("Output should have 4 dimensions" , |
913 | outTy->dims().size(), (size_t)4, this); |
914 | |
915 | if (!isValid) { |
916 | return false; |
917 | } |
918 | |
919 | isValid &= expectCompareTrue( |
920 | "Output should have the same number of batches as the input" , |
921 | inTy->dims()[0], outTy->dims()[0], this); |
922 | |
923 | isValid &= expectCompareTrue( |
924 | "Output should have the same number of channels as the input" , |
925 | inTy->dims()[3], outTy->dims()[3], this); |
926 | |
927 | isValid &= expectCompareTrue( |
928 | "Output should not have more height than the input" , inTy->dims()[1], |
929 | outTy->dims()[1], this, CompareOperatorGreaterEqual<dim_t>()); |
930 | |
931 | isValid &= expectCompareTrue( |
932 | "Output should not have more width than the input" , inTy->dims()[2], |
933 | outTy->dims()[2], this, CompareOperatorGreaterEqual<dim_t>()); |
934 | |
935 | return isValid; |
936 | } |
937 | |
938 | bool MaxPoolGradNode::verify() const { |
939 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
940 | getGradOfInputNamedInput(), this); |
941 | isValid &= verifyOutputAndGradOutputTypes( |
942 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
943 | |
944 | if (getLayout() == NHWC) { |
945 | isValid &= verifyPool<ShapeNHWC>( |
946 | getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(), |
947 | Kernels_, Strides_, Pads_, /* isAvgPool */ false); |
948 | } else { |
949 | isValid &= verifyPool<ShapeNCHW>( |
950 | getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(), |
951 | Kernels_, Strides_, Pads_, /* isAvgPool */ false); |
952 | } |
953 | return isValid; |
954 | } |
955 | |
956 | bool AvgPoolGradNode::verify() const { |
957 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
958 | getGradOfInputNamedInput(), this); |
959 | isValid &= verifyOutputAndGradOutputTypes( |
960 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
961 | |
962 | switch (getLayout()) { |
963 | case NHWC: |
964 | return isValid && |
965 | verifyPool<ShapeNHWC>(getGradOfInputNamedInput(), |
966 | getGradOfOriginalOutputNamedResult(), Kernels_, |
967 | Strides_, Pads_); |
968 | case NCHW: |
969 | return isValid && |
970 | verifyPool<ShapeNCHW>(getGradOfInputNamedInput(), |
971 | getGradOfOriginalOutputNamedResult(), Kernels_, |
972 | Strides_, Pads_); |
973 | case NTHWC: |
974 | return isValid && |
975 | verifyPool3D<ShapeNTHWC>(getGradOfInputNamedInput(), |
976 | getGradOfOriginalOutputNamedResult(), |
977 | Kernels_, Strides_, Pads_); |
978 | case NCTHW: |
979 | return isValid && |
980 | verifyPool3D<ShapeNCTHW>(getGradOfInputNamedInput(), |
981 | getGradOfOriginalOutputNamedResult(), |
982 | Kernels_, Strides_, Pads_); |
983 | default: |
984 | llvm_unreachable("Unsupported format" ); |
985 | } |
986 | } |
987 | |
988 | bool MatMulNode::verify() const { |
989 | auto lhs = getLHS(); |
990 | auto rhs = getRHS(); |
991 | auto dest = getResult(); |
992 | |
993 | auto LDims = lhs.dims(); |
994 | auto RDims = rhs.dims(); |
995 | auto DDims = dest.dims(); |
996 | bool isValid = expectCompareTrue("LHS input must be 2 dimensional." , |
997 | LDims.size(), size_t(2), this); |
998 | isValid &= expectCompareTrue("RHS input must be 2 dimensional." , RDims.size(), |
999 | size_t(2), this); |
1000 | isValid &= expectCompareTrue("Invalid MatMul dimensions" , DDims.size(), |
1001 | size_t(2), this); |
1002 | |
1003 | auto elem = dest.getType()->getElementType(); |
1004 | isValid &= checkType(lhs, elem, this); |
1005 | isValid &= checkType(rhs, elem, this); |
1006 | |
1007 | isValid &= |
1008 | expectCompareTrue("Invalid row dimensions" , LDims[0], DDims[0], this); |
1009 | isValid &= |
1010 | expectCompareTrue("Invalid column dimensions" , RDims[1], DDims[1], this); |
1011 | return isValid; |
1012 | } |
1013 | |
1014 | bool BatchMatMulNode::verify() const { |
1015 | auto LHS = getLHS(); |
1016 | auto RHS = getRHS(); |
1017 | auto dest = getResult(); |
1018 | |
1019 | bool isValid = expectCompareTrue("LHS input must be 3 dimensional." , |
1020 | LHS.dims().size(), size_t(3), this); |
1021 | isValid &= expectCompareTrue("RHS input must be 3 dimensional." , |
1022 | RHS.dims().size(), size_t(3), this); |
1023 | isValid &= expectCompareTrue("Result must be 3 dimensional." , |
1024 | dest.dims().size(), size_t(3), this); |
1025 | isValid &= expectCompareTrue("LHS and RHS inputs must have same batch size." , |
1026 | LHS.dims()[0], RHS.dims()[0], this); |
1027 | isValid &= expectCompareTrue("Result must have same batch size as inputs." , |
1028 | LHS.dims()[0], dest.dims()[0], this); |
1029 | |
1030 | const dim_t numBatches = LHS.dims()[0]; |
1031 | const dim_t N = LHS.dims()[1]; |
1032 | const dim_t M = LHS.dims()[2]; |
1033 | const dim_t P = RHS.dims()[2]; |
1034 | isValid &= expectCompareTrue("Inputs have invalid dimensions." , RHS.dims()[1], |
1035 | M, this); |
1036 | isValid &= expectCompareTrue("Result has invalid dimensions given inputs." , |
1037 | dest.dims(), {numBatches, N, P}, this); |
1038 | |
1039 | auto elemType = dest.getType()->getElementType(); |
1040 | isValid &= checkType(LHS, elemType, this); |
1041 | isValid &= checkType(RHS, elemType, this); |
1042 | |
1043 | return isValid; |
1044 | } |
1045 | |
1046 | bool SigmoidNode::verify() const { |
1047 | return verifyActivation(getInput(), getResult()); |
1048 | } |
1049 | |
1050 | bool SigmoidGradNode::verify() const { |
1051 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
1052 | getGradOfInputNamedInput(), this); |
1053 | isValid &= verifyOutputAndGradOutputTypes( |
1054 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
1055 | isValid &= verifyActivation(getGradOfInputNamedInput(), |
1056 | getGradOfOriginalOutputNamedResult()); |
1057 | return isValid; |
1058 | } |
1059 | |
1060 | bool SoftPlusNode::verify() const { |
1061 | return verifyActivation(getInput(), getResult()); |
1062 | } |
1063 | |
1064 | bool SwishNode::verify() const { |
1065 | return verifyActivation(getInput(), getResult()); |
1066 | } |
1067 | |
1068 | bool TanhNode::verify() const { |
1069 | return verifyActivation(getInput(), getResult()); |
1070 | } |
1071 | |
1072 | bool TanhGradNode::verify() const { |
1073 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
1074 | getGradOfInputNamedInput(), this); |
1075 | isValid &= verifyOutputAndGradOutputTypes( |
1076 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
1077 | isValid &= verifyActivation(getGradOfInputNamedInput(), |
1078 | getGradOfOriginalOutputNamedResult()); |
1079 | return isValid; |
1080 | } |
1081 | |
1082 | bool LogitNode::verify() const { |
1083 | const Node *parent = getResult().getNode(); |
1084 | bool isValid = checkSameType(getInput(), getResult(), parent); |
1085 | isValid &= checkSameShape(getInput(), getResult(), parent); |
1086 | isValid &= expectCompareTrue( |
1087 | "Clamping parameter eps must be strictly positive" , getEpsilon(), 0.0f, |
1088 | this, CompareOperatorGreaterThan<float>()); |
1089 | isValid &= |
1090 | expectCompareTrue("Clamping parameter eps must be less than 0.5" , |
1091 | getEpsilon(), 0.5f, this, CompareOperatorLess<float>()); |
1092 | return isValid; |
1093 | } |
1094 | |
1095 | bool ExpNode::verify() const { |
1096 | const Node *parent = getResult().getNode(); |
1097 | bool isValid = |
1098 | checkSameIsQuantized(getInput().getType(), getResult().getType(), parent); |
1099 | |
1100 | if (getInput().getType()->isQuantizedType()) { |
1101 | isValid &= checkType(getInput(), getResult().getElementType(), |
1102 | getResult().getNode()); |
1103 | isValid &= checkSameShape(getInput(), getResult(), parent); |
1104 | } else { |
1105 | isValid &= checkSameType(getInput(), getResult(), parent); |
1106 | } |
1107 | |
1108 | return isValid; |
1109 | } |
1110 | |
1111 | bool BucketizeNode::verify() const { |
1112 | bool isValid = checkSameShape(getInput(), getResult(), this); |
1113 | isValid &= !getBoundaries().empty(); |
1114 | isValid &= std::is_sorted(getBoundaries().begin(), getBoundaries().end()); |
1115 | return isValid; |
1116 | } |
1117 | |
1118 | bool SoftMaxNode::verify() const { |
1119 | return verifySoftMax(getInput(), getResult()); |
1120 | } |
1121 | |
1122 | bool SoftMaxGradNode::verify() const { |
1123 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
1124 | getGradOfInputNamedInput(), this); |
1125 | isValid &= verifyInputAndGradInputTypes(getSelected(), |
1126 | getGradOfInputNamedSelected(), this); |
1127 | isValid &= verifyOutputAndGradOutputTypes( |
1128 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
1129 | isValid &= verifySoftMax(getGradOfInputNamedInput(), |
1130 | getGradOfOriginalOutputNamedResult()); |
1131 | return isValid; |
1132 | } |
1133 | |
1134 | bool LogSoftMaxNode::verify() const { |
1135 | return verifyLogSoftMax(getInput(), getResult()); |
1136 | } |
1137 | |
1138 | bool LogSoftMaxGradNode::verify() const { |
1139 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
1140 | getGradOfInputNamedInput(), this); |
1141 | isValid &= ((verifyInputAndGradInputTypes( |
1142 | getSelected(), getGradOfInputNamedSelected(), this)) |
1143 | ? 1 |
1144 | : 0); |
1145 | isValid &= ((verifyOutputAndGradOutputTypes( |
1146 | getOriginalOutputForResult(), |
1147 | getGradOfOriginalOutputNamedResult(), this)) |
1148 | ? 1 |
1149 | : 0); |
1150 | isValid &= ((verifyLogSoftMax(getGradOfInputNamedInput(), |
1151 | getGradOfOriginalOutputNamedResult())) |
1152 | ? 1 |
1153 | : 0); |
1154 | return isValid; |
1155 | } |
1156 | |
1157 | bool CrossEntropyLossNode::verify() const { |
1158 | return verifyCrossEntropyLoss(getP(), getCE(), getLabels()); |
1159 | } |
1160 | |
1161 | bool CrossEntropyLossGradNode::verify() const { |
1162 | bool isValid = verifyInputAndGradInputTypes( |
1163 | getLabels(), getGradOfInputNamedLabels(), this); |
1164 | isValid &= verifyInputAndGradInputTypes(getP(), getGradOfInputNamedP(), this); |
1165 | isValid &= verifyOutputAndGradOutputTypes( |
1166 | getOriginalOutputForCE(), getGradOfOriginalOutputNamedCE(), this); |
1167 | isValid &= verifyCrossEntropyLoss(getGradOfInputNamedP(), |
1168 | getGradOfOriginalOutputNamedCE(), |
1169 | getGradOfInputNamedLabels()); |
1170 | return isValid; |
1171 | } |
1172 | |
1173 | bool ReshapeNode::verify() const { |
1174 | bool isValid = expectCompareTrue("Reshape into a different size" , |
1175 | getResult().getType()->size(), |
1176 | getInput().getType()->size(), this); |
1177 | isValid &= checkTypeIgnoreShape(getResult(), getInput(), this); |
1178 | return isValid; |
1179 | } |
1180 | |
1181 | bool TransposeNode::verify() const { |
1182 | auto dest = getResult(); |
1183 | auto src = getInput(); |
1184 | ShapeVector shape; |
1185 | |
1186 | auto dims = src.dims(); |
1187 | for (size_t i = 0; i < dims.size(); i++) { |
1188 | shape.push_back(dims[Shuffle_[i]]); |
1189 | } |
1190 | |
1191 | bool isValid = expectCompareTrue("Invalid transpose dims" , dest.dims(), |
1192 | llvm::makeArrayRef(shape), this); |
1193 | isValid &= checkTypeIgnoreShape(dest, src, this); |
1194 | return isValid; |
1195 | } |
1196 | |
1197 | bool FlipNode::verify() const { |
1198 | auto dest = getResult(); |
1199 | auto src = getInput(); |
1200 | dim_t axis = getAxis(); |
1201 | bool isValid = checkSameType(src, dest, this); |
1202 | isValid &= expectCompareTrue("Invalid axis" , axis, (dim_t)src.dims().size(), |
1203 | this, CompareOperatorLess<dim_t>()); |
1204 | return isValid; |
1205 | } |
1206 | |
1207 | bool ChannelShuffleNode::verify() const { |
1208 | bool isValid = expectCompareTrue("Channel shuffle into a different size." , |
1209 | getResult().getType()->size(), |
1210 | getInput().getType()->size(), this); |
1211 | isValid &= checkTypeIgnoreShape(getResult(), getInput(), this); |
1212 | return isValid; |
1213 | } |
1214 | |
1215 | bool SplatNode::verify() const { return true; } |
1216 | |
1217 | bool TouchNode::verify() const { return true; } |
1218 | |
1219 | bool TraceEventNode::verify() const { return true; } |
1220 | |
1221 | bool ClipNode::verify() const { |
1222 | bool isValid = |
1223 | expectCompareTrue("Clip max must be greater than min" , getMin(), getMax(), |
1224 | this, CompareOperatorLess<float>()); |
1225 | if (getInput().getType()->isQuantizedType()) { |
1226 | isValid &= |
1227 | checkSameIsQuantized(getInput().getType(), getResult().getType(), this); |
1228 | isValid &= checkSameShape(getInput(), getResult(), this); |
1229 | } else { |
1230 | isValid &= checkSameType(getInput(), getResult(), this); |
1231 | } |
1232 | return isValid; |
1233 | } |
1234 | |
1235 | bool InsertTensorNode::verify() const { |
1236 | auto dest = getBig(); |
1237 | auto src = getSmall(); |
1238 | auto offsets = getStart(); |
1239 | dim_t numDims = dest.dims().size(); |
1240 | dim_t axis = getAxis(); |
1241 | dim_t count = getCount(); |
1242 | |
1243 | bool isValid = expectCompareTrue("Invalid number of dimensions" , numDims, |
1244 | (dim_t)src.dims().size(), this); |
1245 | isValid &= expectCompareTrue("Invalid number of dimensions for offsets" , |
1246 | numDims, (dim_t)offsets.size(), this); |
1247 | |
1248 | if (!isValid) { |
1249 | // The following loop may be out-of-bound if the previous |
1250 | // comparisons failed. |
1251 | return false; |
1252 | } |
1253 | |
1254 | isValid &= checkType(dest, src.getType()->getElementType(), this); |
1255 | if (dest.getType()->isQuantizedType()) { |
1256 | isValid &= expectCompareTrue("Scales of Big and Small must match." , |
1257 | src.getType()->getScale(), |
1258 | dest.getType()->getScale(), this); |
1259 | isValid &= expectCompareTrue("Offsets of Big and Small must match." , |
1260 | src.getType()->getOffset(), |
1261 | dest.getType()->getOffset(), this); |
1262 | } |
1263 | |
1264 | for (unsigned i = 0; i < numDims; i++) { |
1265 | // TODO: We could come up with a mechanism to lazy compute that |
1266 | // string since it is going to be used only in case of an error. |
1267 | // However, this function is not performance critical so leave it |
1268 | // this way for now. |
1269 | std::string msg = std::to_string(i); |
1270 | msg = "out of bounds for index " + msg; |
1271 | isValid &= expectCompareTrue(msg.c_str(), src.dims()[i] + offsets[i], |
1272 | dest.dims()[i], this, |
1273 | CompareOperatorLessEqual<dim_t>()); |
1274 | } |
1275 | |
1276 | isValid &= expectCompareTrue("Invalid axis" , axis, (dim_t)src.dims().size(), |
1277 | this, CompareOperatorLessEqual<dim_t>()); |
1278 | for (dim_t i = 0; i < src.dims().size(); i++) { |
1279 | dim_t mul = (i == axis) ? count : 1; |
1280 | std::string msg = std::to_string(i); |
1281 | msg = "Small does not fit inside Big for index " + msg; |
1282 | isValid &= |
1283 | expectCompareTrue(msg.c_str(), src.dims()[i] * mul, dest.dims()[i], |
1284 | this, CompareOperatorLessEqual<dim_t>()); |
1285 | } |
1286 | return isValid; |
1287 | } |
1288 | |
1289 | bool SliceNode::verify() const { |
1290 | auto dest = getResult(); |
1291 | auto src = getInput(); |
1292 | auto offsets = getStart(); |
1293 | size_t numDims = dest.dims().size(); |
1294 | bool isValid = expectCompareTrue("Invalid number of dimensions" , numDims, |
1295 | src.dims().size(), this); |
1296 | isValid &= expectCompareTrue("Invalid number of dimensions" , numDims, |
1297 | offsets.size(), this); |
1298 | |
1299 | if (!isValid) { |
1300 | // The following loop may be out-of-bound if the previous |
1301 | // comparisons failed. |
1302 | return false; |
1303 | } |
1304 | |
1305 | for (unsigned i = 0; i < numDims; i++) { |
1306 | std::string msg = std::to_string(i); |
1307 | msg = "out of bounds for index " + msg; |
1308 | isValid &= expectCompareTrue(msg.c_str(), dest.dims()[i] + offsets[i], |
1309 | src.dims()[i], this, |
1310 | CompareOperatorLessEqual<dim_t>()); |
1311 | } |
1312 | isValid &= checkNotQuantizedOrSameParams(dest.getType(), src.getType(), this); |
1313 | return isValid; |
1314 | } |
1315 | |
1316 | bool TileNode::verify() const { |
1317 | auto dest = getResult(); |
1318 | auto src = getInput(); |
1319 | size_t axis = getAxis(); |
1320 | unsigned count = getCount(); |
1321 | |
1322 | bool isValid = expectCompareTrue("Invalid axis" , axis, src.dims().size(), |
1323 | this, CompareOperatorLessEqual<size_t>()); |
1324 | |
1325 | for (dim_t i = 0; i < src.dims().size(); i++) { |
1326 | dim_t mul = (i == axis) ? count : 1; |
1327 | std::string msg = std::to_string(i); |
1328 | msg = "Incorrect output shape for dim " + msg; |
1329 | isValid &= expectCompareTrue(msg.c_str(), src.dims()[i] * mul, |
1330 | dest.dims()[i], this); |
1331 | } |
1332 | isValid &= checkTypeIgnoreShape(src, dest, this); |
1333 | return isValid; |
1334 | } |
1335 | |
1336 | bool BatchNormalizationNode::verify() const { |
1337 | return verifyBatchNormalization(getInput(), getResult(), getBias(), |
1338 | getScale(), getMean(), getVar(), ChannelIdx_); |
1339 | } |
1340 | |
1341 | bool InstanceNormalizationNode::verify() const { |
1342 | return verifyInstanceNormalization(getInput(), getResult(), getBias(), |
1343 | getScale(), ChannelIdx_); |
1344 | } |
1345 | |
1346 | bool LayerNormalizationNode::verify() const { |
1347 | auto dest = getResult(); |
1348 | auto src = getInput(); |
1349 | auto scale = getScale(); |
1350 | auto bias = getBias(); |
1351 | |
1352 | // Check input and output have same ElemKind. |
1353 | bool isValid = checkType(src, dest.getElementType(), this); |
1354 | |
1355 | // Check scale and bias have same ElemKind |
1356 | isValid &= checkType(bias, scale.getElementType(), this); |
1357 | |
1358 | // Check inputs/outputs and scale/bias match shapes. |
1359 | isValid &= checkSameShape(src, dest, this); |
1360 | isValid &= checkSameShape(scale, bias, this); |
1361 | |
1362 | // Check that the dims of scale and bias match the end of src. |
1363 | auto srcDims = src.getType()->dims(); |
1364 | auto scaleDims = scale.getType()->dims(); |
1365 | isValid &= expectCompareTrue("Expected input to have more dims than scale" , |
1366 | srcDims.size(), scaleDims.size(), this, |
1367 | CompareOperatorGreaterThan<size_t>()); |
1368 | for (size_t i = 0; i < scaleDims.size(); ++i) { |
1369 | size_t scaleI = scaleDims.size() - i - 1; |
1370 | size_t srcI = srcDims.size() - i - 1; |
1371 | isValid &= |
1372 | expectCompareTrue("Expected scale dims to match the end of src dims" , |
1373 | scaleDims[scaleI], srcDims[srcI], this); |
1374 | } |
1375 | |
1376 | return isValid; |
1377 | } |
1378 | |
1379 | bool BatchNormalizationGradNode::verify() const { |
1380 | bool isValid = |
1381 | verifyInputAndGradInputTypes(getBias(), getGradOfInputNamedBias(), this); |
1382 | isValid &= verifyInputAndGradInputTypes(getInput(), |
1383 | getGradOfInputNamedInput(), this); |
1384 | isValid &= |
1385 | verifyInputAndGradInputTypes(getMean(), getGradOfInputNamedMean(), this); |
1386 | isValid &= verifyInputAndGradInputTypes(getScale(), |
1387 | getGradOfInputNamedScale(), this); |
1388 | isValid &= |
1389 | verifyInputAndGradInputTypes(getVar(), getGradOfInputNamedVar(), this); |
1390 | isValid &= verifyOutputAndGradOutputTypes( |
1391 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
1392 | isValid &= verifyBatchNormalization( |
1393 | getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(), |
1394 | getGradOfInputNamedBias(), getGradOfInputNamedScale(), |
1395 | getGradOfInputNamedMean(), getGradOfInputNamedVar(), ChannelIdx_); |
1396 | return isValid; |
1397 | } |
1398 | |
1399 | bool MeanVarNormalizationNode::verify() const { |
1400 | return checkType(getMean(), ElemKind::FloatTy, this) && |
1401 | checkSameType(getMean(), getVar(), this); |
1402 | } |
1403 | |
1404 | bool LocalResponseNormalizationNode::verify() const { |
1405 | return verifyLocalResponseNormalization(getInput(), getResult()); |
1406 | } |
1407 | |
1408 | bool LocalResponseNormalizationGradNode::verify() const { |
1409 | bool isValid = verifyInputAndGradInputTypes(getInput(), |
1410 | getGradOfInputNamedInput(), this); |
1411 | isValid &= verifyOutputAndGradOutputTypes( |
1412 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this); |
1413 | isValid &= verifyLocalResponseNormalization( |
1414 | getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult()); |
1415 | return isValid; |
1416 | } |
1417 | |
1418 | #define VERIFY_UNARY_LOGICAL(NODE_NAME_) \ |
1419 | bool NODE_NAME_##Node::verify() const { \ |
1420 | bool isValid = checkSameShape(getInput(), getResult(), this); \ |
1421 | isValid &= checkType(getInput(), ElemKind::BoolTy, this); \ |
1422 | isValid &= checkType(getResult(), ElemKind::BoolTy, this); \ |
1423 | return isValid; \ |
1424 | } |
1425 | VERIFY_UNARY_LOGICAL(Not) |
1426 | #undef VERIFY_UNARY_LOGICAL |
1427 | |
1428 | bool SignNode::verify() const { |
1429 | if (getResult().getType()->isQuantizedType()) { |
1430 | bool isValid = checkSameShape(getInput(), getResult(), this); |
1431 | isValid &= |
1432 | checkType(getResult(), getInput().getType()->getElementType(), this); |
1433 | return isValid; |
1434 | } |
1435 | return checkSameType(getInput(), getResult(), this); |
1436 | } |
1437 | |
1438 | #define VERIFY_BINARY_LOGICAL(NODE_NAME_) \ |
1439 | bool NODE_NAME_##Node::verify() const { \ |
1440 | bool isValid = checkSameShape(getLHS(), getResult(), this); \ |
1441 | isValid &= checkSameShape(getRHS(), getResult(), this); \ |
1442 | isValid &= checkType(getLHS(), ElemKind::BoolTy, this); \ |
1443 | isValid &= checkType(getRHS(), ElemKind::BoolTy, this); \ |
1444 | isValid &= checkType(getResult(), ElemKind::BoolTy, this); \ |
1445 | return isValid; \ |
1446 | } |
1447 | VERIFY_BINARY_LOGICAL(And) |
1448 | VERIFY_BINARY_LOGICAL(Or) |
1449 | VERIFY_BINARY_LOGICAL(Xor) |
1450 | #undef VERIFY_BINARY_LOGICAL |
1451 | |
1452 | #define VERIFY_BINARY(NODE_NAME_) \ |
1453 | bool NODE_NAME_##Node::verify() const { \ |
1454 | bool isValid = checkSameShape(getLHS(), getResult(), this); \ |
1455 | isValid &= checkSameShape(getRHS(), getResult(), this); \ |
1456 | isValid &= checkSameType(getLHS(), getResult(), this); \ |
1457 | isValid &= checkSameType(getRHS(), getResult(), this); \ |
1458 | return isValid; \ |
1459 | } |
1460 | VERIFY_BINARY(BitwiseAnd) |
1461 | VERIFY_BINARY(BitwiseOr) |
1462 | VERIFY_BINARY(BitwiseXor) |
1463 | #undef VERIFY_BINARY |
1464 | |
1465 | #define VERIFY_UNARY_ARITHMETIC(NODE_NAME_) \ |
1466 | bool NODE_NAME_##Node::verify() const { \ |
1467 | return checkSameShape(getInput(), getResult(), this); \ |
1468 | } |
1469 | VERIFY_UNARY_ARITHMETIC(Abs); |
1470 | VERIFY_UNARY_ARITHMETIC(Neg); |
1471 | VERIFY_UNARY_ARITHMETIC(Floor); |
1472 | VERIFY_UNARY_ARITHMETIC(Ceil); |
1473 | VERIFY_UNARY_ARITHMETIC(Round); |
1474 | VERIFY_UNARY_ARITHMETIC(Sqrt); |
1475 | VERIFY_UNARY_ARITHMETIC(Rsqrt); |
1476 | VERIFY_UNARY_ARITHMETIC(Reciprocal); |
1477 | VERIFY_UNARY_ARITHMETIC(Sin); |
1478 | VERIFY_UNARY_ARITHMETIC(Cos); |
1479 | VERIFY_UNARY_ARITHMETIC(Erf); |
1480 | VERIFY_UNARY_ARITHMETIC(Truncate); |
1481 | VERIFY_UNARY_ARITHMETIC(BitwiseNot); |
1482 | #undef VERIFY_UNARY_ARITHMETIC |
1483 | |
1484 | #define VERIFY_ARITHMETIC(NODE_NAME_) \ |
1485 | bool NODE_NAME_##Node::verify() const { \ |
1486 | return verifyArithmetic(getLHS(), getRHS(), getResult()); \ |
1487 | } |
1488 | VERIFY_ARITHMETIC(Add); |
1489 | VERIFY_ARITHMETIC(Mul); |
1490 | VERIFY_ARITHMETIC(Sub); |
1491 | VERIFY_ARITHMETIC(Div); |
1492 | VERIFY_ARITHMETIC(FloorDiv); |
1493 | VERIFY_ARITHMETIC(Max); |
1494 | VERIFY_ARITHMETIC(Min); |
1495 | VERIFY_ARITHMETIC(Pow); |
1496 | #undef VERIFY_ARITHMETIC |
1497 | |
1498 | #define VERIFY_ARITHMETIC(NODE_NAME_) \ |
1499 | bool NODE_NAME_##Node::verify() const { \ |
1500 | bool isValid = verifyInputAndGradInputTypes( \ |
1501 | getLHS(), getGradOfInputNamedLHS(), this); \ |
1502 | isValid &= verifyInputAndGradInputTypes(getRHS(), \ |
1503 | getGradOfInputNamedRHS(), this); \ |
1504 | isValid &= verifyOutputAndGradOutputTypes( \ |
1505 | getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), \ |
1506 | this); \ |
1507 | isValid &= \ |
1508 | verifyArithmetic(getGradOfInputNamedLHS(), getGradOfInputNamedRHS(), \ |
1509 | getGradOfOriginalOutputNamedResult()); \ |
1510 | return isValid; \ |
1511 | } |
1512 | VERIFY_ARITHMETIC(AddGrad); |
1513 | VERIFY_ARITHMETIC(MulGrad); |
1514 | VERIFY_ARITHMETIC(SubGrad); |
1515 | VERIFY_ARITHMETIC(DivGrad); |
1516 | #undef VERIFY_ARITHMETIC |
1517 | |
1518 | #define VERIFY_CMP(NODE_NAME_) \ |
1519 | bool NODE_NAME_##Node::verify() const { \ |
1520 | bool isValid = checkSameShape(getLHS(), getRHS(), this); \ |
1521 | isValid &= checkSameShape(getResult(), getLHS(), this); \ |
1522 | isValid &= checkType(getLHS(), getRHS().getElementType(), this); \ |
1523 | isValid &= checkType(getResult(), ElemKind::BoolTy, this); \ |
1524 | return isValid; \ |
1525 | } |
1526 | |
1527 | VERIFY_CMP(CmpEQ) |
1528 | VERIFY_CMP(CmpNEQ) |
1529 | VERIFY_CMP(CmpLT) |
1530 | VERIFY_CMP(CmpLTE) |
1531 | #undef VERIFY_CMP |
1532 | |
1533 | // Trigonometric Ops |
1534 | #define VERIFY_TRIGONOMERTRIC_OPS(NODE_NAME_) \ |
1535 | bool NODE_NAME_##Node::verify() const { \ |
1536 | return checkSameShape(getInput(), getResult(), this); \ |
1537 | } |
1538 | VERIFY_TRIGONOMERTRIC_OPS(Acos); |
1539 | VERIFY_TRIGONOMERTRIC_OPS(Asin); |
1540 | VERIFY_TRIGONOMERTRIC_OPS(Atan); |
1541 | #undef VERIFY_UNARY_ARITHMETIC |
1542 | |
1543 | bool FmodNode::verify() const { |
1544 | auto res = getResult(); |
1545 | auto LHS = getLHS(); |
1546 | auto RHS = getRHS(); |
1547 | return checkSameShape(res, LHS, res.getNode()) && |
1548 | checkSameShape(LHS, RHS, res.getNode()) && |
1549 | LHS.getElementType() != ElemKind::Int8QTy && |
1550 | RHS.getElementType() != ElemKind::Int8QTy; |
1551 | } |
1552 | |
1553 | bool BatchedPairwiseDotProductNode::verify() const { |
1554 | auto inputs = getInputs(); |
1555 | |
1556 | bool isValid = inputs.size() > 1; |
1557 | |
1558 | if (isValid) { |
1559 | auto firstInput = inputs[0]; |
1560 | |
1561 | isValid &= firstInput.getElementType() == ElemKind::FloatTy; |
1562 | isValid &= firstInput.getType()->dims().size() == 2; |
1563 | |
1564 | for (auto &in : inputs) { |
1565 | isValid &= checkSameType(in, firstInput, this); |
1566 | } |
1567 | |
1568 | isValid &= getResult().getElementType() == ElemKind::FloatTy; |
1569 | isValid &= |
1570 | getResult().getType()->dims()[0] == firstInput.getType()->dims()[0]; |
1571 | isValid &= getResult().getType()->dims()[1] == |
1572 | inputs.size() * (inputs.size() - 1) / 2; |
1573 | } |
1574 | |
1575 | return isValid; |
1576 | } |
1577 | |
1578 | bool BatchedPairwiseDotProductGradNode::verify() const { return true; } |
1579 | |
1580 | bool BatchedAddNode::verify() const { |
1581 | auto batchShape = getBatch().dims(); |
1582 | auto rhsShape = getSlice().dims(); |
1583 | bool isValid = expectCompareTrue("Invalid shape" , batchShape.drop_front(), |
1584 | rhsShape, this); |
1585 | isValid &= checkSameShape(getBatch(), getResult(), this); |
1586 | |
1587 | if (getBatch().getType()->isQuantizedType()) { |
1588 | expectCompareTrue("Mismatched slice element types" , |
1589 | getSlice().getType()->isQuantizedType(), true, this); |
1590 | } else { |
1591 | isValid &= |
1592 | checkType(getBatch(), getSlice().getType()->getElementType(), this); |
1593 | } |
1594 | return isValid; |
1595 | } |
1596 | |
1597 | bool BatchedMulNode::verify() const { |
1598 | auto batchShape = getBatch().dims(); |
1599 | auto rhsShape = getSlice().dims(); |
1600 | bool isValid = expectCompareTrue("Invalid shape" , batchShape.drop_front(), |
1601 | rhsShape, this); |
1602 | isValid &= checkSameShape(getBatch(), getResult(), this); |
1603 | |
1604 | if (getBatch().getType()->isQuantizedType()) { |
1605 | expectCompareTrue("Mismatched slice element types" , |
1606 | getSlice().getType()->isQuantizedType(), true, this); |
1607 | } else { |
1608 | isValid &= |
1609 | checkType(getBatch(), getSlice().getType()->getElementType(), this); |
1610 | } |
1611 | return isValid; |
1612 | } |
1613 | |
1614 | bool BatchedReduceSumSquareNode::verify() const { |
1615 | bool isValid = checkType(getResult(), getBatch().getElementType(), this); |
1616 | |
1617 | isValid &= |
1618 | expectCompareTrue("Invalid shape" , getBatch().dims().size(), size_t(0), |
1619 | this, CompareOperatorGreaterThan<size_t>()); |
1620 | return isValid; |
1621 | } |
1622 | |
1623 | bool CumSumNode::verify() const { |
1624 | return checkSameType(getResult(), getInput(), this); |
1625 | } |
1626 | |
1627 | bool LengthsSumNode::verify() const { |
1628 | return expectCompareTrue("Lengths must be a 1D vector" , |
1629 | getLengths().dims().size(), size_t(1), this); |
1630 | } |
1631 | |
1632 | // Define verification for Reduction operations. |
1633 | #define DEFINE_BATCHED_REDUCTION_VERIFICATION(name) \ |
1634 | bool name##Node::verify() const { \ |
1635 | bool isValid = checkType(getResult(), getBatch().getElementType(), this); \ |
1636 | isValid &= expectCompareTrue("Invalid shape", getBatch().dims().size(), \ |
1637 | size_t(0), this, \ |
1638 | CompareOperatorGreaterThan<size_t>()); \ |
1639 | return isValid; \ |
1640 | } |
1641 | |
1642 | DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceAdd) |
1643 | DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceMean) |
1644 | DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceMin) |
1645 | DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceMax) |
1646 | DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceProd) |
1647 | |
1648 | #undef DEFINE_BATCHED_REDUCTION_VERIFICATION |
1649 | |
1650 | bool SparseLengthsSumNode::verify() const { |
1651 | return verifySparseLengthsSum(getResult(), getData(), getIndices(), |
1652 | getLengths()); |
1653 | } |
1654 | |
1655 | bool SparseLengthsSumGradNode::verify() const { |
1656 | // Same checks as SparseLengthsSumNode. |
1657 | bool isValid = verifySparseLengthsSum(getOriginalOutputForResult(), getData(), |
1658 | getIndices(), getLengths()); |
1659 | |
1660 | // Checks on gradient inputs/outputs. |
1661 | isValid &= checkSameType(getGradOfOriginalOutputNamedResult(), |
1662 | getOriginalOutputForResult(), this); |
1663 | isValid &= checkSameType(getGradOfInputNamedData(), getData(), this); |
1664 | isValid &= checkSameType(getGradOfInputNamedIndices(), getIndices(), this); |
1665 | isValid &= checkSameType(getGradOfInputNamedLengths(), getLengths(), this); |
1666 | return isValid; |
1667 | } |
1668 | |
1669 | bool SparseLengthsWeightedSumNode::verify() const { |
1670 | return verifySparseLengthsWeightedSum(getResult(), getData(), getWeights(), |
1671 | getIndices(), getLengths()); |
1672 | } |
1673 | |
1674 | bool SparseLengthsWeightedSumGradNode::verify() const { |
1675 | // Same checks as SparseLengthsWeightedSumNode. |
1676 | bool isValid = |
1677 | verifySparseLengthsWeightedSum(getOriginalOutputForResult(), getData(), |
1678 | getWeights(), getIndices(), getLengths()); |
1679 | |
1680 | // Checks on gradient inputs/outputs. |
1681 | isValid &= checkSameType(getGradOfOriginalOutputNamedResult(), |
1682 | getOriginalOutputForResult(), this); |
1683 | isValid &= checkSameType(getGradOfInputNamedData(), getData(), this); |
1684 | isValid &= checkSameType(getGradOfInputNamedWeights(), getWeights(), this); |
1685 | isValid &= checkSameType(getGradOfInputNamedIndices(), getIndices(), this); |
1686 | isValid &= checkSameType(getGradOfInputNamedLengths(), getLengths(), this); |
1687 | return isValid; |
1688 | } |
1689 | |
1690 | bool EmbeddingBagNode::verify() const { |
1691 | return verifyEmbeddingBag(getResult(), getData(), getWeights(), getIndices(), |
1692 | getOffsets()); |
1693 | } |
1694 | |
1695 | bool EmbeddingNode::verify() const { |
1696 | return verifyEmbedding(getResult(), getWeights(), getIndices()); |
1697 | } |
1698 | |
1699 | bool RowwiseQuantizedSparseLengthsWeightedSumNode::verify() const { |
1700 | bool isValid = checkType(getData(), ElemKind::UInt8QTy, this); |
1701 | isValid &= expectCompareTrue("Indices must be a 1D vector" , |
1702 | getIndices().dims().size(), size_t(1), this); |
1703 | isValid &= expectCompareTrue("Lengths must be a 1D vector" , |
1704 | getLengths().dims().size(), size_t(1), this); |
1705 | isValid &= expectCompareTrue("Weights must be a 1D vector" , |
1706 | getWeights().dims().size(), size_t(1), this); |
1707 | isValid &= expectCompareTrue("Scales must be a 1D vector" , |
1708 | getScales().dims().size(), size_t(1), this); |
1709 | isValid &= expectCompareTrue("Offsets must be a 1D vector" , |
1710 | getOffsets().dims().size(), size_t(1), this); |
1711 | isValid &= |
1712 | expectCompareTrue("Weights and Indices must have the same size" , |
1713 | getWeights().dims()[0], getIndices().dims()[0], this); |
1714 | isValid &= expectCompareTrue( |
1715 | "Scales and Data must have the same first dimension size" , |
1716 | getData().dims()[0], getScales().dims()[0], this); |
1717 | isValid &= expectCompareTrue( |
1718 | "Offsets and Data must have the same first dimension size" , |
1719 | getData().dims()[0], getOffsets().dims()[0], this); |
1720 | if (getUseFP16Accumulation()) { |
1721 | isValid &= expectCompareTrue( |
1722 | "Only use FP16 accumulation with FP16 version of Fused-RWQ-SLWS." , |
1723 | getResult().getType()->getElementType(), ElemKind::Float16Ty, this); |
1724 | } |
1725 | return isValid; |
1726 | } |
1727 | |
1728 | static bool verifyFusedRowwiseQuantizedSparseLengthsSum( |
1729 | NodeValue result, NodeValue data, NodeValue indices, NodeValue lengths, |
1730 | NodeValue weights, bool useFP16Accumulation, |
1731 | bool isEmbeddingBagByteRowwiseOffsets = false) { |
1732 | const Node *parent = result.getNode(); |
1733 | bool isValid = expectCompareTrue( |
1734 | "Input data must be Fused Quantized type" , |
1735 | isFusedQuantizedElemKind(data.getType()->getElementType()), true, parent); |
1736 | dim_t ; |
1737 | if (data.getType()->getElementType() == ElemKind::UInt8FusedQTy || |
1738 | data.getType()->getElementType() == ElemKind::UInt4FusedQTy) { |
1739 | extraCols = 2 * sizeof(float); |
1740 | } else { |
1741 | extraCols = 2 * sizeof(float16_t); |
1742 | } |
1743 | if (useFP16Accumulation) { |
1744 | isValid &= expectCompareTrue( |
1745 | "Only use FP16 accumulation with FP16 version of RWQ-SLWS." , |
1746 | result.getType()->getElementType(), ElemKind::Float16Ty, parent); |
1747 | } |
1748 | isValid &= checkType( |
1749 | indices, |
1750 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), |
1751 | parent); |
1752 | // For EmbeddingBagByteRowwiseOffsets lengths are really offsets and |
1753 | // can be either Int64ITy or Int64ITy. |
1754 | if (isEmbeddingBagByteRowwiseOffsets) { |
1755 | isValid &= checkType( |
1756 | lengths, |
1757 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), |
1758 | parent); |
1759 | } else { |
1760 | isValid &= checkType(lengths, ElemKind::Int32ITy, parent); |
1761 | } |
1762 | |
1763 | isValid &= expectCompareTrue("Indices must be a 1D vector" , |
1764 | indices.dims().size(), size_t(1), parent); |
1765 | isValid &= expectCompareTrue("Lengths must be a 1D vector" , |
1766 | lengths.dims().size(), size_t(1), parent); |
1767 | isValid &= expectCompareTrue("Data must be 2 dimensional." , |
1768 | data.dims().size(), size_t(2), parent); |
1769 | isValid &= expectCompareTrue("Data must have extra columns for scale/offset." , |
1770 | data.dims()[1], extraCols, parent, |
1771 | CompareOperatorGreaterEqual<dim_t>()); |
1772 | isValid &= expectCompareTrue("Result must be 2 dimensional." , |
1773 | result.dims().size(), size_t(2), parent); |
1774 | |
1775 | if (weights.getNode()) { |
1776 | isValid &= expectCompareTrue("Weights must be a 1D vector" , |
1777 | weights.dims().size(), size_t(1), parent); |
1778 | isValid &= expectCompareTrue("Weights and Indices must have the same size" , |
1779 | weights.dims()[0], indices.dims()[0], parent); |
1780 | } |
1781 | |
1782 | // Wrap this in isValid to prevent potential segfault if the result is |
1783 | // incorrectly shaped. |
1784 | if (isValid) { |
1785 | // If using 4-bit quantization for embeddings then the input is packed into |
1786 | // two elements per byte. |
1787 | dim_t finalSize = result.dims()[1]; |
1788 | if (data.getType()->getElementType() == ElemKind::UInt4FusedFP16QTy || |
1789 | data.getType()->getElementType() == ElemKind::UInt4FusedQTy) { |
1790 | finalSize /= 2; |
1791 | } |
1792 | isValid &= |
1793 | expectCompareTrue("Result output shape should have second dim without " |
1794 | "extra columns from scale/offset in Data." , |
1795 | finalSize + extraCols, data.dims()[1], parent); |
1796 | } |
1797 | return isValid; |
1798 | } |
1799 | |
1800 | bool EmbeddingBagByteRowwiseOffsetsNode::verify() const { |
1801 | return verifyFusedRowwiseQuantizedSparseLengthsSum( |
1802 | getResult(), getData(), getIndices(), getOffsets(), getWeights(), |
1803 | getUseFP16Accumulation(), /*isEmbeddingBagByteRowwiseOffsets*/ true); |
1804 | } |
1805 | |
1806 | bool FusedRowwiseQuantizedSparseLengthsWeightedSumNode::verify() const { |
1807 | return verifyFusedRowwiseQuantizedSparseLengthsSum( |
1808 | getResult(), getData(), getIndices(), getLengths(), getWeights(), |
1809 | getUseFP16Accumulation()); |
1810 | } |
1811 | |
1812 | bool FusedRowwiseQuantizedSparseLengthsSumNode::verify() const { |
1813 | return verifyFusedRowwiseQuantizedSparseLengthsSum( |
1814 | getResult(), getData(), getIndices(), getLengths(), nullptr, |
1815 | getUseFP16Accumulation()); |
1816 | } |
1817 | |
1818 | bool LengthsToRangesNode::verify() const { |
1819 | bool isValid = checkType(getResult(), getLengths().getElementType(), this); |
1820 | isValid &= checkType(getLengths(), ElemKind::Int32ITy, this); |
1821 | isValid &= expectCompareTrue("Lengths must be a 1D vector" , |
1822 | getLengths().dims().size(), size_t(1), this); |
1823 | isValid &= expectCompareTrue("Ranges must be a 2D vector" , |
1824 | getResult().dims().size(), size_t(2), this); |
1825 | isValid &= expectCompareTrue( |
1826 | "Lengths and Ranges must have the same outer dimensions" , |
1827 | getResult().dims()[0], getLengths().dims()[0], this); |
1828 | isValid &= expectCompareTrue("Inner dimension of Ranges must be 2" , |
1829 | getResult().dims()[1], dim_t(2), this); |
1830 | return isValid; |
1831 | } |
1832 | |
1833 | bool LengthsRangeFillNode::verify() const { |
1834 | bool isValid = checkType(getLengths(), ElemKind::Int32ITy, this); |
1835 | isValid &= checkType(getResult(), getLengths().getElementType(), this); |
1836 | isValid &= expectCompareTrue("Lengths must be a 1D vector" , |
1837 | getLengths().dims().size(), size_t(1), this); |
1838 | isValid &= expectCompareTrue("Result must be a 1D vector" , |
1839 | getResult().dims().size(), size_t(1), this); |
1840 | return isValid; |
1841 | } |
1842 | |
1843 | bool BatchSparseToDenseNode::verify() const { |
1844 | bool isValid = checkType(getResult(), getValues().getElementType(), this); |
1845 | isValid &= |
1846 | checkType(getIndices(), {ElemKind::Int64ITy, ElemKind::Int32ITy}, this); |
1847 | isValid &= |
1848 | checkType(getLengths(), {ElemKind::Int64ITy, ElemKind::Int32ITy}, this); |
1849 | isValid &= expectCompareTrue("Lengths must be a 1D vector" , |
1850 | getLengths().dims().size(), size_t(1), this); |
1851 | isValid &= expectCompareTrue("Indices must be a 1D vector" , |
1852 | getIndices().dims().size(), size_t(1), this); |
1853 | isValid &= expectCompareTrue("Indices and Values must have the same shape" , |
1854 | getIndices().dims(), getValues().dims(), this); |
1855 | isValid &= expectCompareTrue( |
1856 | "The size of Lengths and batches in the result should be the same" , |
1857 | getLengths().dims()[0], getResult().dims()[0], this); |
1858 | isValid &= expectCompareTrue( |
1859 | "The second dimension of the result should be equal to dense_last_dim" , |
1860 | getDenseLastDim(), (unsigned)getResult().dims()[1], this); |
1861 | return isValid; |
1862 | } |
1863 | |
1864 | bool FillExamplesWithIndicatorNode::verify() const { |
1865 | bool isValid = checkType(getResult(), getData().getElementType(), this); |
1866 | isValid &= checkType( |
1867 | getIndicator(), |
1868 | {ElemKind::Int64ITy, ElemKind::Int32ITy, ElemKind::BoolTy}, this); |
1869 | isValid &= expectCompareTrue("Indicator must be a 1D vector" , |
1870 | getIndicator().dims().size(), size_t(1), this); |
1871 | isValid &= expectCompareTrue("Data must have at least one dimension" , |
1872 | getData().dims().size(), size_t(1), this, |
1873 | CompareOperatorGreaterEqual<size_t>()); |
1874 | return isValid; |
1875 | } |
1876 | |
1877 | bool SparseToDenseMaskNode::verify() const { |
1878 | bool isValid = checkType(getResult(), getValues().getElementType(), this); |
1879 | isValid &= checkType(getResult(), getDefaultValue().getElementType(), this); |
1880 | isValid &= checkType(getIndices(), ElemKind::Int64ITy, this); |
1881 | isValid &= checkType(getLengths(), ElemKind::Int32ITy, this); |
1882 | isValid &= expectCompareTrue("Indices must be a 1D vector" , |
1883 | getIndices().dims().size(), size_t(1), this); |
1884 | isValid &= expectCompareTrue("Lengths must be a scalar or 1D vector" , |
1885 | getLengths().dims().size(), {0, 1}, this); |
1886 | isValid &= |
1887 | expectCompareTrue("Indices and Values must have the same first dimension" , |
1888 | getIndices().dims()[0], getValues().dims()[0], this); |
1889 | isValid &= expectCompareTrue( |
1890 | "Values[i] must have the same dimensions as DefaultValue" , |
1891 | getValues().dims().slice(1), getDefaultValue().dims(), this); |
1892 | return isValid; |
1893 | } |
1894 | |
1895 | bool SparseLabelSplitNode::verify() const { |
1896 | bool isValid = |
1897 | checkType("Input and output values must be of the same type" , |
1898 | getLabelValues(), getValues().getElementType(), this); |
1899 | isValid &= checkType("Lengths must be of type int32" , getLengths(), |
1900 | ElemKind::Int32ITy, this); |
1901 | isValid &= checkType("Indices must be of type int64" , getIndices(), |
1902 | ElemKind::Int64ITy, this); |
1903 | isValid &= checkType("ExampleIds must be of type int32" , getExampleIds(), |
1904 | ElemKind::Int32ITy, this); |
1905 | isValid &= checkType("GradientOffsetMap must be of type in32" , |
1906 | getGradientOffsetMap(), ElemKind::Int32ITy, this); |
1907 | isValid &= expectCompareTrue("Lengths must be a 1D vector" , |
1908 | getLengths().dims().size(), size_t(1), this); |
1909 | isValid &= expectCompareTrue("Indices must be a 1D vector" , |
1910 | getIndices().dims().size(), size_t(1), this); |
1911 | isValid &= expectCompareTrue("Values must be a 1D vector" , |
1912 | getValues().dims().size(), size_t(1), this); |
1913 | isValid &= expectCompareTrue("Indices and values must have the same shape" , |
1914 | getIndices().dims(), getValues().dims(), this); |
1915 | return isValid; |
1916 | } |
1917 | |
1918 | bool SGDNode::verify() const { |
1919 | return checkSameType(getGradient(), getWeight(), this); |
1920 | } |
1921 | |
1922 | bool QuantizationProfileNode::verify() const { |
1923 | // Make sure that input tensor is a floating point type. |
1924 | bool isValid = checkType(getInput(), ElemKind::FloatTy, this); |
1925 | |
1926 | // Check computation info has proper size. |
1927 | isValid &= |
1928 | expectCompareTrue("Computation info should be 1 dimensional" , |
1929 | getComputationInfo().dims().size(), size_t(1), this); |
1930 | isValid &= expectCompareTrue( |
1931 | "Computation info should contain Min and Max value only" , |
1932 | getComputationInfo().dims()[0], (dim_t)(2), this); |
1933 | return isValid; |
1934 | } |
1935 | |
1936 | bool IntLookupTableNode::verify() const { |
1937 | bool isValid = |
1938 | expectCompareTrue("Input should be quantized type" , |
1939 | getInput().getType()->isQuantizedType(), true, this); |
1940 | isValid &= |
1941 | expectCompareTrue("Result should be quantized type" , |
1942 | getResult().getType()->isQuantizedType(), true, this); |
1943 | isValid &= expectCompareTrue("Mapping should be 1 dimensional" , |
1944 | getMapping().dims().size(), size_t(1), this); |
1945 | isValid &= expectCompareTrue( |
1946 | "Mapping should cover the whole input quantized range" , |
1947 | getMapping().dims()[0], |
1948 | (dim_t)(getInput().getType()->getQuantizedValueCount()), this); |
1949 | isValid &= expectCompareTrue("Mapping and result type must be the same" , |
1950 | getMapping().getType()->getElementType(), |
1951 | getResult().getType()->getElementType(), this); |
1952 | return isValid; |
1953 | } |
1954 | |
1955 | bool LookupTableNode::verify() const { |
1956 | bool isValid = true; |
1957 | return isValid; |
1958 | } |
1959 | |
1960 | bool QuantizeNode::verify() const { |
1961 | bool isValid = |
1962 | expectCompareTrue("Dest must be quantized" , |
1963 | getResult().getType()->isQuantizedType(), true, this); |
1964 | isValid &= expectCompareTrue("Src must be an FP type" , |
1965 | getInput().getType()->isFPType(), true, this); |
1966 | isValid &= checkSameShape(getResult(), getInput(), this); |
1967 | return isValid; |
1968 | } |
1969 | |
1970 | bool DequantizeNode::verify() const { |
1971 | bool isValid = expectCompareTrue( |
1972 | "Dest must be an FP type" , getResult().getType()->isFPType(), true, this); |
1973 | isValid &= |
1974 | expectCompareTrue("Src must be quantized" , |
1975 | getInput().getType()->isQuantizedType(), true, this); |
1976 | if (getInput().getElementType() == ElemKind::UInt8FusedQTy) { |
1977 | isValid &= expectCompareTrue("Fused tensors should be 2D" , |
1978 | getInput().dims().size(), size_t(2), this); |
1979 | isValid &= expectCompareTrue( |
1980 | "Expected space for per-row scale/offset" , getInput().dims()[1], |
1981 | (dim_t)(2 * sizeof(float)), this, CompareOperatorGreaterThan<dim_t>()); |
1982 | } else { |
1983 | isValid &= checkSameShape(getResult(), getInput(), this); |
1984 | } |
1985 | return isValid; |
1986 | } |
1987 | |
1988 | bool RescaleQuantizedNode::verify() const { |
1989 | bool isValid = |
1990 | expectCompareTrue("Dest must be quantized" , |
1991 | getResult().getType()->isQuantizedType(), true, this); |
1992 | isValid &= |
1993 | checkType(getResult(), getInput().getType()->getElementType(), this); |
1994 | isValid &= checkSameShape(getResult(), getInput(), this); |
1995 | return isValid; |
1996 | } |
1997 | |
1998 | bool CollectRpnProposalsNode::verify() const { |
1999 | auto result = getResult(); |
2000 | auto rois = getRoisIn(); |
2001 | auto probs = getRoisProbsIn(); |
2002 | bool isValid = true; |
2003 | |
2004 | isValid &= expectCompareTrue("rpnPostNmsTopN should be greater than zero" , |
2005 | getRpnPostNmsTopN() > 0, true, this); |
2006 | |
2007 | isValid &= expectCompareTrue( |
2008 | "RPN min level should be less than or equal to RPN max level" , |
2009 | getRpnMinLevel() <= getRpnMaxLevel(), true, this); |
2010 | |
2011 | dim_t rpnLevels = getRpnMaxLevel() - getRpnMinLevel() + 1; |
2012 | |
2013 | isValid &= expectCompareTrue("Invalid number of inputs" , |
2014 | rpnLevels == rois.size(), true, this); |
2015 | isValid &= expectCompareTrue("Invalid number of inputs" , |
2016 | rpnLevels == probs.size(), true, this); |
2017 | |
2018 | for (dim_t i = 0; i < rpnLevels; i++) { |
2019 | auto roi = rois[i]; |
2020 | auto prob = probs[i]; |
2021 | isValid &= checkType(result, roi.getElementType(), this); |
2022 | isValid &= checkType(result, prob.getElementType(), this); |
2023 | isValid &= |
2024 | expectCompareTrue("Rois and result must have same second dimension" , |
2025 | roi.dims()[1], result.dims()[1], this); |
2026 | isValid &= expectCompareTrue( |
2027 | "Rois and respective probability scores must have same first dimension" , |
2028 | roi.dims()[0], prob.dims()[0], this); |
2029 | } |
2030 | |
2031 | isValid &= |
2032 | expectCompareTrue("Result is capped to rpnPostNmsTopN" , |
2033 | result.dims()[0] == getRpnPostNmsTopN(), true, this); |
2034 | |
2035 | return isValid; |
2036 | } |
2037 | |
2038 | bool TopKNode::verify() const { |
2039 | bool isValid = checkSameShape(getValues(), getIndices(), this); |
2040 | isValid &= checkNotQuantizedOrSameParams(getInput().getType(), |
2041 | getValues().getType(), this); |
2042 | return isValid; |
2043 | } |
2044 | |
2045 | bool ArgMaxNode::verify() const { |
2046 | bool isValid = true; |
2047 | |
2048 | // Check output type. |
2049 | isValid &= checkType( |
2050 | getResult(), |
2051 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this); |
2052 | |
2053 | // Check output shape. |
2054 | ShapeVector expDstDims = |
2055 | reduceDims(getInput().dims(), {getAxis()}, getKeepDims()); |
2056 | isValid &= expectCompareTrue("Invalid output dims" , getResult().dims(), |
2057 | llvm::makeArrayRef(expDstDims), this); |
2058 | return isValid; |
2059 | } |
2060 | |
2061 | bool ArgMinNode::verify() const { |
2062 | bool isValid = true; |
2063 | |
2064 | // Check output type. |
2065 | isValid &= checkType( |
2066 | getResult(), |
2067 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this); |
2068 | |
2069 | // Check output shape. |
2070 | ShapeVector expDstDims = |
2071 | reduceDims(getInput().dims(), {getAxis()}, getKeepDims()); |
2072 | isValid &= expectCompareTrue("Invalid output dims" , getResult().dims(), |
2073 | llvm::makeArrayRef(expDstDims), this); |
2074 | return isValid; |
2075 | } |
2076 | |
2077 | bool VectorNormNode::verify() const { |
2078 | bool isValid = true; |
2079 | |
2080 | isValid &= expectCompareTrue("Only support Frobenius, p should be 2" , getP(), |
2081 | (unsigned)2, this); |
2082 | // Check output shape. |
2083 | ShapeVector expDstDims = reduceDims(getInput().dims(), {getAxis()}, false); |
2084 | isValid &= expectCompareTrue("Invalid output dims" , getResult().dims(), |
2085 | llvm::makeArrayRef(expDstDims), this); |
2086 | return isValid; |
2087 | } |
2088 | |
2089 | bool GaussianFillNode::verify() const { |
2090 | auto dest = getResult(); |
2091 | bool isValid = dest.getElementType() == ElemKind::Float16Ty; |
2092 | isValid &= checkSameShape(getInput(), dest, this); |
2093 | return isValid; |
2094 | } |
2095 | |
2096 | bool DynamicQuantizedFullyConnectedNode::verify() const { |
2097 | auto src = getInput(); |
2098 | auto weights = getWeights(); |
2099 | auto bias = getBias(); |
2100 | auto dest = getResult(); |
2101 | auto isPerBatchElement = getIsPerBatchElement(); |
2102 | auto isSymmetric = getIsSymmetric(); |
2103 | |
2104 | bool isValid = expectCompareTrue("Inputs should be 2D tensor" , |
2105 | src.dims().size(), size_t(2), this); |
2106 | isValid &= expectCompareTrue( |
2107 | "Only per batch quantized input DynQuantizedFC is supported now" , |
2108 | isPerBatchElement, true, this); |
2109 | isValid &= expectCompareTrue( |
2110 | "Only symmetric quantized DynQuantizedFC is supported now" , isSymmetric, |
2111 | true, this); |
2112 | isValid &= expectCompareTrue("Weights should be 2D tensor" , |
2113 | weights.dims().size(), size_t(2), this); |
2114 | isValid &= expectCompareTrue("Result should be 2D tensor" , dest.dims().size(), |
2115 | size_t(2), this); |
2116 | isValid &= expectCompareTrue("Bias should be 1D tensor" , bias.dims().size(), |
2117 | size_t(1), this); |
2118 | |
2119 | isValid &= expectCompareTrue("Mismatch on expected source dimension 0" , |
2120 | src.dims()[0], dest.dims()[0], this); |
2121 | isValid &= expectCompareTrue("Mismatch on expected source dimension 1" , |
2122 | src.dims()[1], weights.dims()[0], this); |
2123 | |
2124 | isValid &= expectCompareTrue("Inconsistent bias/weights sizes" , |
2125 | bias.dims()[0], weights.dims()[1], this); |
2126 | isValid &= expectCompareTrue("Inconsistent bias/dest sizes" , bias.dims()[0], |
2127 | dest.dims()[1], this); |
2128 | |
2129 | return isValid; |
2130 | } |
2131 | |
2132 | bool DynamicRowwiseQuantizedFullyConnectedNode::verify() const { |
2133 | auto src = getInput(); |
2134 | auto weights = getWeights(); |
2135 | auto bias = getBias(); |
2136 | auto dest = getResult(); |
2137 | auto scales = getScales(); |
2138 | auto offsets = getOffsets(); |
2139 | auto isPerBatchElement = getIsPerBatchElement(); |
2140 | auto isSymmetric = getIsSymmetric(); |
2141 | |
2142 | bool isValid = expectCompareTrue("Inputs should be 2D tensor" , |
2143 | src.dims().size(), size_t(2), this); |
2144 | isValid &= expectCompareTrue( |
2145 | "Only per batch quantized input DynQuantizedFC is supported now" , |
2146 | isPerBatchElement, true, this); |
2147 | isValid &= expectCompareTrue( |
2148 | "Only symmetric quantized DynQuantizedFC is supported now" , isSymmetric, |
2149 | true, this); |
2150 | isValid &= expectCompareTrue("Weights should be 2D tensor" , |
2151 | weights.dims().size(), size_t(2), this); |
2152 | isValid &= expectCompareTrue("Result should be 2D tensor" , dest.dims().size(), |
2153 | size_t(2), this); |
2154 | isValid &= expectCompareTrue("Bias should be 1D tensor" , bias.dims().size(), |
2155 | size_t(1), this); |
2156 | isValid &= expectCompareTrue("Offsets should be 1D tensor" , |
2157 | offsets.dims().size(), size_t(1), this); |
2158 | isValid &= expectCompareTrue("Scales should be 1D tensor" , |
2159 | scales.dims().size(), size_t(1), this); |
2160 | |
2161 | isValid &= expectCompareTrue("Mismatch on expected source dimension 0" , |
2162 | src.dims()[0], dest.dims()[0], this); |
2163 | isValid &= expectCompareTrue("Mismatch on expected source dimension 1" , |
2164 | src.dims()[1], weights.dims()[0], this); |
2165 | |
2166 | isValid &= expectCompareTrue("Inconsistent bias/weights sizes" , |
2167 | bias.dims()[0], weights.dims()[1], this); |
2168 | isValid &= expectCompareTrue("Inconsistent bias/dest sizes" , bias.dims()[0], |
2169 | dest.dims()[1], this); |
2170 | isValid &= expectCompareTrue("Inconsistent scales/offsets sizes" , |
2171 | scales.dims()[0], offsets.dims()[0], this); |
2172 | isValid &= expectCompareTrue("Inconsistent scales/weights sizes" , |
2173 | scales.dims()[0], weights.dims()[1], this); |
2174 | |
2175 | return isValid; |
2176 | } |
2177 | |
2178 | bool RowwiseQuantizedFullyConnectedNode::verify() const { |
2179 | auto src = getInput(); |
2180 | auto weights = getWeights(); |
2181 | auto scales = getScales(); |
2182 | auto offsets = getOffsets(); |
2183 | auto bias = getBias(); |
2184 | auto dest = getResult(); |
2185 | |
2186 | bool isValid = expectCompareTrue("Inputs should be 2D tensor" , |
2187 | src.dims().size(), size_t(2), this); |
2188 | isValid &= expectCompareTrue("Weights should be 2D tensor" , |
2189 | weights.dims().size(), size_t(2), this); |
2190 | isValid &= expectCompareTrue("Result should be 2D tensor" , dest.dims().size(), |
2191 | size_t(2), this); |
2192 | isValid &= expectCompareTrue("Offsets should be 1D tensor" , |
2193 | offsets.dims().size(), size_t(1), this); |
2194 | isValid &= expectCompareTrue("Scales should be 1D tensor" , |
2195 | scales.dims().size(), size_t(1), this); |
2196 | isValid &= expectCompareTrue("Bias should be 1D tensor" , bias.dims().size(), |
2197 | size_t(1), this); |
2198 | |
2199 | isValid &= expectCompareTrue("Mismatch on expected source dimension 0" , |
2200 | src.dims()[0], dest.dims()[0], this); |
2201 | isValid &= expectCompareTrue("Mismatch on expected source dimension 1" , |
2202 | src.dims()[1], weights.dims()[1], this); |
2203 | |
2204 | isValid &= expectCompareTrue("Inconsistent bias/dest sizes" , bias.dims()[0], |
2205 | weights.dims()[0], this); |
2206 | isValid &= expectCompareTrue("Inconsistent weights/dest sizes" , |
2207 | weights.dims()[0], dest.dims()[1], this); |
2208 | |
2209 | isValid &= expectCompareTrue("Inconsistent scales/offsets sizes" , |
2210 | scales.dims()[0], offsets.dims()[0], this); |
2211 | isValid &= expectCompareTrue("Inconsistent scales/weights sizes" , |
2212 | scales.dims()[0], weights.dims()[0], this); |
2213 | return isValid; |
2214 | } |
2215 | |
2216 | bool GatherNode::verify() const { |
2217 | bool isValid = checkType(getResult(), getData().getElementType(), this); |
2218 | isValid &= checkType( |
2219 | getIndices(), |
2220 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this); |
2221 | isValid &= expectCompareTrue( |
2222 | "Mismatching number of dimensions" , getResult().dims().size(), |
2223 | getData().dims().size() + getIndices().dims().size() - 1, this); |
2224 | isValid &= checkNotQuantizedOrSameParams(getResult().getType(), |
2225 | getData().getType(), this); |
2226 | return isValid; |
2227 | } |
2228 | |
2229 | bool GatherElementsNode::verify() const { |
2230 | bool isValid = checkType(getResult(), getData().getElementType(), this); |
2231 | isValid &= checkType( |
2232 | getIndices(), |
2233 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this); |
2234 | isValid &= expectCompareTrue("Mismatching number of dimensions" , |
2235 | getResult().dims().size(), |
2236 | getIndices().dims().size(), this); |
2237 | return isValid; |
2238 | } |
2239 | |
2240 | bool GatherNDNode::verify() const { |
2241 | bool isValid = checkType(getResult(), getData().getElementType(), this); |
2242 | isValid &= checkType( |
2243 | getIndices(), |
2244 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this); |
2245 | isValid &= expectCompareTrue( |
2246 | "Mismatching number of dimensions" , getResult().dims().size(), |
2247 | getData().dims().size() + getIndices().dims().size() - |
2248 | getIndices().dims().back() - 1 - getBatchDims(), |
2249 | this); |
2250 | isValid &= checkNotQuantizedOrSameParams(getResult().getType(), |
2251 | getData().getType(), this); |
2252 | return isValid; |
2253 | } |
2254 | |
2255 | bool GatherRangesNode::verify() const { |
2256 | bool isValid = expectCompareTrue("Data must be 1D" , getData().dims().size(), |
2257 | size_t(1), this); |
2258 | isValid &= expectCompareTrue("Ranges must be 3D" , getRanges().dims().size(), |
2259 | size_t(3), this); |
2260 | isValid &= expectCompareTrue("Last dimension of Ranges must be equal to 2" , |
2261 | getRanges().dims()[2], dim_t(2), this); |
2262 | isValid &= expectCompareTrue("Output must be 1D" , getOutput().dims().size(), |
2263 | size_t(1), this); |
2264 | isValid &= expectCompareTrue("Lengths must be 1D" , getLengths().dims().size(), |
2265 | size_t(1), this); |
2266 | isValid &= |
2267 | expectCompareTrue("Number of examples must match number of lengths" , |
2268 | getRanges().dims()[0], getLengths().dims()[0], this); |
2269 | |
2270 | isValid &= checkTypeIgnoreShape(getOutput(), getData(), this); |
2271 | isValid &= checkTypeIgnoreShape(getRanges(), getLengths(), this); |
2272 | |
2273 | return isValid; |
2274 | } |
2275 | |
2276 | bool ScatterDataNode::verify() const { |
2277 | const auto &slicesDims = getSlices().dims(); |
2278 | const auto &dataDims = getData().dims(); |
2279 | const auto &indicesDims = getIndices().dims(); |
2280 | bool isValid = true; |
2281 | isValid &= expectCompareTrue("Type mismatch" , |
2282 | getSlices().getType()->getElementType(), |
2283 | getData().getType()->getElementType(), this); |
2284 | if (!isValid) { |
2285 | return false; |
2286 | } |
2287 | // TODO: Do we need support for different quant params of copy? |
2288 | if (getSlices().getType()->isQuantizedType() && !getCumulative()) { |
2289 | isValid &= |
2290 | expectCompareTrue("Scale mismatch" , getSlices().getType()->getScale(), |
2291 | getData().getType()->getScale(), this); |
2292 | isValid &= |
2293 | expectCompareTrue("Offset mismatch" , getSlices().getType()->getOffset(), |
2294 | getData().getType()->getOffset(), this); |
2295 | } |
2296 | isValid &= expectCompareTrue("There should be an index for each slice" , |
2297 | indicesDims[0], slicesDims[0], this); |
2298 | isValid &= expectCompareTrue("Indices should be a 2D tensor" , |
2299 | indicesDims.size(), size_t(2), this); |
2300 | // The code below may crash if these conditions are not met. |
2301 | if (!isValid) { |
2302 | return false; |
2303 | } |
2304 | const size_t indexSize = indicesDims[1]; |
2305 | isValid &= expectCompareTrue("Dimensions of Data should be equal to " |
2306 | "dimensions of indices + dimensions of updates" , |
2307 | slicesDims.size() - 1 + indexSize, |
2308 | dataDims.size(), this); |
2309 | if (dataDims.size() > 1) { |
2310 | for (size_t i = indexSize; i < dataDims.size(); i++) { |
2311 | std::string msg = std::to_string(i); |
2312 | msg = "Slice shape should equal data shape for dim " + msg; |
2313 | isValid &= expectCompareTrue(msg.c_str(), dataDims[i], |
2314 | slicesDims[i - indexSize + 1], this); |
2315 | } |
2316 | } |
2317 | |
2318 | return isValid; |
2319 | } |
2320 | |
2321 | bool BatchOneHotNode::verify() const { |
2322 | const auto &dataDims = getData().dims(); |
2323 | const auto &lengthsDims = getLengths().dims(); |
2324 | const auto &valuesDims = getValues().dims(); |
2325 | |
2326 | bool isValid = expectCompareTrue("Data should be a two dimensional matrix" , |
2327 | dataDims.size(), size_t(2), this); |
2328 | |
2329 | isValid &= expectCompareTrue("Lengths should be a single dimensional vectors" , |
2330 | lengthsDims.size(), size_t(1), this); |
2331 | isValid &= checkType(getLengths(), ElemKind::Int32ITy, this); |
2332 | |
2333 | isValid &= expectCompareTrue("Values should be a single dimensional vectors" , |
2334 | valuesDims.size(), size_t(1), this); |
2335 | |
2336 | isValid &= |
2337 | expectCompareTrue("Size of Lengths should be equal to width of Data" , |
2338 | lengthsDims[0], dataDims[1], this); |
2339 | return isValid; |
2340 | } |
2341 | |
2342 | bool SpaceToDepthNode::verify() const { |
2343 | auto inputN = getInput(); |
2344 | auto resultN = getResult(); |
2345 | auto inputDims = inputN.dims(); |
2346 | auto outputDims = resultN.dims(); |
2347 | unsigned blockSize = getBlockSize(); |
2348 | |
2349 | bool sameType = checkTypeIgnoreShape(inputN, resultN, this); |
2350 | bool dimTransform = inputDims[0] == outputDims[0] && |
2351 | inputDims[1] == outputDims[1] * blockSize && |
2352 | inputDims[2] == outputDims[2] * blockSize && |
2353 | inputDims[3] * blockSize * blockSize == outputDims[3]; |
2354 | |
2355 | return sameType && dimTransform; |
2356 | } |
2357 | |
2358 | bool ResizeNearestNode::verify() const { |
2359 | auto input = getInput(); |
2360 | auto scale = getScale(); |
2361 | auto result = getResult(); |
2362 | auto inputDims = input.dims(); |
2363 | auto outputDims = result.dims(); |
2364 | |
2365 | bool isValid = checkTypeIgnoreShape(input, result, this); |
2366 | isValid &= |
2367 | expectCompareTrue("Input size must be greater than 2" , inputDims.size(), |
2368 | size_t(2), this, CompareOperatorGreaterThan<size_t>()); |
2369 | isValid &= |
2370 | expectCompareTrue("Output size must be greater than 2" , outputDims.size(), |
2371 | size_t(2), this, CompareOperatorGreaterThan<size_t>()); |
2372 | isValid &= expectCompareTrue("Input size must be equal to the output size" , |
2373 | inputDims.size(), outputDims.size(), this); |
2374 | |
2375 | for (size_t i = 0, e = scale.size(); i < e; i++) { |
2376 | isValid &= expectCompareTrue("Unexpected output" , |
2377 | dim_t(std::floor(inputDims[i] * scale[i])), |
2378 | outputDims[i], this); |
2379 | isValid &= expectCompareTrue("Invalid scale" , scale[i], float(0.0), this, |
2380 | CompareOperatorGreaterThan<float>()); |
2381 | } |
2382 | |
2383 | return isValid; |
2384 | } |
2385 | |
2386 | bool ResizeBilinearNode::verify() const { |
2387 | auto input = getInput(); |
2388 | auto scale = getScale(); |
2389 | auto result = getResult(); |
2390 | auto inputDims = input.dims(); |
2391 | auto outputDims = result.dims(); |
2392 | |
2393 | bool isValid = checkTypeIgnoreShape(input, result, this); |
2394 | isValid &= expectCompareTrue("Input must be a 4D tensor" , inputDims.size(), |
2395 | size_t(4), this); |
2396 | isValid &= expectCompareTrue("Output must be a 4D tensor" , outputDims.size(), |
2397 | size_t(4), this); |
2398 | |
2399 | for (size_t i = 0, e = scale.size(); i < e; i++) { |
2400 | isValid &= expectCompareTrue("Unexpected output" , |
2401 | dim_t(std::floor(inputDims[i] * scale[i])), |
2402 | outputDims[i], this); |
2403 | isValid &= expectCompareTrue("Invalid scale" , scale[i], float(0.0), this, |
2404 | CompareOperatorGreaterThan<float>()); |
2405 | } |
2406 | |
2407 | return isValid; |
2408 | } |
2409 | |
2410 | bool NonMaxSuppressionNode::verify() const { |
2411 | NodeValue boxes = getBoxes(); |
2412 | NodeValue scores = getScores(); |
2413 | auto boxesDims = boxes.dims(); |
2414 | auto scoresDims = scores.dims(); |
2415 | bool isV4 = getIsTFVersion4(); |
2416 | |
2417 | size_t scoresBoxDim = scores.dims().size() - 1; |
2418 | size_t scoresBatchDim = scores.dims().size() - 3; |
2419 | |
2420 | size_t boxesBoxDim = boxes.dims().size() - 2; |
2421 | size_t boxesBatchDim = boxes.dims().size() - 3; |
2422 | |
2423 | bool isValid = true; |
2424 | if (isV4) { |
2425 | isValid &= expectCompareTrue( |
2426 | "Number of boxes doesn't match number of confidence scores." , |
2427 | boxesDims[boxesBoxDim], scoresDims[scoresBoxDim], this, |
2428 | CompareOperatorEqual<dim_t>()); |
2429 | } |
2430 | |
2431 | // checking layout matching. See ONNX spec for details. |
2432 | if (!isV4) { |
2433 | isValid &= expectCompareTrue( |
2434 | "Batch dimension doesn't match." , boxesDims[boxesBatchDim], |
2435 | scoresDims[scoresBatchDim], this, CompareOperatorEqual<dim_t>()); |
2436 | |
2437 | isValid &= expectCompareTrue( |
2438 | "Number of boxes doesn't match number of confidence scores." , |
2439 | boxesDims[boxesBoxDim], scoresDims[scoresBoxDim], this, |
2440 | CompareOperatorEqual<dim_t>()); |
2441 | } |
2442 | |
2443 | isValid &= checkType(boxes, scores.getElementType(), this); |
2444 | |
2445 | return isValid; |
2446 | } |
2447 | |
2448 | bool TFLiteDetectionPostProcessNode::verify() const { |
2449 | NodeValue boxes = getBoxes(); |
2450 | NodeValue scores = getScores(); |
2451 | NodeValue anchors = getAnchors(); |
2452 | |
2453 | auto boxesDims = boxes.dims(); |
2454 | auto scoresDims = scores.dims(); |
2455 | auto anchorsDims = anchors.dims(); |
2456 | |
2457 | bool isValid = true; |
2458 | |
2459 | // Validate input tensor sizes. |
2460 | isValid &= expectCompareTrue("Input boxes must be a 3D tensor!" , |
2461 | boxesDims.size(), size_t(3), this); |
2462 | isValid &= expectCompareTrue("Input scores must be a 3D tensor!" , |
2463 | scoresDims.size(), size_t(3), this); |
2464 | isValid &= expectCompareTrue("Input anchors must be a 2D tensor!" , |
2465 | anchorsDims.size(), size_t(2), this); |
2466 | dim_t numBoxes = boxesDims[1]; |
2467 | dim_t numTotClasses = scoresDims[2]; |
2468 | isValid &= expectCompareTrue("Input boxes size invalid!" , boxesDims[1], |
2469 | numBoxes, this); |
2470 | isValid &= expectCompareTrue("Input boxes size invalid!" , boxesDims[2], |
2471 | dim_t(4), this); |
2472 | isValid &= expectCompareTrue("Input scores size invalid!" , scoresDims[0], |
2473 | boxesDims[0], this); |
2474 | isValid &= expectCompareTrue("Input scores size invalid!" , scoresDims[1], |
2475 | numBoxes, this); |
2476 | isValid &= expectCompareTrue("Input scores size invalid!" , scoresDims[2], |
2477 | numTotClasses, this); |
2478 | isValid &= expectCompareTrue("Input anchors size invalid!" , anchorsDims[0], |
2479 | numBoxes, this); |
2480 | isValid &= expectCompareTrue("Input anchors size invalid!" , anchorsDims[1], |
2481 | dim_t(4), this); |
2482 | |
2483 | // Validate parameters. |
2484 | isValid &= |
2485 | expectCompareTrue("Invalid IOU threshold!" , getIouThreshold(), float(0.0), |
2486 | this, CompareOperatorGreaterThan<float>()); |
2487 | isValid &= |
2488 | expectCompareTrue("Invalid IOU threshold!" , getIouThreshold(), float(1.0), |
2489 | this, CompareOperatorLessEqual<float>()); |
2490 | isValid &= |
2491 | expectCompareTrue("Invalid score threshold!" , getScoreThreshold(), |
2492 | float(0.0), this, CompareOperatorGreaterThan<float>()); |
2493 | isValid &= |
2494 | expectCompareTrue("Invalid score threshold!" , getScoreThreshold(), |
2495 | float(1.0), this, CompareOperatorLessEqual<float>()); |
2496 | isValid &= |
2497 | expectCompareTrue("Invalid number of classes!" , dim_t(getNumClasses()), |
2498 | numTotClasses, this, CompareOperatorLessEqual<dim_t>()); |
2499 | isValid &= |
2500 | expectCompareTrue("Invalid max detections!" , dim_t(getMaxDetections()), |
2501 | dim_t(0), this, CompareOperatorGreaterThan<dim_t>()); |
2502 | return isValid; |
2503 | } |
2504 | |
2505 | bool AudioSpectrogramNode::verify() const { |
2506 | NodeValue input = getInput(); |
2507 | NodeValue spectrogram = getSpectrogram(); |
2508 | auto inputLength = input.getType()->size(); |
2509 | auto windowSize = getWindowSize(); |
2510 | auto windowStride = getWindowStride(); |
2511 | auto windowCount = std::floor((inputLength - windowSize) / windowStride) + 1; |
2512 | auto fftLen = 1 << (int)std::ceil(std::log2((double)windowSize)); |
2513 | |
2514 | bool isValid = true; |
2515 | isValid &= expectCompareTrue("Input audio is too short for given window size" , |
2516 | dim_t(windowCount), dim_t(0), this, |
2517 | CompareOperatorGreaterThan<dim_t>()); |
2518 | isValid &= expectCompareTrue("Output spectrogram must be a 2D tensor" , |
2519 | spectrogram.dims().size(), size_t(2), this); |
2520 | isValid &= expectCompareTrue("Output spectrogram size is invalid" , |
2521 | spectrogram.dims()[0], dim_t(windowCount), this, |
2522 | CompareOperatorEqual<dim_t>()); |
2523 | isValid &= expectCompareTrue("Output spectrogram size is invalid" , |
2524 | spectrogram.dims()[1], dim_t(fftLen / 2 + 1), |
2525 | this, CompareOperatorEqual<dim_t>()); |
2526 | return isValid; |
2527 | } |
2528 | |
2529 | bool MFCCNode::verify() const { |
2530 | NodeValue spectrogram = getSpectrogram(); |
2531 | NodeValue coefficients = getCoefficients(); |
2532 | float sampleRate = getSampleRate(); |
2533 | float lowerFrequency = getLowerFrequency(); |
2534 | float upperFrequency = getUpperFrequency(); |
2535 | auto filterBankCount = getFilterBankCount(); |
2536 | auto numCoefficients = getNumCoefficients(); |
2537 | auto fftLen = (spectrogram.dims()[1] - 1) * 2; |
2538 | int exp; |
2539 | |
2540 | bool isValid = true; |
2541 | isValid &= expectCompareTrue("Input spectrogram must be a 2D tensor" , |
2542 | spectrogram.dims().size(), size_t(2), this); |
2543 | isValid &= expectCompareTrue( |
2544 | "Input spectrogram size is invalid. Should be of the form 2^N/2+1." , |
2545 | std::abs(std::frexp((float)(fftLen), &exp)), float(0.5), this, |
2546 | CompareOperatorEqual<float>()); |
2547 | isValid &= expectCompareTrue("Output coefficients must be a 2D tensor" , |
2548 | coefficients.dims().size(), size_t(2), this); |
2549 | isValid &= expectCompareTrue("Output coefficients size is invalid" , |
2550 | coefficients.dims()[1], dim_t(numCoefficients), |
2551 | this, CompareOperatorEqual<dim_t>()); |
2552 | isValid &= expectCompareTrue( |
2553 | "Number of windows should be same for both input and output" , |
2554 | spectrogram.dims()[0], coefficients.dims()[0], this, |
2555 | CompareOperatorEqual<dim_t>()); |
2556 | isValid &= expectCompareTrue("Lower frequency should be greater than 0" , |
2557 | lowerFrequency, float(0.0), this, |
2558 | CompareOperatorGreaterThan<float>()); |
2559 | isValid &= expectCompareTrue("Upper frequency should be greater than 0" , |
2560 | upperFrequency, float(0.0), this, |
2561 | CompareOperatorGreaterThan<float>()); |
2562 | isValid &= expectCompareTrue( |
2563 | "Upper frequency must be greater than lower frequency" , upperFrequency, |
2564 | lowerFrequency, this, CompareOperatorGreaterThan<float>()); |
2565 | isValid &= expectCompareTrue( |
2566 | "Upper frequency must be lower than half the sample rate" , sampleRate, |
2567 | float(2.0 * upperFrequency), this, CompareOperatorGreaterThan<float>()); |
2568 | isValid &= expectCompareTrue( |
2569 | "Number of coefficients should be smaller or equal than the filter bank" , |
2570 | dim_t(filterBankCount), dim_t(numCoefficients), this, |
2571 | CompareOperatorGreaterEqual<dim_t>()); |
2572 | return isValid; |
2573 | } |
2574 | |
2575 | bool ROIAlignNode::verify() const { |
2576 | auto featureMap = getFeatureMap(); |
2577 | auto boxes = getBoxes(); |
2578 | auto batchIndices = getBatchIndices(); |
2579 | auto result = getResult(); |
2580 | auto featureMapDims = featureMap.dims(); |
2581 | auto boxesDims = boxes.dims(); |
2582 | auto outputDims = result.dims(); |
2583 | |
2584 | bool isValid = checkTypeIgnoreShape(featureMap, result, this); |
2585 | isValid &= checkTypeIgnoreShape(boxes, result, this); |
2586 | isValid &= |
2587 | checkType(featureMap, {ElemKind::FloatTy, ElemKind::Float16Ty}, this); |
2588 | isValid &= expectCompareTrue("FeatureMap must be a 4D tensor" , |
2589 | featureMapDims.size(), size_t(4), this); |
2590 | isValid &= expectCompareTrue("Boxes must be a 2D tensor" , boxesDims.size(), |
2591 | size_t(2), this); |
2592 | isValid &= expectCompareTrue("Output must be a 4D tensor" , outputDims.size(), |
2593 | size_t(4), this); |
2594 | // If batch size > 1 batch indices must be provided. |
2595 | if (featureMapDims[0] > 1) { |
2596 | // Caffe2 gets indices using boxes tensor |
2597 | bool indicesInBoxesTensor = boxesDims[1] == (getRotated() ? 6 : 5); |
2598 | // Onnx requires batchIndices to be valid |
2599 | if (!indicesInBoxesTensor) { |
2600 | auto batchIndicesDims = batchIndices.dims(); |
2601 | isValid &= checkType(batchIndices, |
2602 | {ElemKind::Int64ITy, ElemKind::Int32ITy}, this); |
2603 | isValid &= expectCompareTrue("BatchIndices must be a 1D tensor" , |
2604 | batchIndicesDims.size(), size_t(1), this); |
2605 | isValid &= |
2606 | expectCompareTrue("BatchIndices must have same length as Boxes" , |
2607 | batchIndicesDims[0], boxesDims[0], this); |
2608 | } |
2609 | } |
2610 | return isValid; |
2611 | } |
2612 | |
2613 | bool BBoxTransformNode::verify() const { |
2614 | auto rois = getRois(); |
2615 | auto deltas = getDeltas(); |
2616 | auto imInfo = getImInfo(); |
2617 | auto boxOut = getBoxOut(); |
2618 | auto weights = getWeights(); |
2619 | auto period = getAngleBoundHi() - getAngleBoundLo(); |
2620 | |
2621 | auto roisDims = rois.dims(); |
2622 | auto deltasDims = deltas.dims(); |
2623 | auto imInfoDims = imInfo.dims(); |
2624 | |
2625 | bool rotated = getRotated(); |
2626 | bool angleBoundOn = getAngleBoundOn(); |
2627 | // BoxDim is of the format |
2628 | // <x1, y1, x2, y2, [optional_angle]> |
2629 | dim_t expectedBoxDim = rotated ? 5 : 4; |
2630 | |
2631 | // Rois row is of the format |
2632 | // <[optinal_batch_index], x1, y1, x2, y2, [optional_angle]> |
2633 | bool validRoiDim = |
2634 | roisDims[1] == expectedBoxDim || roisDims[1] == expectedBoxDim + 1; |
2635 | |
2636 | bool isValid = checkTypeIgnoreShape(rois, boxOut, this); |
2637 | isValid &= checkSameType(deltas, boxOut, this); |
2638 | isValid &= checkTypeIgnoreShape(imInfo, boxOut, this); |
2639 | // ROIs can be float32 or float16. |
2640 | isValid &= checkType(rois, {ElemKind::FloatTy, ElemKind::Float16Ty}, this); |
2641 | isValid &= expectCompareTrue("Rois must be a 2D tensor" , roisDims.size(), |
2642 | size_t(2), this); |
2643 | isValid &= |
2644 | expectCompareTrue("Rois must have with equals boxDim or larger in 1" , |
2645 | validRoiDim, true, this); |
2646 | isValid &= expectCompareTrue("Deltas must be a 2D tensor" , deltasDims.size(), |
2647 | size_t(2), this); |
2648 | isValid &= expectCompareTrue("ImInfo must be a 2D tensor" , imInfoDims.size(), |
2649 | size_t(2), this); |
2650 | isValid &= expectCompareTrue("ImInfo must be a {batch_size, 3} tensor" , |
2651 | imInfoDims[1], dim_t(3), this); |
2652 | isValid &= expectCompareTrue("Rois and Deltas must have same 0 dimension" , |
2653 | roisDims[0], deltasDims[0], this); |
2654 | isValid &= expectCompareTrue( |
2655 | "Number of rois must be <= 2048 to be represented in FP16." , roisDims[0], |
2656 | dim_t(2048), this, CompareOperatorLessEqual<dim_t>()); |
2657 | isValid &= expectCompareTrue("Deltas must be divisible by box dimensions" , |
2658 | deltasDims[1] % expectedBoxDim, dim_t(0), this); |
2659 | isValid &= expectCompareTrue("Weights must be a 1D vector of length 4" , |
2660 | weights.size(), size_t(4), this); |
2661 | if (roisDims[1] == expectedBoxDim) { |
2662 | isValid &= expectCompareTrue( |
2663 | "The batch size should be 1 if there's no batch index in rois" , |
2664 | imInfoDims[0], dim_t(1), this); |
2665 | } |
2666 | if (rotated && angleBoundOn) { |
2667 | isValid &= expectCompareTrue( |
2668 | "The difference between angleBoundHi and angleBoundLo " |
2669 | "should be greater than 0 and divisible by 180" , |
2670 | period > 0 && period % 180 == 0, true, this); |
2671 | } |
2672 | |
2673 | return isValid; |
2674 | } |
2675 | |
2676 | bool SaveNode::verify() const { |
2677 | return checkSameType(getInput(), getOutput(), this); |
2678 | } |
2679 | |
2680 | bool LogNode::verify() const { |
2681 | if (getResult().getType()->isQuantizedType()) { |
2682 | return checkSameShape(getInput(), getResult(), this); |
2683 | } |
2684 | return checkSameType(getInput(), getResult(), this); |
2685 | } |
2686 | |
2687 | bool IsNaNNode::verify() const { |
2688 | bool isValid = checkSameShape(getResult(), getInput(), this); |
2689 | isValid &= checkType(getResult(), ElemKind::BoolTy, this); |
2690 | return isValid; |
2691 | } |
2692 | |
2693 | bool ReplaceNaNNode::verify() const { |
2694 | return checkSameType(getResult(), getInput(), this); |
2695 | } |
2696 | |
2697 | bool NonZeroNode::verify() const { |
2698 | return checkType(getCond(), ElemKind::BoolTy, this) && |
2699 | checkType(getResult(), ElemKind::Int32ITy, this); |
2700 | } |
2701 | |
2702 | bool SelectNode::verify() const { |
2703 | bool isValid = checkSameShape(getResult(), getLHS(), this); |
2704 | isValid &= checkSameShape(getResult(), getRHS(), this); |
2705 | isValid &= checkSameShape(getResult(), getCond(), this); |
2706 | isValid &= checkType(getLHS(), getRHS().getElementType(), this); |
2707 | isValid &= checkType(getLHS(), getResult().getElementType(), this); |
2708 | isValid &= checkType(getCond(), ElemKind::BoolTy, this); |
2709 | return isValid; |
2710 | } |
2711 | |
2712 | bool ReluNode::verify() const { return verifyRelu(getResult(), getInput()); } |
2713 | |
2714 | bool GeluNode::verify() const { |
2715 | const Node *parent = getResult().getNode(); |
2716 | return checkSameType(getResult(), getInput(), parent); |
2717 | } |
2718 | |
2719 | bool ReluGradNode::verify() const { |
2720 | return verifyInputAndGradInputTypes(getInput(), getGradOfInputNamedInput(), |
2721 | this) && |
2722 | verifyOutputAndGradOutputTypes(getOriginalOutputForResult(), |
2723 | getGradOfOriginalOutputNamedResult(), |
2724 | this) && |
2725 | verifyRelu(getGradOfOriginalOutputNamedResult(), getInput()); |
2726 | } |
2727 | |
2728 | bool LeakyReluNode::verify() const { |
2729 | return verifyRelu(getResult(), getInput()); |
2730 | } |
2731 | |
2732 | bool PReluNode::verify() const { |
2733 | return verifyPRelu(getResult(), getInput(), getSlope()); |
2734 | } |
2735 | |
2736 | bool RegressionNode::verify() const { |
2737 | return verifyRegression(getInput(), getResult(), getExpected()); |
2738 | } |
2739 | |
2740 | bool RegressionGradNode::verify() const { |
2741 | return verifyInputAndGradInputTypes(getExpected(), |
2742 | getGradOfInputNamedExpected(), this) && |
2743 | verifyInputAndGradInputTypes(getInput(), getGradOfInputNamedInput(), |
2744 | this) && |
2745 | verifyOutputAndGradOutputTypes(getOriginalOutputForResult(), |
2746 | getGradOfOriginalOutputNamedResult(), |
2747 | this) && |
2748 | verifyRegression(getGradOfInputNamedInput(), |
2749 | getGradOfOriginalOutputNamedResult(), |
2750 | getGradOfInputNamedExpected()); |
2751 | } |
2752 | |
2753 | bool SigmoidCrossEntropyWithLogitsNode::verify() const { |
2754 | bool isValid = checkType(getResult(), getLogits().getElementType(), this); |
2755 | isValid &= checkSameType(getLogits(), getTargets(), this); |
2756 | return isValid; |
2757 | } |
2758 | |
2759 | bool GemmNode::verify() const { |
2760 | NodeValue A = getA(); |
2761 | NodeValue B = getB(); |
2762 | NodeValue C = getC(); |
2763 | NodeValue Y = getResult(); |
2764 | bool transA = getTransposeA(); |
2765 | bool transB = getTransposeB(); |
2766 | const Node *parent = Y.getNode(); |
2767 | |
2768 | // Check types. |
2769 | bool isValid = checkType(B, A.getElementType(), this); |
2770 | // Check for element kind of bias |
2771 | if (C.getNode()) { |
2772 | // Non quantization type check. |
2773 | if (A.getElementType() == ElemKind::FloatTy || |
2774 | A.getElementType() == ElemKind::Float16Ty) { |
2775 | isValid &= checkType(C, A.getElementType(), parent); |
2776 | } |
2777 | // Quantization type check. |
2778 | if (A.getElementType() == ElemKind::Int8QTy) { |
2779 | isValid &= expectCompareTrue("Bias type should be Int8 or Int32 for Gemm" , |
2780 | C.getElementType() == ElemKind::Int8QTy || |
2781 | C.getElementType() == ElemKind::Int32QTy, |
2782 | true, parent); |
2783 | } |
2784 | } |
2785 | isValid &= checkType(Y, A.getElementType(), this); |
2786 | |
2787 | // Check shapes. |
2788 | isValid &= |
2789 | expectCompareTrue("Input A must be 2D" , A.dims().size(), size_t(2), this); |
2790 | isValid &= |
2791 | expectCompareTrue("Input B must be 2D" , B.dims().size(), size_t(2), this); |
2792 | if (C.getNode()) { |
2793 | isValid &= |
2794 | expectCompareTrue("Input C must be 1D or 2D" , C.dims().size(), |
2795 | size_t(2), this, CompareOperatorLessEqual<size_t>()); |
2796 | } |
2797 | isValid &= |
2798 | expectCompareTrue("Output must be 2D" , Y.dims().size(), size_t(2), this); |
2799 | std::vector<dim_t> dimsA = A.dims(); |
2800 | std::vector<dim_t> dimsB = B.dims(); |
2801 | if (transA) { |
2802 | dimsA[0] = A.dims()[1]; |
2803 | dimsA[1] = A.dims()[0]; |
2804 | } |
2805 | if (transB) { |
2806 | dimsB[0] = B.dims()[1]; |
2807 | dimsB[1] = B.dims()[0]; |
2808 | } |
2809 | isValid &= expectCompareTrue("Input A (transposed) dimension 0 size invalid" , |
2810 | dimsA[0], Y.dims()[0], this, |
2811 | CompareOperatorEqual<dim_t>()); |
2812 | isValid &= expectCompareTrue("Input A (transposed) dimension 1 size invalid" , |
2813 | dimsA[1], dimsB[0], this, |
2814 | CompareOperatorEqual<dim_t>()); |
2815 | isValid &= expectCompareTrue("Input B (transposed) dimension 1 size invalid" , |
2816 | dimsB[1], Y.dims()[1], this, |
2817 | CompareOperatorEqual<dim_t>()); |
2818 | if (C.getNode()) { |
2819 | if (C.dims().size() == 1) { |
2820 | isValid &= |
2821 | expectCompareTrue("Input C size invalid" , C.dims()[0], Y.dims()[1], |
2822 | this, CompareOperatorEqual<dim_t>()); |
2823 | } else { |
2824 | isValid &= |
2825 | expectCompareTrue("Input C dimension 0 size invalid" , C.dims()[0], |
2826 | Y.dims()[0], this, CompareOperatorEqual<dim_t>()); |
2827 | isValid &= |
2828 | expectCompareTrue("Input C dimension 1 size invalid" , C.dims()[1], |
2829 | Y.dims()[1], this, CompareOperatorEqual<dim_t>()); |
2830 | } |
2831 | } |
2832 | return isValid; |
2833 | } |
2834 | |
2835 | bool LSTMUnitNode::verify() const { |
2836 | bool isValid = true; |
2837 | NodeValue C = getC(); |
2838 | auto cDim = C.dims(); |
2839 | NodeValue Input = getInput(); |
2840 | auto inputDim = Input.dims(); |
2841 | |
2842 | isValid &= |
2843 | expectCompareTrue("Input must be 2D" , inputDim.size(), size_t(2), this); |
2844 | isValid &= |
2845 | expectCompareTrue("Cell State must be 2D" , cDim.size(), size_t(2), this); |
2846 | isValid &= expectCompareTrue("Input dims[1] must be 4 * C dims[1]" , |
2847 | inputDim[1], 4 * cDim[1], this); |
2848 | isValid &= |
2849 | expectCompareTrue("Input dims[0] must be must be the same to C dims[0]" , |
2850 | inputDim[0], cDim[0], this); |
2851 | |
2852 | return isValid; |
2853 | } |
2854 | |
2855 | bool FullyConnectedNode::verify() const { |
2856 | return verifyFullyConnected(getInput(), getWeights(), getBias(), getResult()); |
2857 | } |
2858 | |
2859 | bool FullyConnectedGradNode::verify() const { |
2860 | return verifyInputAndGradInputTypes(getBias(), getGradOfInputNamedBias(), |
2861 | this) && |
2862 | verifyInputAndGradInputTypes(getInput(), getGradOfInputNamedInput(), |
2863 | this) && |
2864 | verifyInputAndGradInputTypes(getWeights(), |
2865 | getGradOfInputNamedWeights(), this) && |
2866 | verifyOutputAndGradOutputTypes(getOriginalOutputForResult(), |
2867 | getGradOfOriginalOutputNamedResult(), |
2868 | this) && |
2869 | verifyFullyConnected( |
2870 | getGradOfInputNamedInput(), getGradOfInputNamedWeights(), |
2871 | getGradOfInputNamedBias(), getGradOfOriginalOutputNamedResult()); |
2872 | } |
2873 | |
2874 | bool ConcatNode::verify() const { |
2875 | auto inputs = getInputs(); |
2876 | auto dimension = getDim(); |
2877 | if (!expectCompareTrue("Empty concat?!" , inputs.empty(), false, this)) { |
2878 | return false; |
2879 | } |
2880 | bool isValid = expectCompareTrue("concat on invalid dimension" , |
2881 | inputs[0].dims().size(), size_t(dimension), |
2882 | this, CompareOperatorGreaterThan<size_t>()); |
2883 | |
2884 | size_t nbDims = inputs[0].dims().size(); |
2885 | for (size_t i = 1; i < inputs.size(); i++) { |
2886 | std::string istr = std::to_string(i); |
2887 | std::string msg = |
2888 | "input " + istr + "#dims are incompatible between elements" ; |
2889 | if (!expectCompareTrue(msg.c_str(), nbDims, inputs[i].dims().size(), |
2890 | this)) { |
2891 | isValid = false; |
2892 | // The following loop depends on this condition being true. |
2893 | continue; |
2894 | } |
2895 | for (size_t j = 0; j < nbDims; j++) { |
2896 | if (j == dimension) { |
2897 | continue; |
2898 | } |
2899 | std::string innerMsg = std::to_string(j); |
2900 | innerMsg = |
2901 | "Mismatching dimension " + innerMsg + " for input 0 and " + istr; |
2902 | isValid &= expectCompareTrue(innerMsg.c_str(), inputs[0].dims()[j], |
2903 | inputs[i].dims()[j], this); |
2904 | } |
2905 | |
2906 | for (size_t i = 0; i < inputs.size(); i++) { |
2907 | isValid &= checkType(inputs[i], getResult().getElementType(), this); |
2908 | isValid &= checkNotQuantizedOrSameParams(getResult().getType(), |
2909 | inputs[i].getType(), this); |
2910 | } |
2911 | } |
2912 | return isValid; |
2913 | } |
2914 | |
2915 | bool BatchBoxCoxNode::verify() const { |
2916 | auto result = getResult(); |
2917 | auto data = getInput(); |
2918 | auto lambda1 = getLambda1(); |
2919 | auto lambda2 = getLambda2(); |
2920 | bool isValid = checkSameType(lambda1, lambda2, this); |
2921 | isValid &= checkSameType(data, result, this); |
2922 | isValid &= checkType(data, lambda1.getElementType(), this); |
2923 | isValid &= checkType( |
2924 | data, {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty}, |
2925 | this); |
2926 | isValid &= expectCompareTrue("Input must be a 2D tensor" , data.dims().size(), |
2927 | size_t(2), this); |
2928 | isValid &= expectCompareTrue("Lambda1 must be a 1D vector" , |
2929 | lambda1.dims().size(), size_t(1), this); |
2930 | if (isValid) { |
2931 | isValid &= expectCompareTrue("Data dim 1 must equal lambda dim" , |
2932 | data.dims()[1], lambda1.dims()[0], this); |
2933 | } |
2934 | return isValid; |
2935 | } |
2936 | |
2937 | bool BroadcastNode::verify() const { |
2938 | const auto inputDims = getInput().dims(); |
2939 | const auto axis = getAxis(); |
2940 | const auto targetDims = getTargetDim(); |
2941 | bool isValid = (axis + inputDims.size() <= targetDims.size()); |
2942 | |
2943 | // Iterate over the new shape; if the original shape had a dimension here |
2944 | // (when considering the axis) then verify the dimension either matches the |
2945 | // new shape (no action taken) or == 1 (broadcast in that direction). |
2946 | for (dim_t i = 0; i < targetDims.size(); i++) { |
2947 | if (i >= axis && i < inputDims.size() + axis) { |
2948 | const int origIdx = i - axis; |
2949 | isValid &= |
2950 | (inputDims[origIdx] == targetDims[i] || inputDims[origIdx] == 1); |
2951 | } |
2952 | } |
2953 | isValid &= checkTypeIgnoreShape(getInput(), getResult(), this); |
2954 | |
2955 | return isValid; |
2956 | } |
2957 | |
2958 | bool ModuloNode::verify() const { return getDivisor() >= 1; } |
2959 | |
2960 | bool ExternalFunctionCallNode::verify() const { return true; } |
2961 | |
2962 | static bool verifyBatchedUnaryEmbeddingsBags(NodeValue dest, NodeValue weights, |
2963 | NodeValue indices, |
2964 | NodeValue offsets) { |
2965 | bool isValid = checkType(dest, weights.getElementType(), dest.getNode()); |
2966 | isValid &= checkType( |
2967 | indices, |
2968 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), |
2969 | dest.getNode()); |
2970 | isValid &= checkType( |
2971 | offsets, |
2972 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), |
2973 | dest.getNode()); |
2974 | isValid &= |
2975 | expectCompareTrue("Indices must be a 1D vector" , indices.dims().size(), |
2976 | size_t(1), dest.getNode()); |
2977 | isValid &= |
2978 | expectCompareTrue("Offsets must be a 1D vector" , offsets.dims().size(), |
2979 | size_t(1), dest.getNode()); |
2980 | isValid &= |
2981 | expectCompareTrue("Weights must be a 3D vector" , weights.dims().size(), |
2982 | size_t(3), dest.getNode()); |
2983 | return isValid; |
2984 | } |
2985 | |
2986 | bool BatchedUnaryEmbeddingsBagsNode::verify() const { |
2987 | return verifyBatchedUnaryEmbeddingsBags(getResult(), getWeights(), |
2988 | getIndices(), getOffsets()); |
2989 | } |
2990 | |
2991 | static bool verifyIntNBitSplitEmbeddingBagsNode(NodeValue dest, |
2992 | NodeValue devWeights, |
2993 | NodeValue weightsOffsets, |
2994 | NodeValue weightsPlacements, |
2995 | NodeValue weightsTys) { |
2996 | bool isValid = checkType(dest, devWeights.getElementType(), dest.getNode()); |
2997 | isValid &= checkType(devWeights, ElemKind::UInt8ITy, devWeights.getNode()); |
2998 | isValid &= checkSameShape(weightsPlacements, weightsTys, |
2999 | weightsPlacements.getNode()); |
3000 | isValid &= checkType( |
3001 | weightsOffsets, |
3002 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), |
3003 | dest.getNode()); |
3004 | return isValid; |
3005 | } |
3006 | |
3007 | bool IntNBitSplitEmbeddingBagsNode::verify() const { |
3008 | return verifyIntNBitSplitEmbeddingBagsNode( |
3009 | getResult(), getDevWeights(), getWeightsOffsets(), getWeightsPlacements(), |
3010 | getWeightsTys()); |
3011 | } |
3012 | |
3013 | static bool verifyIntNBitSplitEmbeddingWeightedBagsNode( |
3014 | NodeValue dest, NodeValue devWeights, NodeValue weightsOffsets, |
3015 | NodeValue weightsPlacements, NodeValue weightsTys, NodeValue indices, |
3016 | NodeValue indiceWeights) { |
3017 | bool isValid = checkType(dest, devWeights.getElementType(), dest.getNode()); |
3018 | isValid &= checkType(devWeights, ElemKind::UInt8ITy, devWeights.getNode()); |
3019 | isValid &= checkSameShape(weightsPlacements, weightsTys, |
3020 | weightsPlacements.getNode()); |
3021 | isValid &= checkType( |
3022 | weightsOffsets, |
3023 | llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), |
3024 | dest.getNode()); |
3025 | isValid &= checkSameShape(indices, indiceWeights, indiceWeights.getNode()); |
3026 | return isValid; |
3027 | } |
3028 | |
3029 | bool IntNBitSplitEmbeddingWeightedBagsNode::verify() const { |
3030 | return verifyIntNBitSplitEmbeddingWeightedBagsNode( |
3031 | getResult(), getDevWeights(), getWeightsOffsets(), getWeightsPlacements(), |
3032 | getWeightsTys(), getIndices(), getIndiceWeight()); |
3033 | } |
3034 | |
3035 | //===----------------------------------------------------------------------===// |
3036 | // Node hashing support |
3037 | //===----------------------------------------------------------------------===// |
3038 | |
3039 | /// These hash functions are required for using llvm::hash_combine. |
3040 | /// hash_value functions should be defined in the same namespace as |
3041 | /// the types they apply to. |
3042 | namespace glow { |
3043 | /// Convert a float into an unsigned integer binary representation. |
3044 | size_t toBinary(float f) { |
3045 | // Convert floating-point binary representation to integer. memcpy compiles |
3046 | // to a simple asm move on platforms we support. |
3047 | static_assert(sizeof(size_t) >= sizeof(float), |
3048 | "size_t is too small on this platform" ); |
3049 | size_t ret = 0; |
3050 | memcpy(&ret, &f, sizeof(float)); |
3051 | return ret; |
3052 | } |
3053 | /// Convert a collection of floats into a vector of |
3054 | /// unsigned integer binary representation. |
3055 | std::vector<size_t> toBinary(llvm::ArrayRef<float> vec) { |
3056 | std::vector<size_t> sizeVec(vec.size()); |
3057 | std::for_each(vec.begin(), vec.end(), [&sizeVec](float f) -> void { |
3058 | sizeVec.push_back(toBinary(f)); |
3059 | }); |
3060 | return sizeVec; |
3061 | } |
3062 | |
3063 | llvm::hash_code hash_value(const glow::Tensor &T) { return T.size(); } |
3064 | |
3065 | // Types are uniqued, so just a pointer can be used. |
3066 | llvm::hash_code hash_value(const glow::Type *T) { |
3067 | return llvm::hash_value((void *)(T)); |
3068 | } |
3069 | |
3070 | llvm::hash_code hash_value(glow::Node *N) { return N->getHash(); } |
3071 | |
3072 | llvm::hash_code hash_value(const glow::NodeValue &NV) { |
3073 | return llvm::hash_combine(NV.getNode(), NV.getResNo()); |
3074 | } |
3075 | |
3076 | llvm::hash_code hash_value(const glow::NodeHandle &NV) { |
3077 | return llvm::hash_combine(NV.getNode(), NV.getResNo()); |
3078 | } |
3079 | |
3080 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
3081 | FusedActivation fusedActivation) { |
3082 | switch (fusedActivation) { |
3083 | case FusedActivation::NONE: |
3084 | os << "NONE" ; |
3085 | break; |
3086 | case FusedActivation::RELU: |
3087 | os << "RELU" ; |
3088 | break; |
3089 | case FusedActivation::CLIP: |
3090 | os << "CLIP" ; |
3091 | break; |
3092 | case FusedActivation::SIGMOID: |
3093 | os << "SIGMOID" ; |
3094 | break; |
3095 | case FusedActivation::TANH: |
3096 | os << "TANH" ; |
3097 | break; |
3098 | case FusedActivation::LEAKY_RELU: |
3099 | os << "LEAKY_RELU" ; |
3100 | break; |
3101 | } |
3102 | return os; |
3103 | } |
3104 | |
3105 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, LUTOperator lutOperator) { |
3106 | switch (lutOperator) { |
3107 | case LUTOperator::NONE: |
3108 | os << "NONE" ; |
3109 | break; |
3110 | case LUTOperator::RELU: |
3111 | os << "RELU" ; |
3112 | break; |
3113 | case LUTOperator::CLIP: |
3114 | os << "CLIP" ; |
3115 | break; |
3116 | case LUTOperator::SIGMOID: |
3117 | os << "SIGMOID" ; |
3118 | break; |
3119 | case LUTOperator::TANH: |
3120 | os << "TANH" ; |
3121 | break; |
3122 | case LUTOperator::LEAKY_RELU: |
3123 | os << "LEAKY_RELU" ; |
3124 | break; |
3125 | } |
3126 | return os; |
3127 | } |
3128 | |
3129 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ConvolutionLayout layout) { |
3130 | switch (layout) { |
3131 | case ConvolutionLayout::NCHW: |
3132 | os << "NCHW" ; |
3133 | break; |
3134 | case ConvolutionLayout::NHWC: |
3135 | os << "NHWC" ; |
3136 | break; |
3137 | case ConvolutionLayout::NCTHW: |
3138 | os << "NCTHW" ; |
3139 | break; |
3140 | case ConvolutionLayout::NTHWC: |
3141 | os << "NTHWC" ; |
3142 | break; |
3143 | default: |
3144 | llvm_unreachable("Unknown format" ); |
3145 | } |
3146 | return os; |
3147 | } |
3148 | |
3149 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, LengthsMode lengthsMode) { |
3150 | switch (lengthsMode) { |
3151 | case LengthsMode::AllOne: |
3152 | os << "AllOne" ; |
3153 | break; |
3154 | case LengthsMode::Variable: |
3155 | os << "Variable" ; |
3156 | break; |
3157 | } |
3158 | return os; |
3159 | } |
3160 | } // namespace glow |
3161 | |