1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Graph/VerifierHelper.h"
18
19using namespace glow;
20
21//===----------------------------------------------------------------------===//
22// Printing
23//===----------------------------------------------------------------------===//
24
25void glow::reportContext(ElemKind Ty) { report(Type::getElementName(Ty)); }
26
27void glow::reportContext(const ShapeNHWC &shapeNHWC) {
28 report("NHWC: ");
29 reportContext(llvm::ArrayRef<size_t>(
30 {shapeNHWC.n, shapeNHWC.h, shapeNHWC.w, shapeNHWC.c}));
31}
32
33void glow::reportContext(const ShapeNCHW &shapeNCHW) {
34 report("NCHW: ");
35 reportContext(llvm::ArrayRef<size_t>(
36 {shapeNCHW.n, shapeNCHW.c, shapeNCHW.h, shapeNCHW.w}));
37}
38
39void glow::reportContext(const ShapeNTHWC &shapeNTHWC) {
40 report("NTHWC: ");
41 reportContext(llvm::ArrayRef<size_t>(
42 {shapeNTHWC.n, shapeNTHWC.t, shapeNTHWC.h, shapeNTHWC.w, shapeNTHWC.c}));
43}
44
45void glow::reportContext(const ShapeNCTHW &shapeNCTHW) {
46 report("NCTHW: ");
47 reportContext(llvm::ArrayRef<size_t>(
48 {shapeNCTHW.n, shapeNCTHW.c, shapeNCTHW.t, shapeNCTHW.h, shapeNCTHW.w}));
49}
50
51void glow::reportContext(const Node *node) {
52 report("In '");
53 report(node->getName());
54 report("'");
55 if (const Function *function = node->getParent()) {
56 report(" ");
57 reportContext(function);
58 }
59}
60
61void glow::reportContext(const Function *function) {
62 report("From '");
63 report(function->getName());
64 report("'");
65}
66
67//===----------------------------------------------------------------------===//
68// Checks
69//===----------------------------------------------------------------------===//
70
71bool glow::checkSameType(NodeValue A, NodeValue B, const Node *parent) {
72 return expectCompareTrue("Mismatching type", *A.getType(), *B.getType(),
73 parent);
74}
75
76bool glow::checkSameShape(NodeValue A, NodeValue B, const Node *parent) {
77 return expectCompareTrue("Mismatching dimensions", A.dims(), B.dims(),
78 parent);
79}
80
81bool glow::checkType(NodeValue A, ElemKind expectedType, const Node *parent) {
82 return expectCompareTrue("Mismatching element type", A.getElementType(),
83 expectedType, parent);
84}
85
86bool glow::checkType(llvm::StringRef msg, NodeValue A, ElemKind expectedType,
87 const Node *parent) {
88 std::string errorMsg{msg};
89 errorMsg += ", Mismatching element type";
90 return expectCompareTrue(errorMsg.c_str(), A.getElementType(), expectedType,
91 parent);
92}
93
94bool glow::checkType(NodeValue A, llvm::ArrayRef<ElemKind> expectedTypes,
95 const Node *parent) {
96 return expectCompareTrue("Mismatching element type", A.getElementType(),
97 expectedTypes, parent);
98}
99
100bool glow::checkType(llvm::StringRef msg, NodeValue A,
101 llvm::ArrayRef<ElemKind> expectedTypes,
102 const Node *parent) {
103 std::string errorMsg{msg};
104 errorMsg += ", Mismatching element type";
105 return expectCompareTrue(errorMsg.c_str(), A.getElementType(), expectedTypes,
106 parent);
107}
108
109bool glow::checkSameIsQuantized(const TypeRef A, const TypeRef B,
110 const Node *parent) {
111 return expectCompareTrue("Mismatching isQuantized", A->isQuantizedType(),
112 B->isQuantizedType(), parent);
113}
114
115bool glow::checkNotQuantizedOrSameParams(const TypeRef A, float scale,
116 int32_t offset, const Node *parent) {
117 if (A->isQuantizedType()) {
118 if (!expectCompareTrue("Mismatching scale", A->getScale(), scale, parent) ||
119 !expectCompareTrue("Mismatching offset", A->getOffset(), offset,
120 parent)) {
121 return false;
122 }
123 }
124 return true;
125}
126
127bool glow::checkNotQuantizedOrSameParams(const TypeRef A, const TypeRef B,
128 const Node *parent) {
129 if (!B->isQuantizedType()) {
130 return checkSameIsQuantized(A, B, parent);
131 }
132 return checkNotQuantizedOrSameParams(A, B->getScale(), B->getOffset(),
133 parent);
134}
135
136bool glow::checkTypeIgnoreShape(NodeValue A, NodeValue B, const Node *parent) {
137 bool isValid = checkType(A, B.getElementType(), parent);
138 isValid &= checkSameIsQuantized(A.getType(), B.getType(), parent);
139 isValid &= checkNotQuantizedOrSameParams(A.getType(), B.getType(), parent);
140 return isValid;
141}
142