1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Graph/Nodes.h"
18#include "glow/Base/Type.h"
19#include "glow/Graph/Graph.h"
20#include "glow/Graph/VerifierHelper.h"
21#include "glow/Support/Support.h"
22
23using namespace glow;
24
25bool Storage::isEqual(const Storage &other) const {
26 /// A storage should be equal only to itself!
27 return this == &other;
28}
29
30llvm::hash_code Constant::getHash() const {
31 return llvm::hash_combine(getName(), getType());
32}
33
34llvm::hash_code Placeholder::getHash() const {
35 return llvm::hash_combine(getName());
36}
37
38//===----------------------------------------------------------------------===//
39// Visitor methods
40//===----------------------------------------------------------------------===//
41
42void Storage::visit(Node *parent, NodeWalker *visitor) {
43 if (!visitor->shouldVisit(parent, this)) {
44 return;
45 }
46 visitor->pre(parent, this);
47 visitor->post(parent, this);
48}
49
50void Storage::visit(const Node *parent, NodeWalker *visitor) const {
51 if (!visitor->shouldVisit(parent, this)) {
52 return;
53 }
54 visitor->pre(parent, this);
55 visitor->post(parent, this);
56}
57
58//===----------------------------------------------------------------------===//
59// Edge getters methods
60//===----------------------------------------------------------------------===//
61unsigned Storage::getNumInputs() const { return 0; }
62
63std::string Storage::getInputName(unsigned idx) const {
64 llvm_unreachable("Invalid index");
65}
66
67NodeValue Storage::getNthInput(unsigned idx) {
68 llvm_unreachable("Invalid index");
69}
70
71llvm::StringRef Storage::getOutputName(unsigned idx) const {
72 if (idx == 0) {
73 return "Output";
74 }
75 llvm_unreachable("Invalid index");
76}
77
78bool Storage::hasSideEffects() const { return false; }
79
80Node *Storage::clone() const { llvm_unreachable("Storage can't be cloned."); }
81
82//===----------------------------------------------------------------------===//
83// Debug description methods
84//===----------------------------------------------------------------------===//
85
86std::string Constant::getDebugDesc(bool skipUsers) const {
87 DescriptionBuilder db(getKindName());
88 db.addParam("Name", separateString(getName(), 100, "\n"))
89 .addParam("Layout", getLayout())
90 .addParam("Output", *getType());
91 if (!skipUsers) {
92 db.addParam("Users", getNumUsers());
93 }
94 return db;
95}
96
97std::string Placeholder::getDebugDesc(bool skipUsers) const {
98 DescriptionBuilder db(getKindName());
99 db.addParam("Name", separateString(getName(), 100, "\n"))
100 .addParam("Layout", getLayout())
101 .addParam("Output", *getType())
102 .addParam("Trainable", isTraining())
103 .addParam("Static", isStatic());
104 if (!skipUsers) {
105 db.addParam("Users", getNumUsers());
106 }
107 return db;
108}
109
110//===----------------------------------------------------------------------===//
111// Nodes verification
112//===----------------------------------------------------------------------===//
113
114static bool verifyConvFilter(const Node *parent, NodeValue filter,
115 const ShapeNHWC &idim, const ShapeNHWC &odim,
116 const ShapeHW &kdim, unsigned_t group) {
117 const dim_t filterDims[] = {odim.c, kdim.height, kdim.width,
118 idim.c / (dim_t)group};
119 return expectCompareTrue("Invalid filter dimensions",
120 filter.getType()->dims(),
121 llvm::makeArrayRef(filterDims), parent);
122}
123
124static bool verifyConvFilter(const Node *parent, NodeValue filter,
125 const ShapeNCHW &idim, const ShapeNCHW &odim,
126 const ShapeHW &kdim, unsigned_t group) {
127 const dim_t filterDims[] = {odim.c, idim.c / (dim_t)group, kdim.height,
128 kdim.width};
129
130 return expectCompareTrue("Invalid filter dimensions",
131 filter.getType()->dims(),
132 llvm::makeArrayRef(filterDims), parent);
133}
134
135template <typename Shape>
136static bool verifyConvolution(NodeValue src, NodeValue dest, NodeValue filter,
137 NodeValue bias,
138 llvm::ArrayRef<unsigned_t> kernels,
139 llvm::ArrayRef<unsigned_t> strides,
140 llvm::ArrayRef<unsigned_t> pads, unsigned_t group,
141 llvm::ArrayRef<unsigned_t> dilation,
142 bool checkBiasType = true) {
143 const Node *parent = dest.getNode();
144 bool isValid = checkType(src, dest.getElementType(), parent);
145 isValid &= checkType(src, filter.getElementType(), parent);
146 if (checkBiasType) {
147 // Non quantization type check.
148 if (src.getElementType() == ElemKind::FloatTy) {
149 isValid &= checkType(bias, ElemKind::FloatTy, parent);
150 }
151 // Quantization type check.
152 if (src.getElementType() == ElemKind::Int8QTy) {
153 isValid &=
154 expectCompareTrue("Bias type should be float, Int8 or Int32 for Conv",
155 bias.getElementType() == ElemKind::FloatTy ||
156 bias.getElementType() == ElemKind::Int8QTy ||
157 bias.getElementType() == ElemKind::Int32QTy,
158 true, parent);
159 }
160 }
161 Shape idim(src.getType()->dims());
162 Shape odim(dest.getType()->dims());
163 PaddingTLBR pdim(pads);
164 ShapeHW kdim(kernels);
165 isValid &= expectCompareTrue("buffer height too small for selected stride",
166 idim.h + pdim.top + pdim.bottom, kdim.height,
167 parent, CompareOperatorGreaterEqual<dim_t>());
168 isValid &= expectCompareTrue("buffer width too small for selected stride",
169 idim.w + pdim.left + pdim.right, kdim.width,
170 parent, CompareOperatorGreaterEqual<dim_t>());
171 isValid &= expectCompareTrue("channels number must be divisible by groups",
172 idim.c % group, dim_t(0), parent);
173 isValid &= expectCompareTrue("Dilation should have same length as Stride",
174 dilation.size(), strides.size(), parent);
175
176 auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides,
177 pads, dilation);
178 isValid &=
179 expectCompareTrue("Invalid output dimension N", odim.n, idim.n, parent);
180 isValid &= expectCompareTrue("Invalid output dimension H", odim.h,
181 outSz.first, parent);
182 isValid &= expectCompareTrue("Invalid output dimension W", odim.w,
183 outSz.second, parent);
184 isValid &= expectCompareTrue("Invalid output dimension C", odim.c % group,
185 dim_t(0), parent);
186
187 isValid &= verifyConvFilter(parent, filter, idim, odim, kdim, group);
188
189 const dim_t biasDims[] = {odim.c};
190
191 isValid &=
192 expectCompareTrue("Invalid bias dimensions", bias.getType()->dims(),
193 llvm::makeArrayRef(biasDims), parent);
194 return isValid;
195}
196
197static bool verifyConvolution3D(NodeValue src, NodeValue dest, NodeValue filter,
198 NodeValue bias,
199 llvm::ArrayRef<unsigned_t> kernels,
200 llvm::ArrayRef<unsigned_t> strides,
201 llvm::ArrayRef<unsigned_t> pads,
202 unsigned_t group) {
203
204 const Node *parent = dest.getNode();
205 bool isValid = checkType(src, dest.getElementType(), parent);
206 isValid &= checkType(src, filter.getElementType(), parent);
207 // Non quantization type check.
208 if (src.getElementType() == ElemKind::FloatTy) {
209 isValid &= checkType(bias, ElemKind::FloatTy, parent);
210 }
211 // Quantization type check.
212 if (src.getElementType() == ElemKind::Int8QTy) {
213 isValid &=
214 expectCompareTrue("Bias type should be Float, Int8 or Int32 for Conv3D",
215 bias.getElementType() == ElemKind::FloatTy ||
216 bias.getElementType() == ElemKind::Int8QTy ||
217 bias.getElementType() == ElemKind::Int32QTy,
218 true, parent);
219 }
220 ShapeNTHWC idim(src.getType()->dims());
221 ShapeNTHWC odim(dest.getType()->dims());
222 PaddingNFTBLR pdim(pads);
223 ShapeTHW kdim(kernels);
224 isValid &= expectCompareTrue("buffer height too small for selected stride",
225 idim.h + pdim.top + pdim.bottom, kdim.height,
226 parent, CompareOperatorGreaterEqual<dim_t>());
227 isValid &= expectCompareTrue("buffer width too small for selected stride",
228 idim.w + pdim.left + pdim.right, kdim.width,
229 parent, CompareOperatorGreaterEqual<dim_t>());
230 isValid &=
231 expectCompareTrue("buffer time too small for selected stride",
232 idim.t + pdim.near + pdim.far, kdim.temporal_frames,
233 parent, CompareOperatorGreaterEqual<dim_t>());
234 isValid &= expectCompareTrue("channels number must be divisible by groups",
235 idim.c % group, dim_t(0), parent);
236
237 auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels,
238 strides, pads);
239 isValid &=
240 expectCompareTrue("Invalid output dimension N", odim.n, idim.n, parent);
241 isValid &= expectCompareTrue("Invalid output dimension T", odim.t,
242 outSz.temporal_frames, parent);
243 isValid &= expectCompareTrue("Invalid output dimension H", odim.h,
244 outSz.height, parent);
245 isValid &= expectCompareTrue("Invalid output dimension W", odim.w,
246 outSz.width, parent);
247 isValid &= expectCompareTrue("Invalid output dimension C", odim.c % group,
248 dim_t(0), parent);
249
250 const dim_t filterDims[] = {odim.c, kdim.temporal_frames, kdim.height,
251 kdim.width, idim.c / group};
252 isValid &=
253 expectCompareTrue("Invalid filter dimensions", filter.getType()->dims(),
254 llvm::makeArrayRef(filterDims), parent);
255 const dim_t biasDims[] = {odim.c};
256 isValid &=
257 expectCompareTrue("Invalid bias dimensions", bias.getType()->dims(),
258 llvm::makeArrayRef(biasDims), parent);
259 return isValid;
260}
261
262static bool verifyConvTranspose(NodeValue src, NodeValue dest, NodeValue filter,
263 llvm::ArrayRef<unsigned_t> kernels,
264 llvm::ArrayRef<unsigned_t> strides,
265 llvm::ArrayRef<unsigned_t> pads,
266 unsigned_t group,
267 llvm::ArrayRef<unsigned_t> dilation) {
268 const Node *parent = dest.getNode();
269 bool isValid = checkType(src, dest.getElementType(), parent);
270 isValid &= checkType(src, filter.getElementType(), parent);
271 ShapeNHWC idim(src.getType()->dims());
272 ShapeNHWC odim(dest.getType()->dims());
273 PaddingTLBR pdim(pads);
274 ShapeHW kdim(kernels);
275 // TODO: any kernel size check in respect to input ? In contrast to Conv,
276 // seems kernel can be any size.
277
278 isValid &= expectCompareTrue("channels number must be divisible by groups",
279 idim.c % group, dim_t(0), parent);
280
281 isValid &= expectCompareTrue("Stride should be less than kernel.",
282 strides[0] <= kernels[0], true, parent);
283
284 isValid &= expectCompareTrue("Stride should be less than kernel.",
285 strides[1] <= kernels[1], true, parent);
286
287 isValid &= expectCompareTrue("channels number must be divisible by groups",
288 idim.c % group, dim_t(0), parent);
289
290 isValid &= expectCompareTrue("Dilation should have same length as Stride",
291 dilation.size(), strides.size(), parent);
292
293 auto outSz = calculateConvTransposeOutputDims(idim.h, idim.w, kernels,
294 strides, pads, dilation);
295 (void)outSz;
296 isValid &=
297 expectCompareTrue("Invalid output dimension N", odim.n, idim.n, parent);
298
299 isValid &=
300 expectCompareTrue("Invalid output dimension HT", odim.h, outSz.first,
301 parent, CompareOperatorGreaterEqual<dim_t>());
302 isValid &=
303 expectCompareTrue("Invalid output dimension WT", odim.w, outSz.second,
304 parent, CompareOperatorGreaterEqual<dim_t>());
305
306 isValid &= expectCompareTrue("Invalid output dimension CT", odim.c % group,
307 dim_t(0), parent);
308
309 const dim_t filterDims[] = {odim.c / (dim_t)group, kdim.height, kdim.width,
310 idim.c};
311 isValid &=
312 expectCompareTrue("Invalid filter dimensions", filter.getType()->dims(),
313 llvm::makeArrayRef(filterDims), parent);
314 return isValid;
315}
316
317static bool verifyFullyConnected(NodeValue src, NodeValue weights,
318 NodeValue bias, NodeValue dest) {
319 const Node *parent = dest.getNode();
320 bool isValid = expectCompareTrue("FC input must be 2D", size_t(2),
321 src.dims().size(), parent);
322 isValid &= expectCompareTrue("FC weights must be 2D", size_t(2),
323 weights.dims().size(), parent);
324 isValid &= expectCompareTrue("FC bias must be 1D", size_t(1),
325 bias.dims().size(), parent);
326 isValid &= expectCompareTrue("Mismatch between source and dest dimensions",
327 src.dims()[0], dest.dims()[0], parent);
328 isValid &= expectCompareTrue("Mismatch between source and weight dimensions",
329 src.dims()[1], weights.dims()[0], parent);
330 isValid &= expectCompareTrue("Inconsistent bias/dest sizes", bias.dims()[0],
331 weights.dims()[1], parent);
332 isValid &= expectCompareTrue("Inconsistent weights/dest sizes",
333 weights.dims()[1], dest.dims()[1], parent);
334
335 if (src.getElementType() == ElemKind::Int8QTy) {
336 isValid &=
337 expectCompareTrue("Bias type should be Int8, Int32 or FP32 for FC",
338 bias.getElementType() == ElemKind::Int8QTy ||
339 bias.getElementType() == ElemKind::Int32QTy ||
340 bias.getElementType() == ElemKind::FloatTy,
341 true, parent);
342 }
343 return isValid;
344}
345
346template <typename Shape>
347static bool verifyPool(NodeValue src, NodeValue dest,
348 llvm::ArrayRef<unsigned_t> kernels,
349 llvm::ArrayRef<unsigned_t> strides,
350 llvm::ArrayRef<unsigned_t> pads, bool isAvgPool = true) {
351 const Node *parent = dest.getNode();
352 Shape idim(src.getType()->dims());
353 Shape odim(dest.getType()->dims());
354 PaddingTLBR pdim(pads);
355 ShapeHW kdim(kernels);
356
357 bool isValid =
358 expectCompareTrue("buffer height too small for selected stride",
359 idim.h + pdim.top + pdim.bottom, kdim.height, parent,
360 CompareOperatorGreaterEqual<dim_t>());
361 isValid &= expectCompareTrue("buffer width too small for selected stride",
362 idim.w + pdim.left + pdim.right, kdim.width,
363 parent, CompareOperatorGreaterEqual<dim_t>());
364
365 auto outSz =
366 calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
367 Shape exp(idim);
368 exp.h = outSz.first;
369 exp.w = outSz.second;
370 isValid &=
371 expectCompareTrue("Unexpected output dimensions", exp, odim, parent);
372
373 // For quantized AvgPool, the scale and offset of its input and output could
374 // be different. But for quantized MaxPool, the scale and offset of its input
375 // and output should be the same.
376 isValid &= checkSameIsQuantized(src.getType(), dest.getType(), parent);
377 if (!isAvgPool) {
378 isValid &= checkTypeIgnoreShape(src, dest, parent);
379 }
380 return isValid;
381}
382
383template <typename Shape>
384static bool
385verifyPool3D(NodeValue src, NodeValue dest, llvm::ArrayRef<unsigned_t> kernels,
386 llvm::ArrayRef<unsigned_t> strides,
387 llvm::ArrayRef<unsigned_t> pads, bool isAvgPool = true) {
388 const Node *parent = dest.getNode();
389 Shape idim(src.getType()->dims());
390 Shape odim(dest.getType()->dims());
391 PaddingTLNBRF pdim(pads);
392 ShapeTHW kdim(kernels);
393
394 bool isValid =
395 expectCompareTrue("buffer height too small for selected stride",
396 idim.h + pdim.top + pdim.bottom, kdim.height, parent,
397 CompareOperatorGreaterEqual<dim_t>());
398 isValid &= expectCompareTrue("buffer width too small for selected stride",
399 idim.w + pdim.left + pdim.right, kdim.width,
400 parent, CompareOperatorGreaterEqual<dim_t>());
401 isValid &=
402 expectCompareTrue("buffer temporal_frames are too small for "
403 "selected stride",
404 idim.t + pdim.near + pdim.far, kdim.temporal_frames,
405 parent, CompareOperatorGreaterEqual<dim_t>());
406
407 auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels,
408 strides, pads);
409 Shape exp(idim);
410 exp.t = outSz.temporal_frames;
411 exp.h = outSz.height;
412 exp.w = outSz.width;
413 isValid &=
414 expectCompareTrue("Unexpected output dimensions", exp, odim, parent);
415
416 // For quantized AvgPool, the scale and offset of its input and output could
417 // be different. But for quantized MaxPool, the scale and offset of its input
418 // and output should be the same.
419 isValid =
420 isValid && checkSameIsQuantized(src.getType(), dest.getType(), parent);
421 if (!isAvgPool) {
422 isValid = isValid && checkTypeIgnoreShape(src, dest, parent);
423 }
424
425 return isValid;
426}
427
428static bool verifyBatchNormalization(NodeValue src, NodeValue dest,
429 NodeValue bias, NodeValue scale,
430 NodeValue mean, NodeValue var,
431 unsigned_t channel) {
432 const Node *parent = dest.getNode();
433
434 // Source and Dest can have different quantization params
435 // but need to match in shape and element type.
436 bool isValid = checkSameShape(dest, src, parent);
437 isValid = isValid && checkType(dest, src.getElementType(), parent);
438
439 isValid =
440 isValid &&
441 expectCompareTrue(
442 "Require at least two input dims i.e., batch and channel dimensions",
443 src.dims().size(), (size_t)1, parent,
444 CompareOperatorGreaterThan<size_t>());
445
446 // Figure out how many channels are in the tensor.
447 dim_t channels = src.dims()[channel];
448
449 const dim_t expArray[] = {channels};
450 auto exp = llvm::makeArrayRef(expArray);
451 isValid = isValid && expectCompareTrue("Invalid bias dimension",
452 bias.getType()->dims(), exp, parent);
453 isValid = isValid && expectCompareTrue("Invalid scale dimension",
454 scale.getType()->dims(), exp, parent);
455 isValid = isValid && expectCompareTrue("Invalid mean dimension",
456 mean.getType()->dims(), exp, parent);
457 isValid = isValid && expectCompareTrue("Invalid var dimension",
458 var.getType()->dims(), exp, parent);
459 return isValid;
460}
461
462static bool verifyInstanceNormalization(NodeValue src, NodeValue dest,
463 NodeValue bias, NodeValue scale,
464 unsigned_t channel) {
465 const Node *parent = dest.getNode();
466 bool isValid = true;
467 if (src.getType()->isQuantizedType()) {
468 isValid &= checkType(src, dest.getElementType(), dest.getNode());
469 isValid &= checkSameShape(src, dest, parent);
470 } else {
471 isValid &= checkSameType(src, dest, parent);
472 }
473
474 isValid &= expectCompareTrue(
475 "Require at least two input dims i.e., batch and channel dimensions",
476 src.dims().size(), (size_t)1, parent,
477 CompareOperatorGreaterThan<size_t>());
478
479 // Figure out how many channels are in the tensor.
480 dim_t channels = src.dims()[channel];
481
482 const dim_t expArray[] = {channels};
483 auto exp = llvm::makeArrayRef(expArray);
484 isValid &= expectCompareTrue("Invalid bias dimension", bias.getType()->dims(),
485 exp, parent);
486 isValid &= expectCompareTrue("Invalid scale dimension",
487 scale.getType()->dims(), exp, parent);
488
489 return isValid;
490}
491
492static bool verifyActivation(NodeValue src, NodeValue dest) {
493 const Node *parent = dest.getNode();
494 bool isValid = checkSameIsQuantized(src.getType(), dest.getType(), parent);
495 if (src.getType()->isQuantizedType()) {
496 isValid &= checkType(src, dest.getElementType(), dest.getNode());
497 isValid &= checkSameShape(src, dest, parent);
498 } else {
499 isValid &= checkSameType(src, dest, parent);
500 }
501 return isValid;
502}
503
504static bool verifySoftMax(NodeValue src, NodeValue dest) {
505 const Node *parent = dest.getNode();
506 if (src.getType()->isQuantizedType()) {
507 return checkType(src, dest.getElementType(), parent) &&
508 checkSameShape(src, dest, parent);
509 }
510 return checkSameType(src, dest, parent);
511}
512
513static bool verifyLogSoftMax(NodeValue src, NodeValue dest) {
514 const Node *parent = dest.getNode();
515 if (src.getType()->isQuantizedType()) {
516 return checkType(src, dest.getElementType(), parent) &&
517 checkSameShape(src, dest, parent);
518 }
519 return checkSameType(src, dest, parent);
520}
521
522static bool verifyCrossEntropyLoss(NodeValue P, NodeValue CE,
523 NodeValue labels) {
524 const Node *parent = CE.getNode();
525 bool isValid = checkType(P, CE.getElementType(), parent);
526 isValid &= expectCompareTrue("Mismatching shape", P.dims()[0],
527 labels.dims()[0], parent);
528 return isValid;
529}
530
531static bool verifyLocalResponseNormalization(NodeValue src, NodeValue dest) {
532 return checkSameType(src, dest, dest.getNode());
533}
534
535static bool verifyArithmetic(NodeValue LHS, NodeValue RHS, NodeValue res) {
536 return checkSameShape(res, LHS, res.getNode()) &&
537 checkSameShape(LHS, RHS, res.getNode());
538}
539
540static bool verifyRelu(NodeValue result, NodeValue input) {
541 const Node *parent = result.getNode();
542 if (input.getType()->isQuantizedType()) {
543 return checkSameIsQuantized(input.getType(), result.getType(), parent) &&
544 checkSameShape(result, input, parent);
545 }
546 return checkSameType(result, input, parent);
547}
548
549static bool verifyPRelu(NodeValue result, NodeValue input, NodeValue slope) {
550 const Node *parent = result.getNode();
551 if (input.getType()->isQuantizedType()) {
552 return checkSameIsQuantized(input.getType(), result.getType(), parent) &&
553 checkSameIsQuantized(input.getType(), slope.getType(), parent) &&
554 checkSameShape(result, input, parent) &&
555 checkSameShape(slope, input, parent);
556 }
557 return checkSameType(result, input, parent) &&
558 checkSameType(slope, input, parent) &&
559 checkSameShape(slope, input, parent);
560}
561
562static bool verifyRegression(NodeValue src, NodeValue dest,
563 NodeValue expected) {
564 return checkSameType(src, dest, dest.getNode()) &&
565 checkSameType(dest, expected, dest.getNode());
566}
567
568static bool verifySparseLengthsSum(NodeValue dest, NodeValue data,
569 NodeValue indices, NodeValue lengths) {
570 bool isValid = checkType(dest, data.getElementType(), dest.getNode());
571 isValid &= checkType(indices, {ElemKind::Int64ITy, ElemKind::Int32ITy},
572 dest.getNode());
573 isValid &= checkType(lengths, ElemKind::Int32ITy, dest.getNode());
574 isValid &=
575 expectCompareTrue("Indices must be a 1D vector", indices.dims().size(),
576 size_t(1), dest.getNode());
577 isValid &=
578 expectCompareTrue("Lengths must be a 1D vector", lengths.dims().size(),
579 size_t(1), dest.getNode());
580 return isValid;
581}
582
583static bool verifySparseLengthsWeightedSum(NodeValue dest, NodeValue data,
584 NodeValue weights, NodeValue indices,
585 NodeValue lengths) {
586 bool isValid = checkType(dest, data.getElementType(), dest.getNode());
587 isValid &= checkType(weights, data.getElementType(), dest.getNode());
588 isValid &= checkType(indices, {ElemKind::Int64ITy, ElemKind::Int32ITy},
589 dest.getNode());
590 isValid &= checkType(lengths, ElemKind::Int32ITy, dest.getNode());
591 isValid &=
592 expectCompareTrue("Indices must be a 1D vector", indices.dims().size(),
593 size_t(1), dest.getNode());
594 isValid &=
595 expectCompareTrue("Lengths must be a 1D vector", lengths.dims().size(),
596 size_t(1), dest.getNode());
597 isValid &=
598 expectCompareTrue("Weights must be a 1D vector", weights.dims().size(),
599 size_t(1), dest.getNode());
600
601 isValid &=
602 expectCompareTrue("Weights and Indices must have the same size",
603 weights.dims()[0], indices.dims()[0], dest.getNode());
604 return isValid;
605}
606
607static bool verifyEmbedding(NodeValue dest, NodeValue weights,
608 NodeValue indices) {
609 bool isValid = checkType(dest, weights.getElementType(), dest.getNode());
610 isValid &= checkType(
611 indices,
612 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}),
613 dest.getNode());
614 isValid &=
615 expectCompareTrue("Weights must be a 2D tensor", weights.dims().size(),
616 size_t(2), weights.getNode());
617 return isValid;
618}
619
620static bool verifyEmbeddingBag(NodeValue dest, NodeValue data,
621 NodeValue weights, NodeValue indices,
622 NodeValue offsets) {
623 bool isValid = checkType(dest, data.getElementType(), dest.getNode());
624 isValid &= checkType(weights, data.getElementType(), dest.getNode());
625 isValid &= checkType(
626 indices,
627 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}),
628 dest.getNode());
629 isValid &= checkType(
630 offsets,
631 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}),
632 dest.getNode());
633 isValid &=
634 expectCompareTrue("Indices must be a 1D vector", indices.dims().size(),
635 size_t(1), dest.getNode());
636 isValid &=
637 expectCompareTrue("Offsets must be a 1D vector", offsets.dims().size(),
638 size_t(1), dest.getNode());
639 isValid &=
640 expectCompareTrue("Weights must be a 1D vector", weights.dims().size(),
641 size_t(1), dest.getNode());
642
643 isValid &=
644 expectCompareTrue("Weights and Indices must have the same size",
645 weights.dims()[0], indices.dims()[0], dest.getNode());
646 return isValid;
647}
648
649bool HardSwishNode::verify() const {
650 return checkSameType(getInput(), getResult(), this);
651}
652
653bool PadNode::verify() const {
654 // Pad is currently only supported for constant padding.
655 return expectCompareTrue("only the 'constant' mode is currrently supported",
656 getMode() == PaddingMode::CONSTANT, true,
657 getResult().getNode());
658}
659
660bool ConvolutionNode::verify() const {
661 if (getLayout() == NHWC) {
662 return verifyConvolution<ShapeNHWC>(getInput(), getResult(), getFilter(),
663 getBias(), Kernels_, Strides_, Pads_,
664 Group_, Dilation_);
665 } else {
666 return verifyConvolution<ShapeNCHW>(getInput(), getResult(), getFilter(),
667 getBias(), Kernels_, Strides_, Pads_,
668 Group_, Dilation_);
669 }
670}
671
672bool ChannelwiseQuantizedConvolutionNode::verify() const {
673 auto input_dims = getInput().getType()->dims();
674 bool isValid = false;
675 bool isConv3D = (input_dims.size() == 5);
676 if (isConv3D) {
677 isValid = verifyConvolution3D(getInput(), getResult(), getFilter(),
678 getBias(), Kernels_, Strides_, Pads_, Group_);
679
680 if (!all_of(Dilation_.begin(), Dilation_.end(),
681 [](unsigned_t i) { return i == 1; })) {
682 report("For Conv3D dilation must be 1");
683 }
684 } else {
685 isValid = verifyConvolution<ShapeNHWC>(
686 getInput(), getResult(), getFilter(), getBias(), Kernels_, Strides_,
687 Pads_, Group_, Dilation_, /* checkBiasType */ false);
688 }
689
690 isValid &= checkType(getResult(), ElemKind::Int8QTy, this);
691 isValid &= checkType(getInput(), ElemKind::Int8QTy, this);
692 isValid &= checkType(getFilter(), ElemKind::Int8QTy, this);
693 isValid &= checkType(
694 getBias(), {ElemKind::Int8QTy, ElemKind::Int32QTy, ElemKind::FloatTy},
695 this);
696
697 // Check qparam types.
698 isValid &= checkType(getFilterOffsets(), ElemKind::Int32ITy, this);
699 isValid &= checkType(getFilterScales(), ElemKind::FloatTy, this);
700 isValid &= checkType(getBiasOffsets(), ElemKind::Int32ITy, this);
701 isValid &= checkType(getBiasScales(), ElemKind::FloatTy, this);
702
703 // Check qparam dimensions.
704 isValid &=
705 expectCompareTrue("Filter offsets must be a 1D vector",
706 getFilterOffsets().dims().size(), size_t(1), this);
707 isValid &=
708 expectCompareTrue("Filter scales must be a 1D vector",
709 getFilterScales().dims().size(), size_t(1), this);
710 isValid &= expectCompareTrue("Bias offsets must be a 1D vector",
711 getBiasOffsets().dims().size(), size_t(1), this);
712 isValid &= expectCompareTrue("Bias scales must be a 1D vector",
713 getBiasScales().dims().size(), size_t(1), this);
714
715 // Check qparam sizes.
716 isValid &= expectCompareTrue(
717 "There must be one filter offset qparam per output channel",
718 getFilterOffsets().dims()[0], dim_t(getResult().dims().back()), this);
719 isValid &= expectCompareTrue(
720 "There must be one filter scale qparam per output channel",
721 getFilterScales().dims()[0], dim_t(getResult().dims().back()), this);
722 isValid &= expectCompareTrue(
723 "There must be one bias offset qparam per output channel",
724 getBiasOffsets().dims()[0], dim_t(getResult().dims().back()), this);
725 isValid &= expectCompareTrue(
726 "There must be one bias scale qparam per output channel",
727 getBiasScales().dims()[0], dim_t(getResult().dims().back()), this);
728
729 return isValid;
730}
731
732bool Convolution3DNode::verify() const {
733 return verifyConvolution3D(getInput(), getResult(), getFilter(), getBias(),
734 Kernels_, Strides_, Pads_, Group_);
735}
736
737bool ConvTransposeNode::verify() const {
738 return verifyConvTranspose(getInput(), getResult(), getFilter(), Kernels_,
739 Strides_, Pads_, Group_, Dilation_);
740}
741
742/// Verify that types of an input and its gradient are the same.
743static bool verifyInputAndGradInputTypes(NodeValue input, NodeValue gradInput,
744 const Node *parent) {
745 return checkSameType(input, gradInput, parent);
746}
747
748/// Verify that types of an output and its gradient are the same.
749static bool verifyOutputAndGradOutputTypes(NodeValue output,
750 NodeValue gradOutput,
751 const Node *parent) {
752 return checkSameType(output, gradOutput, parent);
753}
754
755bool Constant::verify() const {
756 return expectCompareTrue("Underlying tensor type doesn't match constant type",
757 *getType(), getPayload().getType(), this);
758}
759
760bool ConvolutionGradNode::verify() const {
761 bool isValid = verifyInputAndGradInputTypes(getInput(),
762 getGradOfInputNamedInput(), this);
763 isValid &= verifyInputAndGradInputTypes(getFilter(),
764 getGradOfInputNamedFilter(), this);
765 isValid &=
766 verifyInputAndGradInputTypes(getBias(), getGradOfInputNamedBias(), this);
767 isValid &= verifyOutputAndGradOutputTypes(
768 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
769 if (getLayout() == NHWC) {
770 isValid &= verifyConvolution<ShapeNHWC>(
771 getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(),
772 getGradOfInputNamedFilter(), getGradOfInputNamedBias(), Kernels_,
773 Strides_, Pads_, Group_, Dilation_);
774 } else {
775 isValid &= verifyConvolution<ShapeNCHW>(
776 getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(),
777 getGradOfInputNamedFilter(), getGradOfInputNamedBias(), Kernels_,
778 Strides_, Pads_, Group_, Dilation_);
779 }
780
781 return isValid;
782}
783
784bool Convolution3DGradNode::verify() const {
785 bool isValid = verifyInputAndGradInputTypes(getInput(),
786 getGradOfInputNamedInput(), this);
787 isValid &= verifyInputAndGradInputTypes(getFilter(),
788 getGradOfInputNamedFilter(), this);
789 isValid &=
790 verifyInputAndGradInputTypes(getBias(), getGradOfInputNamedBias(), this);
791 isValid &= verifyOutputAndGradOutputTypes(
792 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
793 isValid &= verifyConvolution3D(
794 getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(),
795 getGradOfInputNamedFilter(), getGradOfInputNamedBias(), Kernels_,
796 Strides_, Pads_, Group_);
797 return isValid;
798}
799
800/// \returns the number of columns of data from a fused \p type (i.e. not
801/// considering the columns for per row scale/offsets).
802static size_t getNumDataColumnsFromFused(TypeRef type) {
803 size_t n = type->dims()[1];
804 switch (type->getElementType()) {
805 case ElemKind::UInt8FusedQTy:
806 return n - 2 * sizeof(float);
807 case ElemKind::UInt8FusedFP16QTy:
808 return n - 2 * sizeof(float16_t);
809 case ElemKind::UInt4FusedFP16QTy:
810 return (n - 2 * sizeof(float16_t)) * 2;
811 case ElemKind::UInt4FusedQTy:
812 return (n - 2 * sizeof(float)) * 2;
813 default:
814 llvm_unreachable("Not supported Fused ElemKind");
815 }
816}
817
818bool ConvertToNode::verify() const {
819 TypeRef srcTy = getInput().getType();
820 TypeRef dstTy = getResult().getType();
821 const bool srcIsFused = isFusedQuantizedElemKind(srcTy->getElementType());
822 const bool dstIsFused = isFusedQuantizedElemKind(dstTy->getElementType());
823
824 bool isValid = expectCompareTrue(
825 "Conversion of src and dst with mismatched fused property is not yet "
826 "implemented",
827 (srcIsFused && dstIsFused) || (!srcIsFused && !dstIsFused), true, this);
828
829 if (srcIsFused && dstIsFused) {
830 size_t srcNumCols = getNumDataColumnsFromFused(srcTy);
831 size_t dstNumCols = getNumDataColumnsFromFused(dstTy);
832 return expectCompareTrue("Shapes of data for fused kinds do not match",
833 srcNumCols, dstNumCols, this);
834 }
835
836 isValid &= checkSameShape(getInput(), getResult(), this);
837 isValid &= expectCompareTrue(
838 "Quantized conversion should use Dequantize, Quantize and Rescale",
839 srcTy->isQuantizedType() || dstTy->isQuantizedType(), false, this);
840 return isValid;
841}
842
843bool MaxPoolNode::verify() const {
844 switch (getLayout()) {
845 case NHWC:
846 return verifyPool<ShapeNHWC>(getInput(), getResult(), Kernels_, Strides_,
847 Pads_,
848 /* isAvgPool */ false);
849 case NCHW:
850 return verifyPool<ShapeNCHW>(getInput(), getResult(), Kernels_, Strides_,
851 Pads_,
852 /* isAvgPool */ false);
853 default: // MaxPool3D is unsupported
854 return false;
855 }
856}
857
858bool AvgPoolNode::verify() const {
859 switch (getLayout()) {
860 case NHWC:
861 return verifyPool<ShapeNHWC>(getInput(), getResult(), Kernels_, Strides_,
862 Pads_);
863 case NCHW:
864 return verifyPool<ShapeNCHW>(getInput(), getResult(), Kernels_, Strides_,
865 Pads_);
866 case NTHWC:
867 return verifyPool3D<ShapeNTHWC>(getInput(), getResult(), Kernels_, Strides_,
868 Pads_);
869 case NCTHW:
870 return verifyPool3D<ShapeNCTHW>(getInput(), getResult(), Kernels_, Strides_,
871 Pads_);
872 default:
873 llvm_unreachable("Unsupported format");
874 }
875}
876
877bool AdaptiveAvgPoolGradNode::verify() const {
878 bool isValid = verifyInputAndGradInputTypes(getInput(),
879 getGradOfInputNamedInput(), this);
880 isValid &= verifyOutputAndGradOutputTypes(
881 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
882
883 ShapeNHWC idim(getInput().getType()->dims());
884 ShapeNHWC odim(getOriginalOutputForResult().getType()->dims());
885
886 isValid &= expectCompareTrue(
887 "expected the same number of channels for input and output", odim.c,
888 idim.c, this);
889
890 isValid &= expectCompareTrue(
891 "expected the same number of batches for input and output", odim.n,
892 idim.n, this);
893
894 isValid &= expectCompareTrue("height too small for averaging area", odim.h,
895 idim.h, this, CompareOperatorLessEqual<dim_t>());
896
897 isValid &= expectCompareTrue("width too small for averaging area", odim.w,
898 idim.w, this, CompareOperatorLessEqual<dim_t>());
899
900 return isValid;
901}
902
903bool AdaptiveAvgPoolNode::verify() const {
904 bool isValid = checkTypeIgnoreShape(getInput(), getResult(), this);
905
906 TypeRef inTy = getInput().getType();
907 TypeRef outTy = getResult().getType();
908
909 isValid &= expectCompareTrue("Input should have 4 dimensions",
910 inTy->dims().size(), (size_t)4, this);
911
912 isValid &= expectCompareTrue("Output should have 4 dimensions",
913 outTy->dims().size(), (size_t)4, this);
914
915 if (!isValid) {
916 return false;
917 }
918
919 isValid &= expectCompareTrue(
920 "Output should have the same number of batches as the input",
921 inTy->dims()[0], outTy->dims()[0], this);
922
923 isValid &= expectCompareTrue(
924 "Output should have the same number of channels as the input",
925 inTy->dims()[3], outTy->dims()[3], this);
926
927 isValid &= expectCompareTrue(
928 "Output should not have more height than the input", inTy->dims()[1],
929 outTy->dims()[1], this, CompareOperatorGreaterEqual<dim_t>());
930
931 isValid &= expectCompareTrue(
932 "Output should not have more width than the input", inTy->dims()[2],
933 outTy->dims()[2], this, CompareOperatorGreaterEqual<dim_t>());
934
935 return isValid;
936}
937
938bool MaxPoolGradNode::verify() const {
939 bool isValid = verifyInputAndGradInputTypes(getInput(),
940 getGradOfInputNamedInput(), this);
941 isValid &= verifyOutputAndGradOutputTypes(
942 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
943
944 if (getLayout() == NHWC) {
945 isValid &= verifyPool<ShapeNHWC>(
946 getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(),
947 Kernels_, Strides_, Pads_, /* isAvgPool */ false);
948 } else {
949 isValid &= verifyPool<ShapeNCHW>(
950 getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(),
951 Kernels_, Strides_, Pads_, /* isAvgPool */ false);
952 }
953 return isValid;
954}
955
956bool AvgPoolGradNode::verify() const {
957 bool isValid = verifyInputAndGradInputTypes(getInput(),
958 getGradOfInputNamedInput(), this);
959 isValid &= verifyOutputAndGradOutputTypes(
960 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
961
962 switch (getLayout()) {
963 case NHWC:
964 return isValid &&
965 verifyPool<ShapeNHWC>(getGradOfInputNamedInput(),
966 getGradOfOriginalOutputNamedResult(), Kernels_,
967 Strides_, Pads_);
968 case NCHW:
969 return isValid &&
970 verifyPool<ShapeNCHW>(getGradOfInputNamedInput(),
971 getGradOfOriginalOutputNamedResult(), Kernels_,
972 Strides_, Pads_);
973 case NTHWC:
974 return isValid &&
975 verifyPool3D<ShapeNTHWC>(getGradOfInputNamedInput(),
976 getGradOfOriginalOutputNamedResult(),
977 Kernels_, Strides_, Pads_);
978 case NCTHW:
979 return isValid &&
980 verifyPool3D<ShapeNCTHW>(getGradOfInputNamedInput(),
981 getGradOfOriginalOutputNamedResult(),
982 Kernels_, Strides_, Pads_);
983 default:
984 llvm_unreachable("Unsupported format");
985 }
986}
987
988bool MatMulNode::verify() const {
989 auto lhs = getLHS();
990 auto rhs = getRHS();
991 auto dest = getResult();
992
993 auto LDims = lhs.dims();
994 auto RDims = rhs.dims();
995 auto DDims = dest.dims();
996 bool isValid = expectCompareTrue("LHS input must be 2 dimensional.",
997 LDims.size(), size_t(2), this);
998 isValid &= expectCompareTrue("RHS input must be 2 dimensional.", RDims.size(),
999 size_t(2), this);
1000 isValid &= expectCompareTrue("Invalid MatMul dimensions", DDims.size(),
1001 size_t(2), this);
1002
1003 auto elem = dest.getType()->getElementType();
1004 isValid &= checkType(lhs, elem, this);
1005 isValid &= checkType(rhs, elem, this);
1006
1007 isValid &=
1008 expectCompareTrue("Invalid row dimensions", LDims[0], DDims[0], this);
1009 isValid &=
1010 expectCompareTrue("Invalid column dimensions", RDims[1], DDims[1], this);
1011 return isValid;
1012}
1013
1014bool BatchMatMulNode::verify() const {
1015 auto LHS = getLHS();
1016 auto RHS = getRHS();
1017 auto dest = getResult();
1018
1019 bool isValid = expectCompareTrue("LHS input must be 3 dimensional.",
1020 LHS.dims().size(), size_t(3), this);
1021 isValid &= expectCompareTrue("RHS input must be 3 dimensional.",
1022 RHS.dims().size(), size_t(3), this);
1023 isValid &= expectCompareTrue("Result must be 3 dimensional.",
1024 dest.dims().size(), size_t(3), this);
1025 isValid &= expectCompareTrue("LHS and RHS inputs must have same batch size.",
1026 LHS.dims()[0], RHS.dims()[0], this);
1027 isValid &= expectCompareTrue("Result must have same batch size as inputs.",
1028 LHS.dims()[0], dest.dims()[0], this);
1029
1030 const dim_t numBatches = LHS.dims()[0];
1031 const dim_t N = LHS.dims()[1];
1032 const dim_t M = LHS.dims()[2];
1033 const dim_t P = RHS.dims()[2];
1034 isValid &= expectCompareTrue("Inputs have invalid dimensions.", RHS.dims()[1],
1035 M, this);
1036 isValid &= expectCompareTrue("Result has invalid dimensions given inputs.",
1037 dest.dims(), {numBatches, N, P}, this);
1038
1039 auto elemType = dest.getType()->getElementType();
1040 isValid &= checkType(LHS, elemType, this);
1041 isValid &= checkType(RHS, elemType, this);
1042
1043 return isValid;
1044}
1045
1046bool SigmoidNode::verify() const {
1047 return verifyActivation(getInput(), getResult());
1048}
1049
1050bool SigmoidGradNode::verify() const {
1051 bool isValid = verifyInputAndGradInputTypes(getInput(),
1052 getGradOfInputNamedInput(), this);
1053 isValid &= verifyOutputAndGradOutputTypes(
1054 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
1055 isValid &= verifyActivation(getGradOfInputNamedInput(),
1056 getGradOfOriginalOutputNamedResult());
1057 return isValid;
1058}
1059
1060bool SoftPlusNode::verify() const {
1061 return verifyActivation(getInput(), getResult());
1062}
1063
1064bool SwishNode::verify() const {
1065 return verifyActivation(getInput(), getResult());
1066}
1067
1068bool TanhNode::verify() const {
1069 return verifyActivation(getInput(), getResult());
1070}
1071
1072bool TanhGradNode::verify() const {
1073 bool isValid = verifyInputAndGradInputTypes(getInput(),
1074 getGradOfInputNamedInput(), this);
1075 isValid &= verifyOutputAndGradOutputTypes(
1076 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
1077 isValid &= verifyActivation(getGradOfInputNamedInput(),
1078 getGradOfOriginalOutputNamedResult());
1079 return isValid;
1080}
1081
1082bool LogitNode::verify() const {
1083 const Node *parent = getResult().getNode();
1084 bool isValid = checkSameType(getInput(), getResult(), parent);
1085 isValid &= checkSameShape(getInput(), getResult(), parent);
1086 isValid &= expectCompareTrue(
1087 "Clamping parameter eps must be strictly positive", getEpsilon(), 0.0f,
1088 this, CompareOperatorGreaterThan<float>());
1089 isValid &=
1090 expectCompareTrue("Clamping parameter eps must be less than 0.5",
1091 getEpsilon(), 0.5f, this, CompareOperatorLess<float>());
1092 return isValid;
1093}
1094
1095bool ExpNode::verify() const {
1096 const Node *parent = getResult().getNode();
1097 bool isValid =
1098 checkSameIsQuantized(getInput().getType(), getResult().getType(), parent);
1099
1100 if (getInput().getType()->isQuantizedType()) {
1101 isValid &= checkType(getInput(), getResult().getElementType(),
1102 getResult().getNode());
1103 isValid &= checkSameShape(getInput(), getResult(), parent);
1104 } else {
1105 isValid &= checkSameType(getInput(), getResult(), parent);
1106 }
1107
1108 return isValid;
1109}
1110
1111bool BucketizeNode::verify() const {
1112 bool isValid = checkSameShape(getInput(), getResult(), this);
1113 isValid &= !getBoundaries().empty();
1114 isValid &= std::is_sorted(getBoundaries().begin(), getBoundaries().end());
1115 return isValid;
1116}
1117
1118bool SoftMaxNode::verify() const {
1119 return verifySoftMax(getInput(), getResult());
1120}
1121
1122bool SoftMaxGradNode::verify() const {
1123 bool isValid = verifyInputAndGradInputTypes(getInput(),
1124 getGradOfInputNamedInput(), this);
1125 isValid &= verifyInputAndGradInputTypes(getSelected(),
1126 getGradOfInputNamedSelected(), this);
1127 isValid &= verifyOutputAndGradOutputTypes(
1128 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
1129 isValid &= verifySoftMax(getGradOfInputNamedInput(),
1130 getGradOfOriginalOutputNamedResult());
1131 return isValid;
1132}
1133
1134bool LogSoftMaxNode::verify() const {
1135 return verifyLogSoftMax(getInput(), getResult());
1136}
1137
1138bool LogSoftMaxGradNode::verify() const {
1139 bool isValid = verifyInputAndGradInputTypes(getInput(),
1140 getGradOfInputNamedInput(), this);
1141 isValid &= ((verifyInputAndGradInputTypes(
1142 getSelected(), getGradOfInputNamedSelected(), this))
1143 ? 1
1144 : 0);
1145 isValid &= ((verifyOutputAndGradOutputTypes(
1146 getOriginalOutputForResult(),
1147 getGradOfOriginalOutputNamedResult(), this))
1148 ? 1
1149 : 0);
1150 isValid &= ((verifyLogSoftMax(getGradOfInputNamedInput(),
1151 getGradOfOriginalOutputNamedResult()))
1152 ? 1
1153 : 0);
1154 return isValid;
1155}
1156
1157bool CrossEntropyLossNode::verify() const {
1158 return verifyCrossEntropyLoss(getP(), getCE(), getLabels());
1159}
1160
1161bool CrossEntropyLossGradNode::verify() const {
1162 bool isValid = verifyInputAndGradInputTypes(
1163 getLabels(), getGradOfInputNamedLabels(), this);
1164 isValid &= verifyInputAndGradInputTypes(getP(), getGradOfInputNamedP(), this);
1165 isValid &= verifyOutputAndGradOutputTypes(
1166 getOriginalOutputForCE(), getGradOfOriginalOutputNamedCE(), this);
1167 isValid &= verifyCrossEntropyLoss(getGradOfInputNamedP(),
1168 getGradOfOriginalOutputNamedCE(),
1169 getGradOfInputNamedLabels());
1170 return isValid;
1171}
1172
1173bool ReshapeNode::verify() const {
1174 bool isValid = expectCompareTrue("Reshape into a different size",
1175 getResult().getType()->size(),
1176 getInput().getType()->size(), this);
1177 isValid &= checkTypeIgnoreShape(getResult(), getInput(), this);
1178 return isValid;
1179}
1180
1181bool TransposeNode::verify() const {
1182 auto dest = getResult();
1183 auto src = getInput();
1184 ShapeVector shape;
1185
1186 auto dims = src.dims();
1187 for (size_t i = 0; i < dims.size(); i++) {
1188 shape.push_back(dims[Shuffle_[i]]);
1189 }
1190
1191 bool isValid = expectCompareTrue("Invalid transpose dims", dest.dims(),
1192 llvm::makeArrayRef(shape), this);
1193 isValid &= checkTypeIgnoreShape(dest, src, this);
1194 return isValid;
1195}
1196
1197bool FlipNode::verify() const {
1198 auto dest = getResult();
1199 auto src = getInput();
1200 dim_t axis = getAxis();
1201 bool isValid = checkSameType(src, dest, this);
1202 isValid &= expectCompareTrue("Invalid axis", axis, (dim_t)src.dims().size(),
1203 this, CompareOperatorLess<dim_t>());
1204 return isValid;
1205}
1206
1207bool ChannelShuffleNode::verify() const {
1208 bool isValid = expectCompareTrue("Channel shuffle into a different size.",
1209 getResult().getType()->size(),
1210 getInput().getType()->size(), this);
1211 isValid &= checkTypeIgnoreShape(getResult(), getInput(), this);
1212 return isValid;
1213}
1214
1215bool SplatNode::verify() const { return true; }
1216
1217bool TouchNode::verify() const { return true; }
1218
1219bool TraceEventNode::verify() const { return true; }
1220
1221bool ClipNode::verify() const {
1222 bool isValid =
1223 expectCompareTrue("Clip max must be greater than min", getMin(), getMax(),
1224 this, CompareOperatorLess<float>());
1225 if (getInput().getType()->isQuantizedType()) {
1226 isValid &=
1227 checkSameIsQuantized(getInput().getType(), getResult().getType(), this);
1228 isValid &= checkSameShape(getInput(), getResult(), this);
1229 } else {
1230 isValid &= checkSameType(getInput(), getResult(), this);
1231 }
1232 return isValid;
1233}
1234
1235bool InsertTensorNode::verify() const {
1236 auto dest = getBig();
1237 auto src = getSmall();
1238 auto offsets = getStart();
1239 dim_t numDims = dest.dims().size();
1240 dim_t axis = getAxis();
1241 dim_t count = getCount();
1242
1243 bool isValid = expectCompareTrue("Invalid number of dimensions", numDims,
1244 (dim_t)src.dims().size(), this);
1245 isValid &= expectCompareTrue("Invalid number of dimensions for offsets",
1246 numDims, (dim_t)offsets.size(), this);
1247
1248 if (!isValid) {
1249 // The following loop may be out-of-bound if the previous
1250 // comparisons failed.
1251 return false;
1252 }
1253
1254 isValid &= checkType(dest, src.getType()->getElementType(), this);
1255 if (dest.getType()->isQuantizedType()) {
1256 isValid &= expectCompareTrue("Scales of Big and Small must match.",
1257 src.getType()->getScale(),
1258 dest.getType()->getScale(), this);
1259 isValid &= expectCompareTrue("Offsets of Big and Small must match.",
1260 src.getType()->getOffset(),
1261 dest.getType()->getOffset(), this);
1262 }
1263
1264 for (unsigned i = 0; i < numDims; i++) {
1265 // TODO: We could come up with a mechanism to lazy compute that
1266 // string since it is going to be used only in case of an error.
1267 // However, this function is not performance critical so leave it
1268 // this way for now.
1269 std::string msg = std::to_string(i);
1270 msg = "out of bounds for index " + msg;
1271 isValid &= expectCompareTrue(msg.c_str(), src.dims()[i] + offsets[i],
1272 dest.dims()[i], this,
1273 CompareOperatorLessEqual<dim_t>());
1274 }
1275
1276 isValid &= expectCompareTrue("Invalid axis", axis, (dim_t)src.dims().size(),
1277 this, CompareOperatorLessEqual<dim_t>());
1278 for (dim_t i = 0; i < src.dims().size(); i++) {
1279 dim_t mul = (i == axis) ? count : 1;
1280 std::string msg = std::to_string(i);
1281 msg = "Small does not fit inside Big for index " + msg;
1282 isValid &=
1283 expectCompareTrue(msg.c_str(), src.dims()[i] * mul, dest.dims()[i],
1284 this, CompareOperatorLessEqual<dim_t>());
1285 }
1286 return isValid;
1287}
1288
1289bool SliceNode::verify() const {
1290 auto dest = getResult();
1291 auto src = getInput();
1292 auto offsets = getStart();
1293 size_t numDims = dest.dims().size();
1294 bool isValid = expectCompareTrue("Invalid number of dimensions", numDims,
1295 src.dims().size(), this);
1296 isValid &= expectCompareTrue("Invalid number of dimensions", numDims,
1297 offsets.size(), this);
1298
1299 if (!isValid) {
1300 // The following loop may be out-of-bound if the previous
1301 // comparisons failed.
1302 return false;
1303 }
1304
1305 for (unsigned i = 0; i < numDims; i++) {
1306 std::string msg = std::to_string(i);
1307 msg = "out of bounds for index " + msg;
1308 isValid &= expectCompareTrue(msg.c_str(), dest.dims()[i] + offsets[i],
1309 src.dims()[i], this,
1310 CompareOperatorLessEqual<dim_t>());
1311 }
1312 isValid &= checkNotQuantizedOrSameParams(dest.getType(), src.getType(), this);
1313 return isValid;
1314}
1315
1316bool TileNode::verify() const {
1317 auto dest = getResult();
1318 auto src = getInput();
1319 size_t axis = getAxis();
1320 unsigned count = getCount();
1321
1322 bool isValid = expectCompareTrue("Invalid axis", axis, src.dims().size(),
1323 this, CompareOperatorLessEqual<size_t>());
1324
1325 for (dim_t i = 0; i < src.dims().size(); i++) {
1326 dim_t mul = (i == axis) ? count : 1;
1327 std::string msg = std::to_string(i);
1328 msg = "Incorrect output shape for dim " + msg;
1329 isValid &= expectCompareTrue(msg.c_str(), src.dims()[i] * mul,
1330 dest.dims()[i], this);
1331 }
1332 isValid &= checkTypeIgnoreShape(src, dest, this);
1333 return isValid;
1334}
1335
1336bool BatchNormalizationNode::verify() const {
1337 return verifyBatchNormalization(getInput(), getResult(), getBias(),
1338 getScale(), getMean(), getVar(), ChannelIdx_);
1339}
1340
1341bool InstanceNormalizationNode::verify() const {
1342 return verifyInstanceNormalization(getInput(), getResult(), getBias(),
1343 getScale(), ChannelIdx_);
1344}
1345
1346bool LayerNormalizationNode::verify() const {
1347 auto dest = getResult();
1348 auto src = getInput();
1349 auto scale = getScale();
1350 auto bias = getBias();
1351
1352 // Check input and output have same ElemKind.
1353 bool isValid = checkType(src, dest.getElementType(), this);
1354
1355 // Check scale and bias have same ElemKind
1356 isValid &= checkType(bias, scale.getElementType(), this);
1357
1358 // Check inputs/outputs and scale/bias match shapes.
1359 isValid &= checkSameShape(src, dest, this);
1360 isValid &= checkSameShape(scale, bias, this);
1361
1362 // Check that the dims of scale and bias match the end of src.
1363 auto srcDims = src.getType()->dims();
1364 auto scaleDims = scale.getType()->dims();
1365 isValid &= expectCompareTrue("Expected input to have more dims than scale",
1366 srcDims.size(), scaleDims.size(), this,
1367 CompareOperatorGreaterThan<size_t>());
1368 for (size_t i = 0; i < scaleDims.size(); ++i) {
1369 size_t scaleI = scaleDims.size() - i - 1;
1370 size_t srcI = srcDims.size() - i - 1;
1371 isValid &=
1372 expectCompareTrue("Expected scale dims to match the end of src dims",
1373 scaleDims[scaleI], srcDims[srcI], this);
1374 }
1375
1376 return isValid;
1377}
1378
1379bool BatchNormalizationGradNode::verify() const {
1380 bool isValid =
1381 verifyInputAndGradInputTypes(getBias(), getGradOfInputNamedBias(), this);
1382 isValid &= verifyInputAndGradInputTypes(getInput(),
1383 getGradOfInputNamedInput(), this);
1384 isValid &=
1385 verifyInputAndGradInputTypes(getMean(), getGradOfInputNamedMean(), this);
1386 isValid &= verifyInputAndGradInputTypes(getScale(),
1387 getGradOfInputNamedScale(), this);
1388 isValid &=
1389 verifyInputAndGradInputTypes(getVar(), getGradOfInputNamedVar(), this);
1390 isValid &= verifyOutputAndGradOutputTypes(
1391 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
1392 isValid &= verifyBatchNormalization(
1393 getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(),
1394 getGradOfInputNamedBias(), getGradOfInputNamedScale(),
1395 getGradOfInputNamedMean(), getGradOfInputNamedVar(), ChannelIdx_);
1396 return isValid;
1397}
1398
1399bool MeanVarNormalizationNode::verify() const {
1400 return checkType(getMean(), ElemKind::FloatTy, this) &&
1401 checkSameType(getMean(), getVar(), this);
1402}
1403
1404bool LocalResponseNormalizationNode::verify() const {
1405 return verifyLocalResponseNormalization(getInput(), getResult());
1406}
1407
1408bool LocalResponseNormalizationGradNode::verify() const {
1409 bool isValid = verifyInputAndGradInputTypes(getInput(),
1410 getGradOfInputNamedInput(), this);
1411 isValid &= verifyOutputAndGradOutputTypes(
1412 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), this);
1413 isValid &= verifyLocalResponseNormalization(
1414 getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult());
1415 return isValid;
1416}
1417
1418#define VERIFY_UNARY_LOGICAL(NODE_NAME_) \
1419 bool NODE_NAME_##Node::verify() const { \
1420 bool isValid = checkSameShape(getInput(), getResult(), this); \
1421 isValid &= checkType(getInput(), ElemKind::BoolTy, this); \
1422 isValid &= checkType(getResult(), ElemKind::BoolTy, this); \
1423 return isValid; \
1424 }
1425VERIFY_UNARY_LOGICAL(Not)
1426#undef VERIFY_UNARY_LOGICAL
1427
1428bool SignNode::verify() const {
1429 if (getResult().getType()->isQuantizedType()) {
1430 bool isValid = checkSameShape(getInput(), getResult(), this);
1431 isValid &=
1432 checkType(getResult(), getInput().getType()->getElementType(), this);
1433 return isValid;
1434 }
1435 return checkSameType(getInput(), getResult(), this);
1436}
1437
1438#define VERIFY_BINARY_LOGICAL(NODE_NAME_) \
1439 bool NODE_NAME_##Node::verify() const { \
1440 bool isValid = checkSameShape(getLHS(), getResult(), this); \
1441 isValid &= checkSameShape(getRHS(), getResult(), this); \
1442 isValid &= checkType(getLHS(), ElemKind::BoolTy, this); \
1443 isValid &= checkType(getRHS(), ElemKind::BoolTy, this); \
1444 isValid &= checkType(getResult(), ElemKind::BoolTy, this); \
1445 return isValid; \
1446 }
1447VERIFY_BINARY_LOGICAL(And)
1448VERIFY_BINARY_LOGICAL(Or)
1449VERIFY_BINARY_LOGICAL(Xor)
1450#undef VERIFY_BINARY_LOGICAL
1451
1452#define VERIFY_BINARY(NODE_NAME_) \
1453 bool NODE_NAME_##Node::verify() const { \
1454 bool isValid = checkSameShape(getLHS(), getResult(), this); \
1455 isValid &= checkSameShape(getRHS(), getResult(), this); \
1456 isValid &= checkSameType(getLHS(), getResult(), this); \
1457 isValid &= checkSameType(getRHS(), getResult(), this); \
1458 return isValid; \
1459 }
1460VERIFY_BINARY(BitwiseAnd)
1461VERIFY_BINARY(BitwiseOr)
1462VERIFY_BINARY(BitwiseXor)
1463#undef VERIFY_BINARY
1464
1465#define VERIFY_UNARY_ARITHMETIC(NODE_NAME_) \
1466 bool NODE_NAME_##Node::verify() const { \
1467 return checkSameShape(getInput(), getResult(), this); \
1468 }
1469VERIFY_UNARY_ARITHMETIC(Abs);
1470VERIFY_UNARY_ARITHMETIC(Neg);
1471VERIFY_UNARY_ARITHMETIC(Floor);
1472VERIFY_UNARY_ARITHMETIC(Ceil);
1473VERIFY_UNARY_ARITHMETIC(Round);
1474VERIFY_UNARY_ARITHMETIC(Sqrt);
1475VERIFY_UNARY_ARITHMETIC(Rsqrt);
1476VERIFY_UNARY_ARITHMETIC(Reciprocal);
1477VERIFY_UNARY_ARITHMETIC(Sin);
1478VERIFY_UNARY_ARITHMETIC(Cos);
1479VERIFY_UNARY_ARITHMETIC(Erf);
1480VERIFY_UNARY_ARITHMETIC(Truncate);
1481VERIFY_UNARY_ARITHMETIC(BitwiseNot);
1482#undef VERIFY_UNARY_ARITHMETIC
1483
1484#define VERIFY_ARITHMETIC(NODE_NAME_) \
1485 bool NODE_NAME_##Node::verify() const { \
1486 return verifyArithmetic(getLHS(), getRHS(), getResult()); \
1487 }
1488VERIFY_ARITHMETIC(Add);
1489VERIFY_ARITHMETIC(Mul);
1490VERIFY_ARITHMETIC(Sub);
1491VERIFY_ARITHMETIC(Div);
1492VERIFY_ARITHMETIC(FloorDiv);
1493VERIFY_ARITHMETIC(Max);
1494VERIFY_ARITHMETIC(Min);
1495VERIFY_ARITHMETIC(Pow);
1496#undef VERIFY_ARITHMETIC
1497
1498#define VERIFY_ARITHMETIC(NODE_NAME_) \
1499 bool NODE_NAME_##Node::verify() const { \
1500 bool isValid = verifyInputAndGradInputTypes( \
1501 getLHS(), getGradOfInputNamedLHS(), this); \
1502 isValid &= verifyInputAndGradInputTypes(getRHS(), \
1503 getGradOfInputNamedRHS(), this); \
1504 isValid &= verifyOutputAndGradOutputTypes( \
1505 getOriginalOutputForResult(), getGradOfOriginalOutputNamedResult(), \
1506 this); \
1507 isValid &= \
1508 verifyArithmetic(getGradOfInputNamedLHS(), getGradOfInputNamedRHS(), \
1509 getGradOfOriginalOutputNamedResult()); \
1510 return isValid; \
1511 }
1512VERIFY_ARITHMETIC(AddGrad);
1513VERIFY_ARITHMETIC(MulGrad);
1514VERIFY_ARITHMETIC(SubGrad);
1515VERIFY_ARITHMETIC(DivGrad);
1516#undef VERIFY_ARITHMETIC
1517
1518#define VERIFY_CMP(NODE_NAME_) \
1519 bool NODE_NAME_##Node::verify() const { \
1520 bool isValid = checkSameShape(getLHS(), getRHS(), this); \
1521 isValid &= checkSameShape(getResult(), getLHS(), this); \
1522 isValid &= checkType(getLHS(), getRHS().getElementType(), this); \
1523 isValid &= checkType(getResult(), ElemKind::BoolTy, this); \
1524 return isValid; \
1525 }
1526
1527VERIFY_CMP(CmpEQ)
1528VERIFY_CMP(CmpNEQ)
1529VERIFY_CMP(CmpLT)
1530VERIFY_CMP(CmpLTE)
1531#undef VERIFY_CMP
1532
1533// Trigonometric Ops
1534#define VERIFY_TRIGONOMERTRIC_OPS(NODE_NAME_) \
1535 bool NODE_NAME_##Node::verify() const { \
1536 return checkSameShape(getInput(), getResult(), this); \
1537 }
1538VERIFY_TRIGONOMERTRIC_OPS(Acos);
1539VERIFY_TRIGONOMERTRIC_OPS(Asin);
1540VERIFY_TRIGONOMERTRIC_OPS(Atan);
1541#undef VERIFY_UNARY_ARITHMETIC
1542
1543bool FmodNode::verify() const {
1544 auto res = getResult();
1545 auto LHS = getLHS();
1546 auto RHS = getRHS();
1547 return checkSameShape(res, LHS, res.getNode()) &&
1548 checkSameShape(LHS, RHS, res.getNode()) &&
1549 LHS.getElementType() != ElemKind::Int8QTy &&
1550 RHS.getElementType() != ElemKind::Int8QTy;
1551}
1552
1553bool BatchedPairwiseDotProductNode::verify() const {
1554 auto inputs = getInputs();
1555
1556 bool isValid = inputs.size() > 1;
1557
1558 if (isValid) {
1559 auto firstInput = inputs[0];
1560
1561 isValid &= firstInput.getElementType() == ElemKind::FloatTy;
1562 isValid &= firstInput.getType()->dims().size() == 2;
1563
1564 for (auto &in : inputs) {
1565 isValid &= checkSameType(in, firstInput, this);
1566 }
1567
1568 isValid &= getResult().getElementType() == ElemKind::FloatTy;
1569 isValid &=
1570 getResult().getType()->dims()[0] == firstInput.getType()->dims()[0];
1571 isValid &= getResult().getType()->dims()[1] ==
1572 inputs.size() * (inputs.size() - 1) / 2;
1573 }
1574
1575 return isValid;
1576}
1577
1578bool BatchedPairwiseDotProductGradNode::verify() const { return true; }
1579
1580bool BatchedAddNode::verify() const {
1581 auto batchShape = getBatch().dims();
1582 auto rhsShape = getSlice().dims();
1583 bool isValid = expectCompareTrue("Invalid shape", batchShape.drop_front(),
1584 rhsShape, this);
1585 isValid &= checkSameShape(getBatch(), getResult(), this);
1586
1587 if (getBatch().getType()->isQuantizedType()) {
1588 expectCompareTrue("Mismatched slice element types",
1589 getSlice().getType()->isQuantizedType(), true, this);
1590 } else {
1591 isValid &=
1592 checkType(getBatch(), getSlice().getType()->getElementType(), this);
1593 }
1594 return isValid;
1595}
1596
1597bool BatchedMulNode::verify() const {
1598 auto batchShape = getBatch().dims();
1599 auto rhsShape = getSlice().dims();
1600 bool isValid = expectCompareTrue("Invalid shape", batchShape.drop_front(),
1601 rhsShape, this);
1602 isValid &= checkSameShape(getBatch(), getResult(), this);
1603
1604 if (getBatch().getType()->isQuantizedType()) {
1605 expectCompareTrue("Mismatched slice element types",
1606 getSlice().getType()->isQuantizedType(), true, this);
1607 } else {
1608 isValid &=
1609 checkType(getBatch(), getSlice().getType()->getElementType(), this);
1610 }
1611 return isValid;
1612}
1613
1614bool BatchedReduceSumSquareNode::verify() const {
1615 bool isValid = checkType(getResult(), getBatch().getElementType(), this);
1616
1617 isValid &=
1618 expectCompareTrue("Invalid shape", getBatch().dims().size(), size_t(0),
1619 this, CompareOperatorGreaterThan<size_t>());
1620 return isValid;
1621}
1622
1623bool CumSumNode::verify() const {
1624 return checkSameType(getResult(), getInput(), this);
1625}
1626
1627bool LengthsSumNode::verify() const {
1628 return expectCompareTrue("Lengths must be a 1D vector",
1629 getLengths().dims().size(), size_t(1), this);
1630}
1631
1632// Define verification for Reduction operations.
1633#define DEFINE_BATCHED_REDUCTION_VERIFICATION(name) \
1634 bool name##Node::verify() const { \
1635 bool isValid = checkType(getResult(), getBatch().getElementType(), this); \
1636 isValid &= expectCompareTrue("Invalid shape", getBatch().dims().size(), \
1637 size_t(0), this, \
1638 CompareOperatorGreaterThan<size_t>()); \
1639 return isValid; \
1640 }
1641
1642DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceAdd)
1643DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceMean)
1644DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceMin)
1645DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceMax)
1646DEFINE_BATCHED_REDUCTION_VERIFICATION(BatchedReduceProd)
1647
1648#undef DEFINE_BATCHED_REDUCTION_VERIFICATION
1649
1650bool SparseLengthsSumNode::verify() const {
1651 return verifySparseLengthsSum(getResult(), getData(), getIndices(),
1652 getLengths());
1653}
1654
1655bool SparseLengthsSumGradNode::verify() const {
1656 // Same checks as SparseLengthsSumNode.
1657 bool isValid = verifySparseLengthsSum(getOriginalOutputForResult(), getData(),
1658 getIndices(), getLengths());
1659
1660 // Checks on gradient inputs/outputs.
1661 isValid &= checkSameType(getGradOfOriginalOutputNamedResult(),
1662 getOriginalOutputForResult(), this);
1663 isValid &= checkSameType(getGradOfInputNamedData(), getData(), this);
1664 isValid &= checkSameType(getGradOfInputNamedIndices(), getIndices(), this);
1665 isValid &= checkSameType(getGradOfInputNamedLengths(), getLengths(), this);
1666 return isValid;
1667}
1668
1669bool SparseLengthsWeightedSumNode::verify() const {
1670 return verifySparseLengthsWeightedSum(getResult(), getData(), getWeights(),
1671 getIndices(), getLengths());
1672}
1673
1674bool SparseLengthsWeightedSumGradNode::verify() const {
1675 // Same checks as SparseLengthsWeightedSumNode.
1676 bool isValid =
1677 verifySparseLengthsWeightedSum(getOriginalOutputForResult(), getData(),
1678 getWeights(), getIndices(), getLengths());
1679
1680 // Checks on gradient inputs/outputs.
1681 isValid &= checkSameType(getGradOfOriginalOutputNamedResult(),
1682 getOriginalOutputForResult(), this);
1683 isValid &= checkSameType(getGradOfInputNamedData(), getData(), this);
1684 isValid &= checkSameType(getGradOfInputNamedWeights(), getWeights(), this);
1685 isValid &= checkSameType(getGradOfInputNamedIndices(), getIndices(), this);
1686 isValid &= checkSameType(getGradOfInputNamedLengths(), getLengths(), this);
1687 return isValid;
1688}
1689
1690bool EmbeddingBagNode::verify() const {
1691 return verifyEmbeddingBag(getResult(), getData(), getWeights(), getIndices(),
1692 getOffsets());
1693}
1694
1695bool EmbeddingNode::verify() const {
1696 return verifyEmbedding(getResult(), getWeights(), getIndices());
1697}
1698
1699bool RowwiseQuantizedSparseLengthsWeightedSumNode::verify() const {
1700 bool isValid = checkType(getData(), ElemKind::UInt8QTy, this);
1701 isValid &= expectCompareTrue("Indices must be a 1D vector",
1702 getIndices().dims().size(), size_t(1), this);
1703 isValid &= expectCompareTrue("Lengths must be a 1D vector",
1704 getLengths().dims().size(), size_t(1), this);
1705 isValid &= expectCompareTrue("Weights must be a 1D vector",
1706 getWeights().dims().size(), size_t(1), this);
1707 isValid &= expectCompareTrue("Scales must be a 1D vector",
1708 getScales().dims().size(), size_t(1), this);
1709 isValid &= expectCompareTrue("Offsets must be a 1D vector",
1710 getOffsets().dims().size(), size_t(1), this);
1711 isValid &=
1712 expectCompareTrue("Weights and Indices must have the same size",
1713 getWeights().dims()[0], getIndices().dims()[0], this);
1714 isValid &= expectCompareTrue(
1715 "Scales and Data must have the same first dimension size",
1716 getData().dims()[0], getScales().dims()[0], this);
1717 isValid &= expectCompareTrue(
1718 "Offsets and Data must have the same first dimension size",
1719 getData().dims()[0], getOffsets().dims()[0], this);
1720 if (getUseFP16Accumulation()) {
1721 isValid &= expectCompareTrue(
1722 "Only use FP16 accumulation with FP16 version of Fused-RWQ-SLWS.",
1723 getResult().getType()->getElementType(), ElemKind::Float16Ty, this);
1724 }
1725 return isValid;
1726}
1727
1728static bool verifyFusedRowwiseQuantizedSparseLengthsSum(
1729 NodeValue result, NodeValue data, NodeValue indices, NodeValue lengths,
1730 NodeValue weights, bool useFP16Accumulation,
1731 bool isEmbeddingBagByteRowwiseOffsets = false) {
1732 const Node *parent = result.getNode();
1733 bool isValid = expectCompareTrue(
1734 "Input data must be Fused Quantized type",
1735 isFusedQuantizedElemKind(data.getType()->getElementType()), true, parent);
1736 dim_t extraCols;
1737 if (data.getType()->getElementType() == ElemKind::UInt8FusedQTy ||
1738 data.getType()->getElementType() == ElemKind::UInt4FusedQTy) {
1739 extraCols = 2 * sizeof(float);
1740 } else {
1741 extraCols = 2 * sizeof(float16_t);
1742 }
1743 if (useFP16Accumulation) {
1744 isValid &= expectCompareTrue(
1745 "Only use FP16 accumulation with FP16 version of RWQ-SLWS.",
1746 result.getType()->getElementType(), ElemKind::Float16Ty, parent);
1747 }
1748 isValid &= checkType(
1749 indices,
1750 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}),
1751 parent);
1752 // For EmbeddingBagByteRowwiseOffsets lengths are really offsets and
1753 // can be either Int64ITy or Int64ITy.
1754 if (isEmbeddingBagByteRowwiseOffsets) {
1755 isValid &= checkType(
1756 lengths,
1757 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}),
1758 parent);
1759 } else {
1760 isValid &= checkType(lengths, ElemKind::Int32ITy, parent);
1761 }
1762
1763 isValid &= expectCompareTrue("Indices must be a 1D vector",
1764 indices.dims().size(), size_t(1), parent);
1765 isValid &= expectCompareTrue("Lengths must be a 1D vector",
1766 lengths.dims().size(), size_t(1), parent);
1767 isValid &= expectCompareTrue("Data must be 2 dimensional.",
1768 data.dims().size(), size_t(2), parent);
1769 isValid &= expectCompareTrue("Data must have extra columns for scale/offset.",
1770 data.dims()[1], extraCols, parent,
1771 CompareOperatorGreaterEqual<dim_t>());
1772 isValid &= expectCompareTrue("Result must be 2 dimensional.",
1773 result.dims().size(), size_t(2), parent);
1774
1775 if (weights.getNode()) {
1776 isValid &= expectCompareTrue("Weights must be a 1D vector",
1777 weights.dims().size(), size_t(1), parent);
1778 isValid &= expectCompareTrue("Weights and Indices must have the same size",
1779 weights.dims()[0], indices.dims()[0], parent);
1780 }
1781
1782 // Wrap this in isValid to prevent potential segfault if the result is
1783 // incorrectly shaped.
1784 if (isValid) {
1785 // If using 4-bit quantization for embeddings then the input is packed into
1786 // two elements per byte.
1787 dim_t finalSize = result.dims()[1];
1788 if (data.getType()->getElementType() == ElemKind::UInt4FusedFP16QTy ||
1789 data.getType()->getElementType() == ElemKind::UInt4FusedQTy) {
1790 finalSize /= 2;
1791 }
1792 isValid &=
1793 expectCompareTrue("Result output shape should have second dim without "
1794 "extra columns from scale/offset in Data.",
1795 finalSize + extraCols, data.dims()[1], parent);
1796 }
1797 return isValid;
1798}
1799
1800bool EmbeddingBagByteRowwiseOffsetsNode::verify() const {
1801 return verifyFusedRowwiseQuantizedSparseLengthsSum(
1802 getResult(), getData(), getIndices(), getOffsets(), getWeights(),
1803 getUseFP16Accumulation(), /*isEmbeddingBagByteRowwiseOffsets*/ true);
1804}
1805
1806bool FusedRowwiseQuantizedSparseLengthsWeightedSumNode::verify() const {
1807 return verifyFusedRowwiseQuantizedSparseLengthsSum(
1808 getResult(), getData(), getIndices(), getLengths(), getWeights(),
1809 getUseFP16Accumulation());
1810}
1811
1812bool FusedRowwiseQuantizedSparseLengthsSumNode::verify() const {
1813 return verifyFusedRowwiseQuantizedSparseLengthsSum(
1814 getResult(), getData(), getIndices(), getLengths(), nullptr,
1815 getUseFP16Accumulation());
1816}
1817
1818bool LengthsToRangesNode::verify() const {
1819 bool isValid = checkType(getResult(), getLengths().getElementType(), this);
1820 isValid &= checkType(getLengths(), ElemKind::Int32ITy, this);
1821 isValid &= expectCompareTrue("Lengths must be a 1D vector",
1822 getLengths().dims().size(), size_t(1), this);
1823 isValid &= expectCompareTrue("Ranges must be a 2D vector",
1824 getResult().dims().size(), size_t(2), this);
1825 isValid &= expectCompareTrue(
1826 "Lengths and Ranges must have the same outer dimensions",
1827 getResult().dims()[0], getLengths().dims()[0], this);
1828 isValid &= expectCompareTrue("Inner dimension of Ranges must be 2",
1829 getResult().dims()[1], dim_t(2), this);
1830 return isValid;
1831}
1832
1833bool LengthsRangeFillNode::verify() const {
1834 bool isValid = checkType(getLengths(), ElemKind::Int32ITy, this);
1835 isValid &= checkType(getResult(), getLengths().getElementType(), this);
1836 isValid &= expectCompareTrue("Lengths must be a 1D vector",
1837 getLengths().dims().size(), size_t(1), this);
1838 isValid &= expectCompareTrue("Result must be a 1D vector",
1839 getResult().dims().size(), size_t(1), this);
1840 return isValid;
1841}
1842
1843bool BatchSparseToDenseNode::verify() const {
1844 bool isValid = checkType(getResult(), getValues().getElementType(), this);
1845 isValid &=
1846 checkType(getIndices(), {ElemKind::Int64ITy, ElemKind::Int32ITy}, this);
1847 isValid &=
1848 checkType(getLengths(), {ElemKind::Int64ITy, ElemKind::Int32ITy}, this);
1849 isValid &= expectCompareTrue("Lengths must be a 1D vector",
1850 getLengths().dims().size(), size_t(1), this);
1851 isValid &= expectCompareTrue("Indices must be a 1D vector",
1852 getIndices().dims().size(), size_t(1), this);
1853 isValid &= expectCompareTrue("Indices and Values must have the same shape",
1854 getIndices().dims(), getValues().dims(), this);
1855 isValid &= expectCompareTrue(
1856 "The size of Lengths and batches in the result should be the same",
1857 getLengths().dims()[0], getResult().dims()[0], this);
1858 isValid &= expectCompareTrue(
1859 "The second dimension of the result should be equal to dense_last_dim",
1860 getDenseLastDim(), (unsigned)getResult().dims()[1], this);
1861 return isValid;
1862}
1863
1864bool FillExamplesWithIndicatorNode::verify() const {
1865 bool isValid = checkType(getResult(), getData().getElementType(), this);
1866 isValid &= checkType(
1867 getIndicator(),
1868 {ElemKind::Int64ITy, ElemKind::Int32ITy, ElemKind::BoolTy}, this);
1869 isValid &= expectCompareTrue("Indicator must be a 1D vector",
1870 getIndicator().dims().size(), size_t(1), this);
1871 isValid &= expectCompareTrue("Data must have at least one dimension",
1872 getData().dims().size(), size_t(1), this,
1873 CompareOperatorGreaterEqual<size_t>());
1874 return isValid;
1875}
1876
1877bool SparseToDenseMaskNode::verify() const {
1878 bool isValid = checkType(getResult(), getValues().getElementType(), this);
1879 isValid &= checkType(getResult(), getDefaultValue().getElementType(), this);
1880 isValid &= checkType(getIndices(), ElemKind::Int64ITy, this);
1881 isValid &= checkType(getLengths(), ElemKind::Int32ITy, this);
1882 isValid &= expectCompareTrue("Indices must be a 1D vector",
1883 getIndices().dims().size(), size_t(1), this);
1884 isValid &= expectCompareTrue("Lengths must be a scalar or 1D vector",
1885 getLengths().dims().size(), {0, 1}, this);
1886 isValid &=
1887 expectCompareTrue("Indices and Values must have the same first dimension",
1888 getIndices().dims()[0], getValues().dims()[0], this);
1889 isValid &= expectCompareTrue(
1890 "Values[i] must have the same dimensions as DefaultValue",
1891 getValues().dims().slice(1), getDefaultValue().dims(), this);
1892 return isValid;
1893}
1894
1895bool SparseLabelSplitNode::verify() const {
1896 bool isValid =
1897 checkType("Input and output values must be of the same type",
1898 getLabelValues(), getValues().getElementType(), this);
1899 isValid &= checkType("Lengths must be of type int32", getLengths(),
1900 ElemKind::Int32ITy, this);
1901 isValid &= checkType("Indices must be of type int64", getIndices(),
1902 ElemKind::Int64ITy, this);
1903 isValid &= checkType("ExampleIds must be of type int32", getExampleIds(),
1904 ElemKind::Int32ITy, this);
1905 isValid &= checkType("GradientOffsetMap must be of type in32",
1906 getGradientOffsetMap(), ElemKind::Int32ITy, this);
1907 isValid &= expectCompareTrue("Lengths must be a 1D vector",
1908 getLengths().dims().size(), size_t(1), this);
1909 isValid &= expectCompareTrue("Indices must be a 1D vector",
1910 getIndices().dims().size(), size_t(1), this);
1911 isValid &= expectCompareTrue("Values must be a 1D vector",
1912 getValues().dims().size(), size_t(1), this);
1913 isValid &= expectCompareTrue("Indices and values must have the same shape",
1914 getIndices().dims(), getValues().dims(), this);
1915 return isValid;
1916}
1917
1918bool SGDNode::verify() const {
1919 return checkSameType(getGradient(), getWeight(), this);
1920}
1921
1922bool QuantizationProfileNode::verify() const {
1923 // Make sure that input tensor is a floating point type.
1924 bool isValid = checkType(getInput(), ElemKind::FloatTy, this);
1925
1926 // Check computation info has proper size.
1927 isValid &=
1928 expectCompareTrue("Computation info should be 1 dimensional",
1929 getComputationInfo().dims().size(), size_t(1), this);
1930 isValid &= expectCompareTrue(
1931 "Computation info should contain Min and Max value only",
1932 getComputationInfo().dims()[0], (dim_t)(2), this);
1933 return isValid;
1934}
1935
1936bool IntLookupTableNode::verify() const {
1937 bool isValid =
1938 expectCompareTrue("Input should be quantized type",
1939 getInput().getType()->isQuantizedType(), true, this);
1940 isValid &=
1941 expectCompareTrue("Result should be quantized type",
1942 getResult().getType()->isQuantizedType(), true, this);
1943 isValid &= expectCompareTrue("Mapping should be 1 dimensional",
1944 getMapping().dims().size(), size_t(1), this);
1945 isValid &= expectCompareTrue(
1946 "Mapping should cover the whole input quantized range",
1947 getMapping().dims()[0],
1948 (dim_t)(getInput().getType()->getQuantizedValueCount()), this);
1949 isValid &= expectCompareTrue("Mapping and result type must be the same",
1950 getMapping().getType()->getElementType(),
1951 getResult().getType()->getElementType(), this);
1952 return isValid;
1953}
1954
1955bool LookupTableNode::verify() const {
1956 bool isValid = true;
1957 return isValid;
1958}
1959
1960bool QuantizeNode::verify() const {
1961 bool isValid =
1962 expectCompareTrue("Dest must be quantized",
1963 getResult().getType()->isQuantizedType(), true, this);
1964 isValid &= expectCompareTrue("Src must be an FP type",
1965 getInput().getType()->isFPType(), true, this);
1966 isValid &= checkSameShape(getResult(), getInput(), this);
1967 return isValid;
1968}
1969
1970bool DequantizeNode::verify() const {
1971 bool isValid = expectCompareTrue(
1972 "Dest must be an FP type", getResult().getType()->isFPType(), true, this);
1973 isValid &=
1974 expectCompareTrue("Src must be quantized",
1975 getInput().getType()->isQuantizedType(), true, this);
1976 if (getInput().getElementType() == ElemKind::UInt8FusedQTy) {
1977 isValid &= expectCompareTrue("Fused tensors should be 2D",
1978 getInput().dims().size(), size_t(2), this);
1979 isValid &= expectCompareTrue(
1980 "Expected space for per-row scale/offset", getInput().dims()[1],
1981 (dim_t)(2 * sizeof(float)), this, CompareOperatorGreaterThan<dim_t>());
1982 } else {
1983 isValid &= checkSameShape(getResult(), getInput(), this);
1984 }
1985 return isValid;
1986}
1987
1988bool RescaleQuantizedNode::verify() const {
1989 bool isValid =
1990 expectCompareTrue("Dest must be quantized",
1991 getResult().getType()->isQuantizedType(), true, this);
1992 isValid &=
1993 checkType(getResult(), getInput().getType()->getElementType(), this);
1994 isValid &= checkSameShape(getResult(), getInput(), this);
1995 return isValid;
1996}
1997
1998bool CollectRpnProposalsNode::verify() const {
1999 auto result = getResult();
2000 auto rois = getRoisIn();
2001 auto probs = getRoisProbsIn();
2002 bool isValid = true;
2003
2004 isValid &= expectCompareTrue("rpnPostNmsTopN should be greater than zero",
2005 getRpnPostNmsTopN() > 0, true, this);
2006
2007 isValid &= expectCompareTrue(
2008 "RPN min level should be less than or equal to RPN max level",
2009 getRpnMinLevel() <= getRpnMaxLevel(), true, this);
2010
2011 dim_t rpnLevels = getRpnMaxLevel() - getRpnMinLevel() + 1;
2012
2013 isValid &= expectCompareTrue("Invalid number of inputs",
2014 rpnLevels == rois.size(), true, this);
2015 isValid &= expectCompareTrue("Invalid number of inputs",
2016 rpnLevels == probs.size(), true, this);
2017
2018 for (dim_t i = 0; i < rpnLevels; i++) {
2019 auto roi = rois[i];
2020 auto prob = probs[i];
2021 isValid &= checkType(result, roi.getElementType(), this);
2022 isValid &= checkType(result, prob.getElementType(), this);
2023 isValid &=
2024 expectCompareTrue("Rois and result must have same second dimension",
2025 roi.dims()[1], result.dims()[1], this);
2026 isValid &= expectCompareTrue(
2027 "Rois and respective probability scores must have same first dimension",
2028 roi.dims()[0], prob.dims()[0], this);
2029 }
2030
2031 isValid &=
2032 expectCompareTrue("Result is capped to rpnPostNmsTopN",
2033 result.dims()[0] == getRpnPostNmsTopN(), true, this);
2034
2035 return isValid;
2036}
2037
2038bool TopKNode::verify() const {
2039 bool isValid = checkSameShape(getValues(), getIndices(), this);
2040 isValid &= checkNotQuantizedOrSameParams(getInput().getType(),
2041 getValues().getType(), this);
2042 return isValid;
2043}
2044
2045bool ArgMaxNode::verify() const {
2046 bool isValid = true;
2047
2048 // Check output type.
2049 isValid &= checkType(
2050 getResult(),
2051 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this);
2052
2053 // Check output shape.
2054 ShapeVector expDstDims =
2055 reduceDims(getInput().dims(), {getAxis()}, getKeepDims());
2056 isValid &= expectCompareTrue("Invalid output dims", getResult().dims(),
2057 llvm::makeArrayRef(expDstDims), this);
2058 return isValid;
2059}
2060
2061bool ArgMinNode::verify() const {
2062 bool isValid = true;
2063
2064 // Check output type.
2065 isValid &= checkType(
2066 getResult(),
2067 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this);
2068
2069 // Check output shape.
2070 ShapeVector expDstDims =
2071 reduceDims(getInput().dims(), {getAxis()}, getKeepDims());
2072 isValid &= expectCompareTrue("Invalid output dims", getResult().dims(),
2073 llvm::makeArrayRef(expDstDims), this);
2074 return isValid;
2075}
2076
2077bool VectorNormNode::verify() const {
2078 bool isValid = true;
2079
2080 isValid &= expectCompareTrue("Only support Frobenius, p should be 2", getP(),
2081 (unsigned)2, this);
2082 // Check output shape.
2083 ShapeVector expDstDims = reduceDims(getInput().dims(), {getAxis()}, false);
2084 isValid &= expectCompareTrue("Invalid output dims", getResult().dims(),
2085 llvm::makeArrayRef(expDstDims), this);
2086 return isValid;
2087}
2088
2089bool GaussianFillNode::verify() const {
2090 auto dest = getResult();
2091 bool isValid = dest.getElementType() == ElemKind::Float16Ty;
2092 isValid &= checkSameShape(getInput(), dest, this);
2093 return isValid;
2094}
2095
2096bool DynamicQuantizedFullyConnectedNode::verify() const {
2097 auto src = getInput();
2098 auto weights = getWeights();
2099 auto bias = getBias();
2100 auto dest = getResult();
2101 auto isPerBatchElement = getIsPerBatchElement();
2102 auto isSymmetric = getIsSymmetric();
2103
2104 bool isValid = expectCompareTrue("Inputs should be 2D tensor",
2105 src.dims().size(), size_t(2), this);
2106 isValid &= expectCompareTrue(
2107 "Only per batch quantized input DynQuantizedFC is supported now",
2108 isPerBatchElement, true, this);
2109 isValid &= expectCompareTrue(
2110 "Only symmetric quantized DynQuantizedFC is supported now", isSymmetric,
2111 true, this);
2112 isValid &= expectCompareTrue("Weights should be 2D tensor",
2113 weights.dims().size(), size_t(2), this);
2114 isValid &= expectCompareTrue("Result should be 2D tensor", dest.dims().size(),
2115 size_t(2), this);
2116 isValid &= expectCompareTrue("Bias should be 1D tensor", bias.dims().size(),
2117 size_t(1), this);
2118
2119 isValid &= expectCompareTrue("Mismatch on expected source dimension 0",
2120 src.dims()[0], dest.dims()[0], this);
2121 isValid &= expectCompareTrue("Mismatch on expected source dimension 1",
2122 src.dims()[1], weights.dims()[0], this);
2123
2124 isValid &= expectCompareTrue("Inconsistent bias/weights sizes",
2125 bias.dims()[0], weights.dims()[1], this);
2126 isValid &= expectCompareTrue("Inconsistent bias/dest sizes", bias.dims()[0],
2127 dest.dims()[1], this);
2128
2129 return isValid;
2130}
2131
2132bool DynamicRowwiseQuantizedFullyConnectedNode::verify() const {
2133 auto src = getInput();
2134 auto weights = getWeights();
2135 auto bias = getBias();
2136 auto dest = getResult();
2137 auto scales = getScales();
2138 auto offsets = getOffsets();
2139 auto isPerBatchElement = getIsPerBatchElement();
2140 auto isSymmetric = getIsSymmetric();
2141
2142 bool isValid = expectCompareTrue("Inputs should be 2D tensor",
2143 src.dims().size(), size_t(2), this);
2144 isValid &= expectCompareTrue(
2145 "Only per batch quantized input DynQuantizedFC is supported now",
2146 isPerBatchElement, true, this);
2147 isValid &= expectCompareTrue(
2148 "Only symmetric quantized DynQuantizedFC is supported now", isSymmetric,
2149 true, this);
2150 isValid &= expectCompareTrue("Weights should be 2D tensor",
2151 weights.dims().size(), size_t(2), this);
2152 isValid &= expectCompareTrue("Result should be 2D tensor", dest.dims().size(),
2153 size_t(2), this);
2154 isValid &= expectCompareTrue("Bias should be 1D tensor", bias.dims().size(),
2155 size_t(1), this);
2156 isValid &= expectCompareTrue("Offsets should be 1D tensor",
2157 offsets.dims().size(), size_t(1), this);
2158 isValid &= expectCompareTrue("Scales should be 1D tensor",
2159 scales.dims().size(), size_t(1), this);
2160
2161 isValid &= expectCompareTrue("Mismatch on expected source dimension 0",
2162 src.dims()[0], dest.dims()[0], this);
2163 isValid &= expectCompareTrue("Mismatch on expected source dimension 1",
2164 src.dims()[1], weights.dims()[0], this);
2165
2166 isValid &= expectCompareTrue("Inconsistent bias/weights sizes",
2167 bias.dims()[0], weights.dims()[1], this);
2168 isValid &= expectCompareTrue("Inconsistent bias/dest sizes", bias.dims()[0],
2169 dest.dims()[1], this);
2170 isValid &= expectCompareTrue("Inconsistent scales/offsets sizes",
2171 scales.dims()[0], offsets.dims()[0], this);
2172 isValid &= expectCompareTrue("Inconsistent scales/weights sizes",
2173 scales.dims()[0], weights.dims()[1], this);
2174
2175 return isValid;
2176}
2177
2178bool RowwiseQuantizedFullyConnectedNode::verify() const {
2179 auto src = getInput();
2180 auto weights = getWeights();
2181 auto scales = getScales();
2182 auto offsets = getOffsets();
2183 auto bias = getBias();
2184 auto dest = getResult();
2185
2186 bool isValid = expectCompareTrue("Inputs should be 2D tensor",
2187 src.dims().size(), size_t(2), this);
2188 isValid &= expectCompareTrue("Weights should be 2D tensor",
2189 weights.dims().size(), size_t(2), this);
2190 isValid &= expectCompareTrue("Result should be 2D tensor", dest.dims().size(),
2191 size_t(2), this);
2192 isValid &= expectCompareTrue("Offsets should be 1D tensor",
2193 offsets.dims().size(), size_t(1), this);
2194 isValid &= expectCompareTrue("Scales should be 1D tensor",
2195 scales.dims().size(), size_t(1), this);
2196 isValid &= expectCompareTrue("Bias should be 1D tensor", bias.dims().size(),
2197 size_t(1), this);
2198
2199 isValid &= expectCompareTrue("Mismatch on expected source dimension 0",
2200 src.dims()[0], dest.dims()[0], this);
2201 isValid &= expectCompareTrue("Mismatch on expected source dimension 1",
2202 src.dims()[1], weights.dims()[1], this);
2203
2204 isValid &= expectCompareTrue("Inconsistent bias/dest sizes", bias.dims()[0],
2205 weights.dims()[0], this);
2206 isValid &= expectCompareTrue("Inconsistent weights/dest sizes",
2207 weights.dims()[0], dest.dims()[1], this);
2208
2209 isValid &= expectCompareTrue("Inconsistent scales/offsets sizes",
2210 scales.dims()[0], offsets.dims()[0], this);
2211 isValid &= expectCompareTrue("Inconsistent scales/weights sizes",
2212 scales.dims()[0], weights.dims()[0], this);
2213 return isValid;
2214}
2215
2216bool GatherNode::verify() const {
2217 bool isValid = checkType(getResult(), getData().getElementType(), this);
2218 isValid &= checkType(
2219 getIndices(),
2220 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this);
2221 isValid &= expectCompareTrue(
2222 "Mismatching number of dimensions", getResult().dims().size(),
2223 getData().dims().size() + getIndices().dims().size() - 1, this);
2224 isValid &= checkNotQuantizedOrSameParams(getResult().getType(),
2225 getData().getType(), this);
2226 return isValid;
2227}
2228
2229bool GatherElementsNode::verify() const {
2230 bool isValid = checkType(getResult(), getData().getElementType(), this);
2231 isValid &= checkType(
2232 getIndices(),
2233 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this);
2234 isValid &= expectCompareTrue("Mismatching number of dimensions",
2235 getResult().dims().size(),
2236 getIndices().dims().size(), this);
2237 return isValid;
2238}
2239
2240bool GatherNDNode::verify() const {
2241 bool isValid = checkType(getResult(), getData().getElementType(), this);
2242 isValid &= checkType(
2243 getIndices(),
2244 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}), this);
2245 isValid &= expectCompareTrue(
2246 "Mismatching number of dimensions", getResult().dims().size(),
2247 getData().dims().size() + getIndices().dims().size() -
2248 getIndices().dims().back() - 1 - getBatchDims(),
2249 this);
2250 isValid &= checkNotQuantizedOrSameParams(getResult().getType(),
2251 getData().getType(), this);
2252 return isValid;
2253}
2254
2255bool GatherRangesNode::verify() const {
2256 bool isValid = expectCompareTrue("Data must be 1D", getData().dims().size(),
2257 size_t(1), this);
2258 isValid &= expectCompareTrue("Ranges must be 3D", getRanges().dims().size(),
2259 size_t(3), this);
2260 isValid &= expectCompareTrue("Last dimension of Ranges must be equal to 2",
2261 getRanges().dims()[2], dim_t(2), this);
2262 isValid &= expectCompareTrue("Output must be 1D", getOutput().dims().size(),
2263 size_t(1), this);
2264 isValid &= expectCompareTrue("Lengths must be 1D", getLengths().dims().size(),
2265 size_t(1), this);
2266 isValid &=
2267 expectCompareTrue("Number of examples must match number of lengths",
2268 getRanges().dims()[0], getLengths().dims()[0], this);
2269
2270 isValid &= checkTypeIgnoreShape(getOutput(), getData(), this);
2271 isValid &= checkTypeIgnoreShape(getRanges(), getLengths(), this);
2272
2273 return isValid;
2274}
2275
2276bool ScatterDataNode::verify() const {
2277 const auto &slicesDims = getSlices().dims();
2278 const auto &dataDims = getData().dims();
2279 const auto &indicesDims = getIndices().dims();
2280 bool isValid = true;
2281 isValid &= expectCompareTrue("Type mismatch",
2282 getSlices().getType()->getElementType(),
2283 getData().getType()->getElementType(), this);
2284 if (!isValid) {
2285 return false;
2286 }
2287 // TODO: Do we need support for different quant params of copy?
2288 if (getSlices().getType()->isQuantizedType() && !getCumulative()) {
2289 isValid &=
2290 expectCompareTrue("Scale mismatch", getSlices().getType()->getScale(),
2291 getData().getType()->getScale(), this);
2292 isValid &=
2293 expectCompareTrue("Offset mismatch", getSlices().getType()->getOffset(),
2294 getData().getType()->getOffset(), this);
2295 }
2296 isValid &= expectCompareTrue("There should be an index for each slice",
2297 indicesDims[0], slicesDims[0], this);
2298 isValid &= expectCompareTrue("Indices should be a 2D tensor",
2299 indicesDims.size(), size_t(2), this);
2300 // The code below may crash if these conditions are not met.
2301 if (!isValid) {
2302 return false;
2303 }
2304 const size_t indexSize = indicesDims[1];
2305 isValid &= expectCompareTrue("Dimensions of Data should be equal to "
2306 "dimensions of indices + dimensions of updates",
2307 slicesDims.size() - 1 + indexSize,
2308 dataDims.size(), this);
2309 if (dataDims.size() > 1) {
2310 for (size_t i = indexSize; i < dataDims.size(); i++) {
2311 std::string msg = std::to_string(i);
2312 msg = "Slice shape should equal data shape for dim " + msg;
2313 isValid &= expectCompareTrue(msg.c_str(), dataDims[i],
2314 slicesDims[i - indexSize + 1], this);
2315 }
2316 }
2317
2318 return isValid;
2319}
2320
2321bool BatchOneHotNode::verify() const {
2322 const auto &dataDims = getData().dims();
2323 const auto &lengthsDims = getLengths().dims();
2324 const auto &valuesDims = getValues().dims();
2325
2326 bool isValid = expectCompareTrue("Data should be a two dimensional matrix",
2327 dataDims.size(), size_t(2), this);
2328
2329 isValid &= expectCompareTrue("Lengths should be a single dimensional vectors",
2330 lengthsDims.size(), size_t(1), this);
2331 isValid &= checkType(getLengths(), ElemKind::Int32ITy, this);
2332
2333 isValid &= expectCompareTrue("Values should be a single dimensional vectors",
2334 valuesDims.size(), size_t(1), this);
2335
2336 isValid &=
2337 expectCompareTrue("Size of Lengths should be equal to width of Data",
2338 lengthsDims[0], dataDims[1], this);
2339 return isValid;
2340}
2341
2342bool SpaceToDepthNode::verify() const {
2343 auto inputN = getInput();
2344 auto resultN = getResult();
2345 auto inputDims = inputN.dims();
2346 auto outputDims = resultN.dims();
2347 unsigned blockSize = getBlockSize();
2348
2349 bool sameType = checkTypeIgnoreShape(inputN, resultN, this);
2350 bool dimTransform = inputDims[0] == outputDims[0] &&
2351 inputDims[1] == outputDims[1] * blockSize &&
2352 inputDims[2] == outputDims[2] * blockSize &&
2353 inputDims[3] * blockSize * blockSize == outputDims[3];
2354
2355 return sameType && dimTransform;
2356}
2357
2358bool ResizeNearestNode::verify() const {
2359 auto input = getInput();
2360 auto scale = getScale();
2361 auto result = getResult();
2362 auto inputDims = input.dims();
2363 auto outputDims = result.dims();
2364
2365 bool isValid = checkTypeIgnoreShape(input, result, this);
2366 isValid &=
2367 expectCompareTrue("Input size must be greater than 2", inputDims.size(),
2368 size_t(2), this, CompareOperatorGreaterThan<size_t>());
2369 isValid &=
2370 expectCompareTrue("Output size must be greater than 2", outputDims.size(),
2371 size_t(2), this, CompareOperatorGreaterThan<size_t>());
2372 isValid &= expectCompareTrue("Input size must be equal to the output size",
2373 inputDims.size(), outputDims.size(), this);
2374
2375 for (size_t i = 0, e = scale.size(); i < e; i++) {
2376 isValid &= expectCompareTrue("Unexpected output",
2377 dim_t(std::floor(inputDims[i] * scale[i])),
2378 outputDims[i], this);
2379 isValid &= expectCompareTrue("Invalid scale", scale[i], float(0.0), this,
2380 CompareOperatorGreaterThan<float>());
2381 }
2382
2383 return isValid;
2384}
2385
2386bool ResizeBilinearNode::verify() const {
2387 auto input = getInput();
2388 auto scale = getScale();
2389 auto result = getResult();
2390 auto inputDims = input.dims();
2391 auto outputDims = result.dims();
2392
2393 bool isValid = checkTypeIgnoreShape(input, result, this);
2394 isValid &= expectCompareTrue("Input must be a 4D tensor", inputDims.size(),
2395 size_t(4), this);
2396 isValid &= expectCompareTrue("Output must be a 4D tensor", outputDims.size(),
2397 size_t(4), this);
2398
2399 for (size_t i = 0, e = scale.size(); i < e; i++) {
2400 isValid &= expectCompareTrue("Unexpected output",
2401 dim_t(std::floor(inputDims[i] * scale[i])),
2402 outputDims[i], this);
2403 isValid &= expectCompareTrue("Invalid scale", scale[i], float(0.0), this,
2404 CompareOperatorGreaterThan<float>());
2405 }
2406
2407 return isValid;
2408}
2409
2410bool NonMaxSuppressionNode::verify() const {
2411 NodeValue boxes = getBoxes();
2412 NodeValue scores = getScores();
2413 auto boxesDims = boxes.dims();
2414 auto scoresDims = scores.dims();
2415 bool isV4 = getIsTFVersion4();
2416
2417 size_t scoresBoxDim = scores.dims().size() - 1;
2418 size_t scoresBatchDim = scores.dims().size() - 3;
2419
2420 size_t boxesBoxDim = boxes.dims().size() - 2;
2421 size_t boxesBatchDim = boxes.dims().size() - 3;
2422
2423 bool isValid = true;
2424 if (isV4) {
2425 isValid &= expectCompareTrue(
2426 "Number of boxes doesn't match number of confidence scores.",
2427 boxesDims[boxesBoxDim], scoresDims[scoresBoxDim], this,
2428 CompareOperatorEqual<dim_t>());
2429 }
2430
2431 // checking layout matching. See ONNX spec for details.
2432 if (!isV4) {
2433 isValid &= expectCompareTrue(
2434 "Batch dimension doesn't match.", boxesDims[boxesBatchDim],
2435 scoresDims[scoresBatchDim], this, CompareOperatorEqual<dim_t>());
2436
2437 isValid &= expectCompareTrue(
2438 "Number of boxes doesn't match number of confidence scores.",
2439 boxesDims[boxesBoxDim], scoresDims[scoresBoxDim], this,
2440 CompareOperatorEqual<dim_t>());
2441 }
2442
2443 isValid &= checkType(boxes, scores.getElementType(), this);
2444
2445 return isValid;
2446}
2447
2448bool TFLiteDetectionPostProcessNode::verify() const {
2449 NodeValue boxes = getBoxes();
2450 NodeValue scores = getScores();
2451 NodeValue anchors = getAnchors();
2452
2453 auto boxesDims = boxes.dims();
2454 auto scoresDims = scores.dims();
2455 auto anchorsDims = anchors.dims();
2456
2457 bool isValid = true;
2458
2459 // Validate input tensor sizes.
2460 isValid &= expectCompareTrue("Input boxes must be a 3D tensor!",
2461 boxesDims.size(), size_t(3), this);
2462 isValid &= expectCompareTrue("Input scores must be a 3D tensor!",
2463 scoresDims.size(), size_t(3), this);
2464 isValid &= expectCompareTrue("Input anchors must be a 2D tensor!",
2465 anchorsDims.size(), size_t(2), this);
2466 dim_t numBoxes = boxesDims[1];
2467 dim_t numTotClasses = scoresDims[2];
2468 isValid &= expectCompareTrue("Input boxes size invalid!", boxesDims[1],
2469 numBoxes, this);
2470 isValid &= expectCompareTrue("Input boxes size invalid!", boxesDims[2],
2471 dim_t(4), this);
2472 isValid &= expectCompareTrue("Input scores size invalid!", scoresDims[0],
2473 boxesDims[0], this);
2474 isValid &= expectCompareTrue("Input scores size invalid!", scoresDims[1],
2475 numBoxes, this);
2476 isValid &= expectCompareTrue("Input scores size invalid!", scoresDims[2],
2477 numTotClasses, this);
2478 isValid &= expectCompareTrue("Input anchors size invalid!", anchorsDims[0],
2479 numBoxes, this);
2480 isValid &= expectCompareTrue("Input anchors size invalid!", anchorsDims[1],
2481 dim_t(4), this);
2482
2483 // Validate parameters.
2484 isValid &=
2485 expectCompareTrue("Invalid IOU threshold!", getIouThreshold(), float(0.0),
2486 this, CompareOperatorGreaterThan<float>());
2487 isValid &=
2488 expectCompareTrue("Invalid IOU threshold!", getIouThreshold(), float(1.0),
2489 this, CompareOperatorLessEqual<float>());
2490 isValid &=
2491 expectCompareTrue("Invalid score threshold!", getScoreThreshold(),
2492 float(0.0), this, CompareOperatorGreaterThan<float>());
2493 isValid &=
2494 expectCompareTrue("Invalid score threshold!", getScoreThreshold(),
2495 float(1.0), this, CompareOperatorLessEqual<float>());
2496 isValid &=
2497 expectCompareTrue("Invalid number of classes!", dim_t(getNumClasses()),
2498 numTotClasses, this, CompareOperatorLessEqual<dim_t>());
2499 isValid &=
2500 expectCompareTrue("Invalid max detections!", dim_t(getMaxDetections()),
2501 dim_t(0), this, CompareOperatorGreaterThan<dim_t>());
2502 return isValid;
2503}
2504
2505bool AudioSpectrogramNode::verify() const {
2506 NodeValue input = getInput();
2507 NodeValue spectrogram = getSpectrogram();
2508 auto inputLength = input.getType()->size();
2509 auto windowSize = getWindowSize();
2510 auto windowStride = getWindowStride();
2511 auto windowCount = std::floor((inputLength - windowSize) / windowStride) + 1;
2512 auto fftLen = 1 << (int)std::ceil(std::log2((double)windowSize));
2513
2514 bool isValid = true;
2515 isValid &= expectCompareTrue("Input audio is too short for given window size",
2516 dim_t(windowCount), dim_t(0), this,
2517 CompareOperatorGreaterThan<dim_t>());
2518 isValid &= expectCompareTrue("Output spectrogram must be a 2D tensor",
2519 spectrogram.dims().size(), size_t(2), this);
2520 isValid &= expectCompareTrue("Output spectrogram size is invalid",
2521 spectrogram.dims()[0], dim_t(windowCount), this,
2522 CompareOperatorEqual<dim_t>());
2523 isValid &= expectCompareTrue("Output spectrogram size is invalid",
2524 spectrogram.dims()[1], dim_t(fftLen / 2 + 1),
2525 this, CompareOperatorEqual<dim_t>());
2526 return isValid;
2527}
2528
2529bool MFCCNode::verify() const {
2530 NodeValue spectrogram = getSpectrogram();
2531 NodeValue coefficients = getCoefficients();
2532 float sampleRate = getSampleRate();
2533 float lowerFrequency = getLowerFrequency();
2534 float upperFrequency = getUpperFrequency();
2535 auto filterBankCount = getFilterBankCount();
2536 auto numCoefficients = getNumCoefficients();
2537 auto fftLen = (spectrogram.dims()[1] - 1) * 2;
2538 int exp;
2539
2540 bool isValid = true;
2541 isValid &= expectCompareTrue("Input spectrogram must be a 2D tensor",
2542 spectrogram.dims().size(), size_t(2), this);
2543 isValid &= expectCompareTrue(
2544 "Input spectrogram size is invalid. Should be of the form 2^N/2+1.",
2545 std::abs(std::frexp((float)(fftLen), &exp)), float(0.5), this,
2546 CompareOperatorEqual<float>());
2547 isValid &= expectCompareTrue("Output coefficients must be a 2D tensor",
2548 coefficients.dims().size(), size_t(2), this);
2549 isValid &= expectCompareTrue("Output coefficients size is invalid",
2550 coefficients.dims()[1], dim_t(numCoefficients),
2551 this, CompareOperatorEqual<dim_t>());
2552 isValid &= expectCompareTrue(
2553 "Number of windows should be same for both input and output",
2554 spectrogram.dims()[0], coefficients.dims()[0], this,
2555 CompareOperatorEqual<dim_t>());
2556 isValid &= expectCompareTrue("Lower frequency should be greater than 0",
2557 lowerFrequency, float(0.0), this,
2558 CompareOperatorGreaterThan<float>());
2559 isValid &= expectCompareTrue("Upper frequency should be greater than 0",
2560 upperFrequency, float(0.0), this,
2561 CompareOperatorGreaterThan<float>());
2562 isValid &= expectCompareTrue(
2563 "Upper frequency must be greater than lower frequency", upperFrequency,
2564 lowerFrequency, this, CompareOperatorGreaterThan<float>());
2565 isValid &= expectCompareTrue(
2566 "Upper frequency must be lower than half the sample rate", sampleRate,
2567 float(2.0 * upperFrequency), this, CompareOperatorGreaterThan<float>());
2568 isValid &= expectCompareTrue(
2569 "Number of coefficients should be smaller or equal than the filter bank",
2570 dim_t(filterBankCount), dim_t(numCoefficients), this,
2571 CompareOperatorGreaterEqual<dim_t>());
2572 return isValid;
2573}
2574
2575bool ROIAlignNode::verify() const {
2576 auto featureMap = getFeatureMap();
2577 auto boxes = getBoxes();
2578 auto batchIndices = getBatchIndices();
2579 auto result = getResult();
2580 auto featureMapDims = featureMap.dims();
2581 auto boxesDims = boxes.dims();
2582 auto outputDims = result.dims();
2583
2584 bool isValid = checkTypeIgnoreShape(featureMap, result, this);
2585 isValid &= checkTypeIgnoreShape(boxes, result, this);
2586 isValid &=
2587 checkType(featureMap, {ElemKind::FloatTy, ElemKind::Float16Ty}, this);
2588 isValid &= expectCompareTrue("FeatureMap must be a 4D tensor",
2589 featureMapDims.size(), size_t(4), this);
2590 isValid &= expectCompareTrue("Boxes must be a 2D tensor", boxesDims.size(),
2591 size_t(2), this);
2592 isValid &= expectCompareTrue("Output must be a 4D tensor", outputDims.size(),
2593 size_t(4), this);
2594 // If batch size > 1 batch indices must be provided.
2595 if (featureMapDims[0] > 1) {
2596 // Caffe2 gets indices using boxes tensor
2597 bool indicesInBoxesTensor = boxesDims[1] == (getRotated() ? 6 : 5);
2598 // Onnx requires batchIndices to be valid
2599 if (!indicesInBoxesTensor) {
2600 auto batchIndicesDims = batchIndices.dims();
2601 isValid &= checkType(batchIndices,
2602 {ElemKind::Int64ITy, ElemKind::Int32ITy}, this);
2603 isValid &= expectCompareTrue("BatchIndices must be a 1D tensor",
2604 batchIndicesDims.size(), size_t(1), this);
2605 isValid &=
2606 expectCompareTrue("BatchIndices must have same length as Boxes",
2607 batchIndicesDims[0], boxesDims[0], this);
2608 }
2609 }
2610 return isValid;
2611}
2612
2613bool BBoxTransformNode::verify() const {
2614 auto rois = getRois();
2615 auto deltas = getDeltas();
2616 auto imInfo = getImInfo();
2617 auto boxOut = getBoxOut();
2618 auto weights = getWeights();
2619 auto period = getAngleBoundHi() - getAngleBoundLo();
2620
2621 auto roisDims = rois.dims();
2622 auto deltasDims = deltas.dims();
2623 auto imInfoDims = imInfo.dims();
2624
2625 bool rotated = getRotated();
2626 bool angleBoundOn = getAngleBoundOn();
2627 // BoxDim is of the format
2628 // <x1, y1, x2, y2, [optional_angle]>
2629 dim_t expectedBoxDim = rotated ? 5 : 4;
2630
2631 // Rois row is of the format
2632 // <[optinal_batch_index], x1, y1, x2, y2, [optional_angle]>
2633 bool validRoiDim =
2634 roisDims[1] == expectedBoxDim || roisDims[1] == expectedBoxDim + 1;
2635
2636 bool isValid = checkTypeIgnoreShape(rois, boxOut, this);
2637 isValid &= checkSameType(deltas, boxOut, this);
2638 isValid &= checkTypeIgnoreShape(imInfo, boxOut, this);
2639 // ROIs can be float32 or float16.
2640 isValid &= checkType(rois, {ElemKind::FloatTy, ElemKind::Float16Ty}, this);
2641 isValid &= expectCompareTrue("Rois must be a 2D tensor", roisDims.size(),
2642 size_t(2), this);
2643 isValid &=
2644 expectCompareTrue("Rois must have with equals boxDim or larger in 1",
2645 validRoiDim, true, this);
2646 isValid &= expectCompareTrue("Deltas must be a 2D tensor", deltasDims.size(),
2647 size_t(2), this);
2648 isValid &= expectCompareTrue("ImInfo must be a 2D tensor", imInfoDims.size(),
2649 size_t(2), this);
2650 isValid &= expectCompareTrue("ImInfo must be a {batch_size, 3} tensor",
2651 imInfoDims[1], dim_t(3), this);
2652 isValid &= expectCompareTrue("Rois and Deltas must have same 0 dimension",
2653 roisDims[0], deltasDims[0], this);
2654 isValid &= expectCompareTrue(
2655 "Number of rois must be <= 2048 to be represented in FP16.", roisDims[0],
2656 dim_t(2048), this, CompareOperatorLessEqual<dim_t>());
2657 isValid &= expectCompareTrue("Deltas must be divisible by box dimensions",
2658 deltasDims[1] % expectedBoxDim, dim_t(0), this);
2659 isValid &= expectCompareTrue("Weights must be a 1D vector of length 4",
2660 weights.size(), size_t(4), this);
2661 if (roisDims[1] == expectedBoxDim) {
2662 isValid &= expectCompareTrue(
2663 "The batch size should be 1 if there's no batch index in rois",
2664 imInfoDims[0], dim_t(1), this);
2665 }
2666 if (rotated && angleBoundOn) {
2667 isValid &= expectCompareTrue(
2668 "The difference between angleBoundHi and angleBoundLo "
2669 "should be greater than 0 and divisible by 180",
2670 period > 0 && period % 180 == 0, true, this);
2671 }
2672
2673 return isValid;
2674}
2675
2676bool SaveNode::verify() const {
2677 return checkSameType(getInput(), getOutput(), this);
2678}
2679
2680bool LogNode::verify() const {
2681 if (getResult().getType()->isQuantizedType()) {
2682 return checkSameShape(getInput(), getResult(), this);
2683 }
2684 return checkSameType(getInput(), getResult(), this);
2685}
2686
2687bool IsNaNNode::verify() const {
2688 bool isValid = checkSameShape(getResult(), getInput(), this);
2689 isValid &= checkType(getResult(), ElemKind::BoolTy, this);
2690 return isValid;
2691}
2692
2693bool ReplaceNaNNode::verify() const {
2694 return checkSameType(getResult(), getInput(), this);
2695}
2696
2697bool NonZeroNode::verify() const {
2698 return checkType(getCond(), ElemKind::BoolTy, this) &&
2699 checkType(getResult(), ElemKind::Int32ITy, this);
2700}
2701
2702bool SelectNode::verify() const {
2703 bool isValid = checkSameShape(getResult(), getLHS(), this);
2704 isValid &= checkSameShape(getResult(), getRHS(), this);
2705 isValid &= checkSameShape(getResult(), getCond(), this);
2706 isValid &= checkType(getLHS(), getRHS().getElementType(), this);
2707 isValid &= checkType(getLHS(), getResult().getElementType(), this);
2708 isValid &= checkType(getCond(), ElemKind::BoolTy, this);
2709 return isValid;
2710}
2711
2712bool ReluNode::verify() const { return verifyRelu(getResult(), getInput()); }
2713
2714bool GeluNode::verify() const {
2715 const Node *parent = getResult().getNode();
2716 return checkSameType(getResult(), getInput(), parent);
2717}
2718
2719bool ReluGradNode::verify() const {
2720 return verifyInputAndGradInputTypes(getInput(), getGradOfInputNamedInput(),
2721 this) &&
2722 verifyOutputAndGradOutputTypes(getOriginalOutputForResult(),
2723 getGradOfOriginalOutputNamedResult(),
2724 this) &&
2725 verifyRelu(getGradOfOriginalOutputNamedResult(), getInput());
2726}
2727
2728bool LeakyReluNode::verify() const {
2729 return verifyRelu(getResult(), getInput());
2730}
2731
2732bool PReluNode::verify() const {
2733 return verifyPRelu(getResult(), getInput(), getSlope());
2734}
2735
2736bool RegressionNode::verify() const {
2737 return verifyRegression(getInput(), getResult(), getExpected());
2738}
2739
2740bool RegressionGradNode::verify() const {
2741 return verifyInputAndGradInputTypes(getExpected(),
2742 getGradOfInputNamedExpected(), this) &&
2743 verifyInputAndGradInputTypes(getInput(), getGradOfInputNamedInput(),
2744 this) &&
2745 verifyOutputAndGradOutputTypes(getOriginalOutputForResult(),
2746 getGradOfOriginalOutputNamedResult(),
2747 this) &&
2748 verifyRegression(getGradOfInputNamedInput(),
2749 getGradOfOriginalOutputNamedResult(),
2750 getGradOfInputNamedExpected());
2751}
2752
2753bool SigmoidCrossEntropyWithLogitsNode::verify() const {
2754 bool isValid = checkType(getResult(), getLogits().getElementType(), this);
2755 isValid &= checkSameType(getLogits(), getTargets(), this);
2756 return isValid;
2757}
2758
2759bool GemmNode::verify() const {
2760 NodeValue A = getA();
2761 NodeValue B = getB();
2762 NodeValue C = getC();
2763 NodeValue Y = getResult();
2764 bool transA = getTransposeA();
2765 bool transB = getTransposeB();
2766 const Node *parent = Y.getNode();
2767
2768 // Check types.
2769 bool isValid = checkType(B, A.getElementType(), this);
2770 // Check for element kind of bias
2771 if (C.getNode()) {
2772 // Non quantization type check.
2773 if (A.getElementType() == ElemKind::FloatTy ||
2774 A.getElementType() == ElemKind::Float16Ty) {
2775 isValid &= checkType(C, A.getElementType(), parent);
2776 }
2777 // Quantization type check.
2778 if (A.getElementType() == ElemKind::Int8QTy) {
2779 isValid &= expectCompareTrue("Bias type should be Int8 or Int32 for Gemm",
2780 C.getElementType() == ElemKind::Int8QTy ||
2781 C.getElementType() == ElemKind::Int32QTy,
2782 true, parent);
2783 }
2784 }
2785 isValid &= checkType(Y, A.getElementType(), this);
2786
2787 // Check shapes.
2788 isValid &=
2789 expectCompareTrue("Input A must be 2D", A.dims().size(), size_t(2), this);
2790 isValid &=
2791 expectCompareTrue("Input B must be 2D", B.dims().size(), size_t(2), this);
2792 if (C.getNode()) {
2793 isValid &=
2794 expectCompareTrue("Input C must be 1D or 2D", C.dims().size(),
2795 size_t(2), this, CompareOperatorLessEqual<size_t>());
2796 }
2797 isValid &=
2798 expectCompareTrue("Output must be 2D", Y.dims().size(), size_t(2), this);
2799 std::vector<dim_t> dimsA = A.dims();
2800 std::vector<dim_t> dimsB = B.dims();
2801 if (transA) {
2802 dimsA[0] = A.dims()[1];
2803 dimsA[1] = A.dims()[0];
2804 }
2805 if (transB) {
2806 dimsB[0] = B.dims()[1];
2807 dimsB[1] = B.dims()[0];
2808 }
2809 isValid &= expectCompareTrue("Input A (transposed) dimension 0 size invalid",
2810 dimsA[0], Y.dims()[0], this,
2811 CompareOperatorEqual<dim_t>());
2812 isValid &= expectCompareTrue("Input A (transposed) dimension 1 size invalid",
2813 dimsA[1], dimsB[0], this,
2814 CompareOperatorEqual<dim_t>());
2815 isValid &= expectCompareTrue("Input B (transposed) dimension 1 size invalid",
2816 dimsB[1], Y.dims()[1], this,
2817 CompareOperatorEqual<dim_t>());
2818 if (C.getNode()) {
2819 if (C.dims().size() == 1) {
2820 isValid &=
2821 expectCompareTrue("Input C size invalid", C.dims()[0], Y.dims()[1],
2822 this, CompareOperatorEqual<dim_t>());
2823 } else {
2824 isValid &=
2825 expectCompareTrue("Input C dimension 0 size invalid", C.dims()[0],
2826 Y.dims()[0], this, CompareOperatorEqual<dim_t>());
2827 isValid &=
2828 expectCompareTrue("Input C dimension 1 size invalid", C.dims()[1],
2829 Y.dims()[1], this, CompareOperatorEqual<dim_t>());
2830 }
2831 }
2832 return isValid;
2833}
2834
2835bool LSTMUnitNode::verify() const {
2836 bool isValid = true;
2837 NodeValue C = getC();
2838 auto cDim = C.dims();
2839 NodeValue Input = getInput();
2840 auto inputDim = Input.dims();
2841
2842 isValid &=
2843 expectCompareTrue("Input must be 2D", inputDim.size(), size_t(2), this);
2844 isValid &=
2845 expectCompareTrue("Cell State must be 2D", cDim.size(), size_t(2), this);
2846 isValid &= expectCompareTrue("Input dims[1] must be 4 * C dims[1]",
2847 inputDim[1], 4 * cDim[1], this);
2848 isValid &=
2849 expectCompareTrue("Input dims[0] must be must be the same to C dims[0]",
2850 inputDim[0], cDim[0], this);
2851
2852 return isValid;
2853}
2854
2855bool FullyConnectedNode::verify() const {
2856 return verifyFullyConnected(getInput(), getWeights(), getBias(), getResult());
2857}
2858
2859bool FullyConnectedGradNode::verify() const {
2860 return verifyInputAndGradInputTypes(getBias(), getGradOfInputNamedBias(),
2861 this) &&
2862 verifyInputAndGradInputTypes(getInput(), getGradOfInputNamedInput(),
2863 this) &&
2864 verifyInputAndGradInputTypes(getWeights(),
2865 getGradOfInputNamedWeights(), this) &&
2866 verifyOutputAndGradOutputTypes(getOriginalOutputForResult(),
2867 getGradOfOriginalOutputNamedResult(),
2868 this) &&
2869 verifyFullyConnected(
2870 getGradOfInputNamedInput(), getGradOfInputNamedWeights(),
2871 getGradOfInputNamedBias(), getGradOfOriginalOutputNamedResult());
2872}
2873
2874bool ConcatNode::verify() const {
2875 auto inputs = getInputs();
2876 auto dimension = getDim();
2877 if (!expectCompareTrue("Empty concat?!", inputs.empty(), false, this)) {
2878 return false;
2879 }
2880 bool isValid = expectCompareTrue("concat on invalid dimension",
2881 inputs[0].dims().size(), size_t(dimension),
2882 this, CompareOperatorGreaterThan<size_t>());
2883
2884 size_t nbDims = inputs[0].dims().size();
2885 for (size_t i = 1; i < inputs.size(); i++) {
2886 std::string istr = std::to_string(i);
2887 std::string msg =
2888 "input " + istr + "#dims are incompatible between elements";
2889 if (!expectCompareTrue(msg.c_str(), nbDims, inputs[i].dims().size(),
2890 this)) {
2891 isValid = false;
2892 // The following loop depends on this condition being true.
2893 continue;
2894 }
2895 for (size_t j = 0; j < nbDims; j++) {
2896 if (j == dimension) {
2897 continue;
2898 }
2899 std::string innerMsg = std::to_string(j);
2900 innerMsg =
2901 "Mismatching dimension " + innerMsg + " for input 0 and " + istr;
2902 isValid &= expectCompareTrue(innerMsg.c_str(), inputs[0].dims()[j],
2903 inputs[i].dims()[j], this);
2904 }
2905
2906 for (size_t i = 0; i < inputs.size(); i++) {
2907 isValid &= checkType(inputs[i], getResult().getElementType(), this);
2908 isValid &= checkNotQuantizedOrSameParams(getResult().getType(),
2909 inputs[i].getType(), this);
2910 }
2911 }
2912 return isValid;
2913}
2914
2915bool BatchBoxCoxNode::verify() const {
2916 auto result = getResult();
2917 auto data = getInput();
2918 auto lambda1 = getLambda1();
2919 auto lambda2 = getLambda2();
2920 bool isValid = checkSameType(lambda1, lambda2, this);
2921 isValid &= checkSameType(data, result, this);
2922 isValid &= checkType(data, lambda1.getElementType(), this);
2923 isValid &= checkType(
2924 data, {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
2925 this);
2926 isValid &= expectCompareTrue("Input must be a 2D tensor", data.dims().size(),
2927 size_t(2), this);
2928 isValid &= expectCompareTrue("Lambda1 must be a 1D vector",
2929 lambda1.dims().size(), size_t(1), this);
2930 if (isValid) {
2931 isValid &= expectCompareTrue("Data dim 1 must equal lambda dim",
2932 data.dims()[1], lambda1.dims()[0], this);
2933 }
2934 return isValid;
2935}
2936
2937bool BroadcastNode::verify() const {
2938 const auto inputDims = getInput().dims();
2939 const auto axis = getAxis();
2940 const auto targetDims = getTargetDim();
2941 bool isValid = (axis + inputDims.size() <= targetDims.size());
2942
2943 // Iterate over the new shape; if the original shape had a dimension here
2944 // (when considering the axis) then verify the dimension either matches the
2945 // new shape (no action taken) or == 1 (broadcast in that direction).
2946 for (dim_t i = 0; i < targetDims.size(); i++) {
2947 if (i >= axis && i < inputDims.size() + axis) {
2948 const int origIdx = i - axis;
2949 isValid &=
2950 (inputDims[origIdx] == targetDims[i] || inputDims[origIdx] == 1);
2951 }
2952 }
2953 isValid &= checkTypeIgnoreShape(getInput(), getResult(), this);
2954
2955 return isValid;
2956}
2957
2958bool ModuloNode::verify() const { return getDivisor() >= 1; }
2959
2960bool ExternalFunctionCallNode::verify() const { return true; }
2961
2962static bool verifyBatchedUnaryEmbeddingsBags(NodeValue dest, NodeValue weights,
2963 NodeValue indices,
2964 NodeValue offsets) {
2965 bool isValid = checkType(dest, weights.getElementType(), dest.getNode());
2966 isValid &= checkType(
2967 indices,
2968 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}),
2969 dest.getNode());
2970 isValid &= checkType(
2971 offsets,
2972 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}),
2973 dest.getNode());
2974 isValid &=
2975 expectCompareTrue("Indices must be a 1D vector", indices.dims().size(),
2976 size_t(1), dest.getNode());
2977 isValid &=
2978 expectCompareTrue("Offsets must be a 1D vector", offsets.dims().size(),
2979 size_t(1), dest.getNode());
2980 isValid &=
2981 expectCompareTrue("Weights must be a 3D vector", weights.dims().size(),
2982 size_t(3), dest.getNode());
2983 return isValid;
2984}
2985
2986bool BatchedUnaryEmbeddingsBagsNode::verify() const {
2987 return verifyBatchedUnaryEmbeddingsBags(getResult(), getWeights(),
2988 getIndices(), getOffsets());
2989}
2990
2991static bool verifyIntNBitSplitEmbeddingBagsNode(NodeValue dest,
2992 NodeValue devWeights,
2993 NodeValue weightsOffsets,
2994 NodeValue weightsPlacements,
2995 NodeValue weightsTys) {
2996 bool isValid = checkType(dest, devWeights.getElementType(), dest.getNode());
2997 isValid &= checkType(devWeights, ElemKind::UInt8ITy, devWeights.getNode());
2998 isValid &= checkSameShape(weightsPlacements, weightsTys,
2999 weightsPlacements.getNode());
3000 isValid &= checkType(
3001 weightsOffsets,
3002 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}),
3003 dest.getNode());
3004 return isValid;
3005}
3006
3007bool IntNBitSplitEmbeddingBagsNode::verify() const {
3008 return verifyIntNBitSplitEmbeddingBagsNode(
3009 getResult(), getDevWeights(), getWeightsOffsets(), getWeightsPlacements(),
3010 getWeightsTys());
3011}
3012
3013static bool verifyIntNBitSplitEmbeddingWeightedBagsNode(
3014 NodeValue dest, NodeValue devWeights, NodeValue weightsOffsets,
3015 NodeValue weightsPlacements, NodeValue weightsTys, NodeValue indices,
3016 NodeValue indiceWeights) {
3017 bool isValid = checkType(dest, devWeights.getElementType(), dest.getNode());
3018 isValid &= checkType(devWeights, ElemKind::UInt8ITy, devWeights.getNode());
3019 isValid &= checkSameShape(weightsPlacements, weightsTys,
3020 weightsPlacements.getNode());
3021 isValid &= checkType(
3022 weightsOffsets,
3023 llvm::ArrayRef<ElemKind>({ElemKind::Int64ITy, ElemKind::Int32ITy}),
3024 dest.getNode());
3025 isValid &= checkSameShape(indices, indiceWeights, indiceWeights.getNode());
3026 return isValid;
3027}
3028
3029bool IntNBitSplitEmbeddingWeightedBagsNode::verify() const {
3030 return verifyIntNBitSplitEmbeddingWeightedBagsNode(
3031 getResult(), getDevWeights(), getWeightsOffsets(), getWeightsPlacements(),
3032 getWeightsTys(), getIndices(), getIndiceWeight());
3033}
3034
3035//===----------------------------------------------------------------------===//
3036// Node hashing support
3037//===----------------------------------------------------------------------===//
3038
3039/// These hash functions are required for using llvm::hash_combine.
3040/// hash_value functions should be defined in the same namespace as
3041/// the types they apply to.
3042namespace glow {
3043/// Convert a float into an unsigned integer binary representation.
3044size_t toBinary(float f) {
3045 // Convert floating-point binary representation to integer. memcpy compiles
3046 // to a simple asm move on platforms we support.
3047 static_assert(sizeof(size_t) >= sizeof(float),
3048 "size_t is too small on this platform");
3049 size_t ret = 0;
3050 memcpy(&ret, &f, sizeof(float));
3051 return ret;
3052}
3053/// Convert a collection of floats into a vector of
3054/// unsigned integer binary representation.
3055std::vector<size_t> toBinary(llvm::ArrayRef<float> vec) {
3056 std::vector<size_t> sizeVec(vec.size());
3057 std::for_each(vec.begin(), vec.end(), [&sizeVec](float f) -> void {
3058 sizeVec.push_back(toBinary(f));
3059 });
3060 return sizeVec;
3061}
3062
3063llvm::hash_code hash_value(const glow::Tensor &T) { return T.size(); }
3064
3065// Types are uniqued, so just a pointer can be used.
3066llvm::hash_code hash_value(const glow::Type *T) {
3067 return llvm::hash_value((void *)(T));
3068}
3069
3070llvm::hash_code hash_value(glow::Node *N) { return N->getHash(); }
3071
3072llvm::hash_code hash_value(const glow::NodeValue &NV) {
3073 return llvm::hash_combine(NV.getNode(), NV.getResNo());
3074}
3075
3076llvm::hash_code hash_value(const glow::NodeHandle &NV) {
3077 return llvm::hash_combine(NV.getNode(), NV.getResNo());
3078}
3079
3080llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
3081 FusedActivation fusedActivation) {
3082 switch (fusedActivation) {
3083 case FusedActivation::NONE:
3084 os << "NONE";
3085 break;
3086 case FusedActivation::RELU:
3087 os << "RELU";
3088 break;
3089 case FusedActivation::CLIP:
3090 os << "CLIP";
3091 break;
3092 case FusedActivation::SIGMOID:
3093 os << "SIGMOID";
3094 break;
3095 case FusedActivation::TANH:
3096 os << "TANH";
3097 break;
3098 case FusedActivation::LEAKY_RELU:
3099 os << "LEAKY_RELU";
3100 break;
3101 }
3102 return os;
3103}
3104
3105llvm::raw_ostream &operator<<(llvm::raw_ostream &os, LUTOperator lutOperator) {
3106 switch (lutOperator) {
3107 case LUTOperator::NONE:
3108 os << "NONE";
3109 break;
3110 case LUTOperator::RELU:
3111 os << "RELU";
3112 break;
3113 case LUTOperator::CLIP:
3114 os << "CLIP";
3115 break;
3116 case LUTOperator::SIGMOID:
3117 os << "SIGMOID";
3118 break;
3119 case LUTOperator::TANH:
3120 os << "TANH";
3121 break;
3122 case LUTOperator::LEAKY_RELU:
3123 os << "LEAKY_RELU";
3124 break;
3125 }
3126 return os;
3127}
3128
3129llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ConvolutionLayout layout) {
3130 switch (layout) {
3131 case ConvolutionLayout::NCHW:
3132 os << "NCHW";
3133 break;
3134 case ConvolutionLayout::NHWC:
3135 os << "NHWC";
3136 break;
3137 case ConvolutionLayout::NCTHW:
3138 os << "NCTHW";
3139 break;
3140 case ConvolutionLayout::NTHWC:
3141 os << "NTHWC";
3142 break;
3143 default:
3144 llvm_unreachable("Unknown format");
3145 }
3146 return os;
3147}
3148
3149llvm::raw_ostream &operator<<(llvm::raw_ostream &os, LengthsMode lengthsMode) {
3150 switch (lengthsMode) {
3151 case LengthsMode::AllOne:
3152 os << "AllOne";
3153 break;
3154 case LengthsMode::Variable:
3155 os << "Variable";
3156 break;
3157 }
3158 return os;
3159}
3160} // namespace glow
3161