1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | /// This file describes the API used for graph verification. |
17 | /// These are mainly helper class/functions for printing errors and the related |
18 | /// context and doing checks. |
19 | #ifndef GLOW_GRAPH_VERIFIERHELPER_H |
20 | #define GLOW_GRAPH_VERIFIERHELPER_H |
21 | |
22 | #include "glow/Base/Type.h" |
23 | #include "glow/Graph/Graph.h" |
24 | #include "glow/Graph/Node.h" |
25 | #include "glow/Support/Support.h" |
26 | #include "llvm/ADT/StringRef.h" |
27 | |
28 | namespace glow { |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // Printing |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | /// Default reportContext function used to print \p a. |
35 | /// The default implementation relies on operator<< being available. |
36 | /// The actual printing is done calling glow::report. |
37 | template <typename Ty> void reportContext(const Ty &a) { |
38 | std::string storage; |
39 | llvm::raw_string_ostream stringStream(storage); |
40 | stringStream << a; |
41 | report(stringStream.str()); |
42 | } |
43 | |
44 | template <typename Ty> void reportContext(const llvm::ArrayRef<Ty> &arrayRef) { |
45 | bool isFirst = true; |
46 | report("{" ); |
47 | for (auto elt : arrayRef) { |
48 | if (!isFirst) { |
49 | report(", " ); |
50 | } |
51 | isFirst = false; |
52 | reportContext(elt); |
53 | } |
54 | report("}" ); |
55 | } |
56 | |
57 | void reportContext(ElemKind Ty); |
58 | void reportContext(const ShapeNHWC &shapeNHWC); |
59 | void reportContext(const ShapeNCHW &shapeNCHW); |
60 | void reportContext(const ShapeNTHWC &shapeNTHWC); |
61 | void reportContext(const ShapeNCTHW &shapeNCTHW); |
62 | void reportContext(const Node *node); |
63 | void reportContext(const Function *function); |
64 | |
65 | //===----------------------------------------------------------------------===// |
66 | // Checks |
67 | //===----------------------------------------------------------------------===// |
68 | |
69 | /// Wrapper around comparison operators. |
70 | /// They are used to specify the behavior of expectCompareTrue |
71 | /// and provide a pretty printer of the operator used when |
72 | /// things fail. |
73 | /// @{ |
74 | |
75 | /// Interface that the comparison operator must implement. |
76 | template <typename Ty> struct CompareWithName { |
77 | virtual ~CompareWithName() {} |
78 | /// Binary comparison operation. |
79 | virtual bool operator()(const Ty &a, const Ty &b) const = 0; |
80 | /// Name of the operator used for pretty printing. |
81 | virtual llvm::StringRef getCompareName() const = 0; |
82 | }; |
83 | |
84 | /// Operator ==. |
85 | template <typename Ty> |
86 | struct CompareOperatorEqual : public CompareWithName<Ty> { |
87 | bool operator()(const Ty &a, const Ty &b) const override { return a == b; } |
88 | llvm::StringRef getCompareName() const override { return "Equal" ; } |
89 | }; |
90 | |
91 | /// Operator >=. |
92 | template <typename Ty> |
93 | struct CompareOperatorGreaterEqual : public CompareWithName<Ty> { |
94 | bool operator()(const Ty &a, const Ty &b) const override { return a >= b; } |
95 | llvm::StringRef getCompareName() const override { return "GreaterEqual" ; } |
96 | }; |
97 | |
98 | /// Operator >. |
99 | template <typename Ty> |
100 | struct CompareOperatorGreaterThan : public CompareWithName<Ty> { |
101 | bool operator()(const Ty &a, const Ty &b) const override { return a > b; } |
102 | llvm::StringRef getCompareName() const override { return "GreaterThan" ; } |
103 | }; |
104 | |
105 | /// Operator <=. |
106 | template <typename Ty> |
107 | struct CompareOperatorLessEqual : public CompareWithName<Ty> { |
108 | bool operator()(const Ty &a, const Ty &b) const override { return a <= b; } |
109 | llvm::StringRef getCompareName() const override { return "LessEqual" ; } |
110 | }; |
111 | |
112 | /// Operator <. |
113 | template <typename Ty> struct CompareOperatorLess : public CompareWithName<Ty> { |
114 | bool operator()(const Ty &a, const Ty &b) const override { return a < b; } |
115 | llvm::StringRef getCompareName() const override { return "Less" ; } |
116 | }; |
117 | /// @} |
118 | |
119 | /// Main API of the verifier. |
120 | /// Check whether \p comp(\p a, \p b) is true. |
121 | /// If that check fails, \p msg is printed out using glow::report |
122 | /// and \p parent (if not nullptr), \p a, and \p b are printed out |
123 | /// using glow::reportContext. |
124 | /// \returns \p comp(\p a, \p b). |
125 | template <typename InputTy, typename ParentTy> |
126 | bool expectCompareTrue( |
127 | const char *msg, const InputTy &a, const InputTy &b, const ParentTy *parent, |
128 | const CompareWithName<InputTy> &comp = CompareOperatorEqual<InputTy>()) { |
129 | if (comp(a, b)) { |
130 | return true; |
131 | } |
132 | if (parent) { |
133 | reportContext(parent); |
134 | report("\n" ); |
135 | } |
136 | report(msg); |
137 | report("\nFor comparison `LHS " ); |
138 | report(comp.getCompareName()); |
139 | report(" RHS` with:" ); |
140 | report("\nLHS: " ); |
141 | reportContext(a); |
142 | report("\nRHS: " ); |
143 | reportContext(b); |
144 | report("\n" ); |
145 | return false; |
146 | } |
147 | |
148 | /// Check whether $V_{0,n}{comp(\p a, \p b_i)}$ is true. |
149 | /// If that check fails, \p msg is printed out using glow::report |
150 | /// and \p parent (if not nullptr), \p a, and \p b are printed out |
151 | /// using glow::reportContext. |
152 | /// \returns \p comp(\p a, \p b_0) v ... v comp(\p a, \p b_i). |
153 | template <typename InputTy> |
154 | bool expectCompareTrue( |
155 | const char *msg, const InputTy &a, llvm::ArrayRef<InputTy> b, |
156 | const Node *parent, |
157 | const CompareWithName<InputTy> &comp = CompareOperatorEqual<InputTy>()) { |
158 | bool result = false; |
159 | for (const auto &bi : b) { |
160 | result |= comp(a, bi); |
161 | } |
162 | if (result) { |
163 | return true; |
164 | } |
165 | if (parent) { |
166 | reportContext(parent); |
167 | } |
168 | report(msg); |
169 | report("\nFor comparison `LHS " ); |
170 | report(comp.getCompareName()); |
171 | report(" RHS` with:" ); |
172 | report("\nLHS: " ); |
173 | reportContext(a); |
174 | report("\nRHS: " ); |
175 | for (const auto &bi : b) { |
176 | reportContext(bi); |
177 | report(", " ); |
178 | } |
179 | report("\n" ); |
180 | return false; |
181 | } |
182 | |
183 | /// Check that the type of the first operand \p A matches the type of the second |
184 | /// operand \p B. \p parent is used to print the context of that check |
185 | /// in case the it fails. |
186 | /// \see expectCompareTrue for more details. |
187 | bool checkSameType(NodeValue A, NodeValue B, const Node *parent); |
188 | |
189 | /// Check that the shape of the first operand \p A matches the shape of the |
190 | /// second operand \p B. \p parent is used to print the context of that check |
191 | /// in case the it fails. |
192 | /// \see expectCompareTrue for more details. |
193 | bool checkSameShape(NodeValue A, NodeValue B, const Node *parent); |
194 | |
195 | /// Check that the element type of the operand \p A matches expected type \p |
196 | /// expectedType. \p parent is used to print the context of that check |
197 | /// in case the it fails. |
198 | /// \see expectCompareTrue for more details. |
199 | bool checkType(NodeValue A, ElemKind expectedType, const Node *parent); |
200 | |
201 | /// Check that the element type of the operand \p A matches expected type \p |
202 | /// expectedType. \p parent is used to print the context of that check |
203 | /// in case the it fails. |
204 | /// \see expectCompareTrue for more details. |
205 | bool checkType(llvm::StringRef msg, NodeValue A, ElemKind expectedType, |
206 | const Node *parent); |
207 | |
208 | /// Check that the element type of the operand \p A matches any of the expected |
209 | /// types \p expectedTypes. \p parent is used to print the context of that |
210 | /// check in case the it fails. \see expectCompareTrue for more details. |
211 | bool checkType(NodeValue A, llvm::ArrayRef<ElemKind> expectedTypes, |
212 | const Node *parent); |
213 | |
214 | /// Check that the element type of the operand \p A matches any of the expected |
215 | /// types \p expectedTypes. \p parent is used to print the context of that |
216 | /// check in case the it fails. \see expectCompareTrue for more details. |
217 | bool checkType(llvm::StringRef msg, NodeValue A, |
218 | llvm::ArrayRef<ElemKind> expectedTypes, const Node *parent); |
219 | |
220 | /// Check if \p A and \p B have the same value for isQuantized. \p parent is |
221 | /// used to print the context of that check in case the it fails. |
222 | /// \see expectCompareTrue for more details. |
223 | bool checkSameIsQuantized(const TypeRef A, const TypeRef B, const Node *parent); |
224 | |
225 | /// \return True if \p A is not quantized or has its quantization parameters |
226 | /// match \p scale and \p offset. False otherwise. \p parent is used to print |
227 | /// the context of that check in case the it fails. |
228 | /// \see expectCompareTrue for more details. |
229 | bool checkNotQuantizedOrSameParams(const TypeRef A, float scale, int32_t offset, |
230 | const Node *parent); |
231 | |
232 | /// \return True if \p A is not quantized or matches \p B quantization |
233 | /// parameters. False otherwise. |
234 | /// In particular, this returns false if \p A is quantized and \p B |
235 | /// is not. The opposite is not true. |
236 | /// \p parent is used to print the context of that check |
237 | /// in case the it fails. |
238 | /// \see expectCompareTrue for more details. |
239 | bool checkNotQuantizedOrSameParams(const TypeRef A, const TypeRef B, |
240 | const Node *parent); |
241 | |
242 | /// Check that the type of the first operand \p A matches the type of the second |
243 | /// operand \p B but ignore the actual shape. Use only element type and |
244 | /// quantization parameters in comparison. |
245 | /// \p parent is used to print the context of that check |
246 | /// in case the it fails. |
247 | /// \see expectCompareTrue for more details. |
248 | bool checkTypeIgnoreShape(NodeValue A, NodeValue B, const Node *parent); |
249 | } // namespace glow |
250 | #endif // End of GLOW_GRAPH_VERIFIERHELPER_H. |
251 | |