1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16/// This file describes the API used for graph verification.
17/// These are mainly helper class/functions for printing errors and the related
18/// context and doing checks.
19#ifndef GLOW_GRAPH_VERIFIERHELPER_H
20#define GLOW_GRAPH_VERIFIERHELPER_H
21
22#include "glow/Base/Type.h"
23#include "glow/Graph/Graph.h"
24#include "glow/Graph/Node.h"
25#include "glow/Support/Support.h"
26#include "llvm/ADT/StringRef.h"
27
28namespace glow {
29
30//===----------------------------------------------------------------------===//
31// Printing
32//===----------------------------------------------------------------------===//
33
34/// Default reportContext function used to print \p a.
35/// The default implementation relies on operator<< being available.
36/// The actual printing is done calling glow::report.
37template <typename Ty> void reportContext(const Ty &a) {
38 std::string storage;
39 llvm::raw_string_ostream stringStream(storage);
40 stringStream << a;
41 report(stringStream.str());
42}
43
44template <typename Ty> void reportContext(const llvm::ArrayRef<Ty> &arrayRef) {
45 bool isFirst = true;
46 report("{");
47 for (auto elt : arrayRef) {
48 if (!isFirst) {
49 report(", ");
50 }
51 isFirst = false;
52 reportContext(elt);
53 }
54 report("}");
55}
56
57void reportContext(ElemKind Ty);
58void reportContext(const ShapeNHWC &shapeNHWC);
59void reportContext(const ShapeNCHW &shapeNCHW);
60void reportContext(const ShapeNTHWC &shapeNTHWC);
61void reportContext(const ShapeNCTHW &shapeNCTHW);
62void reportContext(const Node *node);
63void reportContext(const Function *function);
64
65//===----------------------------------------------------------------------===//
66// Checks
67//===----------------------------------------------------------------------===//
68
69/// Wrapper around comparison operators.
70/// They are used to specify the behavior of expectCompareTrue
71/// and provide a pretty printer of the operator used when
72/// things fail.
73/// @{
74
75/// Interface that the comparison operator must implement.
76template <typename Ty> struct CompareWithName {
77 virtual ~CompareWithName() {}
78 /// Binary comparison operation.
79 virtual bool operator()(const Ty &a, const Ty &b) const = 0;
80 /// Name of the operator used for pretty printing.
81 virtual llvm::StringRef getCompareName() const = 0;
82};
83
84/// Operator ==.
85template <typename Ty>
86struct CompareOperatorEqual : public CompareWithName<Ty> {
87 bool operator()(const Ty &a, const Ty &b) const override { return a == b; }
88 llvm::StringRef getCompareName() const override { return "Equal"; }
89};
90
91/// Operator >=.
92template <typename Ty>
93struct CompareOperatorGreaterEqual : public CompareWithName<Ty> {
94 bool operator()(const Ty &a, const Ty &b) const override { return a >= b; }
95 llvm::StringRef getCompareName() const override { return "GreaterEqual"; }
96};
97
98/// Operator >.
99template <typename Ty>
100struct CompareOperatorGreaterThan : public CompareWithName<Ty> {
101 bool operator()(const Ty &a, const Ty &b) const override { return a > b; }
102 llvm::StringRef getCompareName() const override { return "GreaterThan"; }
103};
104
105/// Operator <=.
106template <typename Ty>
107struct CompareOperatorLessEqual : public CompareWithName<Ty> {
108 bool operator()(const Ty &a, const Ty &b) const override { return a <= b; }
109 llvm::StringRef getCompareName() const override { return "LessEqual"; }
110};
111
112/// Operator <.
113template <typename Ty> struct CompareOperatorLess : public CompareWithName<Ty> {
114 bool operator()(const Ty &a, const Ty &b) const override { return a < b; }
115 llvm::StringRef getCompareName() const override { return "Less"; }
116};
117/// @}
118
119/// Main API of the verifier.
120/// Check whether \p comp(\p a, \p b) is true.
121/// If that check fails, \p msg is printed out using glow::report
122/// and \p parent (if not nullptr), \p a, and \p b are printed out
123/// using glow::reportContext.
124/// \returns \p comp(\p a, \p b).
125template <typename InputTy, typename ParentTy>
126bool expectCompareTrue(
127 const char *msg, const InputTy &a, const InputTy &b, const ParentTy *parent,
128 const CompareWithName<InputTy> &comp = CompareOperatorEqual<InputTy>()) {
129 if (comp(a, b)) {
130 return true;
131 }
132 if (parent) {
133 reportContext(parent);
134 report("\n");
135 }
136 report(msg);
137 report("\nFor comparison `LHS ");
138 report(comp.getCompareName());
139 report(" RHS` with:");
140 report("\nLHS: ");
141 reportContext(a);
142 report("\nRHS: ");
143 reportContext(b);
144 report("\n");
145 return false;
146}
147
148/// Check whether $V_{0,n}{comp(\p a, \p b_i)}$ is true.
149/// If that check fails, \p msg is printed out using glow::report
150/// and \p parent (if not nullptr), \p a, and \p b are printed out
151/// using glow::reportContext.
152/// \returns \p comp(\p a, \p b_0) v ... v comp(\p a, \p b_i).
153template <typename InputTy>
154bool expectCompareTrue(
155 const char *msg, const InputTy &a, llvm::ArrayRef<InputTy> b,
156 const Node *parent,
157 const CompareWithName<InputTy> &comp = CompareOperatorEqual<InputTy>()) {
158 bool result = false;
159 for (const auto &bi : b) {
160 result |= comp(a, bi);
161 }
162 if (result) {
163 return true;
164 }
165 if (parent) {
166 reportContext(parent);
167 }
168 report(msg);
169 report("\nFor comparison `LHS ");
170 report(comp.getCompareName());
171 report(" RHS` with:");
172 report("\nLHS: ");
173 reportContext(a);
174 report("\nRHS: ");
175 for (const auto &bi : b) {
176 reportContext(bi);
177 report(", ");
178 }
179 report("\n");
180 return false;
181}
182
183/// Check that the type of the first operand \p A matches the type of the second
184/// operand \p B. \p parent is used to print the context of that check
185/// in case the it fails.
186/// \see expectCompareTrue for more details.
187bool checkSameType(NodeValue A, NodeValue B, const Node *parent);
188
189/// Check that the shape of the first operand \p A matches the shape of the
190/// second operand \p B. \p parent is used to print the context of that check
191/// in case the it fails.
192/// \see expectCompareTrue for more details.
193bool checkSameShape(NodeValue A, NodeValue B, const Node *parent);
194
195/// Check that the element type of the operand \p A matches expected type \p
196/// expectedType. \p parent is used to print the context of that check
197/// in case the it fails.
198/// \see expectCompareTrue for more details.
199bool checkType(NodeValue A, ElemKind expectedType, const Node *parent);
200
201/// Check that the element type of the operand \p A matches expected type \p
202/// expectedType. \p parent is used to print the context of that check
203/// in case the it fails.
204/// \see expectCompareTrue for more details.
205bool checkType(llvm::StringRef msg, NodeValue A, ElemKind expectedType,
206 const Node *parent);
207
208/// Check that the element type of the operand \p A matches any of the expected
209/// types \p expectedTypes. \p parent is used to print the context of that
210/// check in case the it fails. \see expectCompareTrue for more details.
211bool checkType(NodeValue A, llvm::ArrayRef<ElemKind> expectedTypes,
212 const Node *parent);
213
214/// Check that the element type of the operand \p A matches any of the expected
215/// types \p expectedTypes. \p parent is used to print the context of that
216/// check in case the it fails. \see expectCompareTrue for more details.
217bool checkType(llvm::StringRef msg, NodeValue A,
218 llvm::ArrayRef<ElemKind> expectedTypes, const Node *parent);
219
220/// Check if \p A and \p B have the same value for isQuantized. \p parent is
221/// used to print the context of that check in case the it fails.
222/// \see expectCompareTrue for more details.
223bool checkSameIsQuantized(const TypeRef A, const TypeRef B, const Node *parent);
224
225/// \return True if \p A is not quantized or has its quantization parameters
226/// match \p scale and \p offset. False otherwise. \p parent is used to print
227/// the context of that check in case the it fails.
228/// \see expectCompareTrue for more details.
229bool checkNotQuantizedOrSameParams(const TypeRef A, float scale, int32_t offset,
230 const Node *parent);
231
232/// \return True if \p A is not quantized or matches \p B quantization
233/// parameters. False otherwise.
234/// In particular, this returns false if \p A is quantized and \p B
235/// is not. The opposite is not true.
236/// \p parent is used to print the context of that check
237/// in case the it fails.
238/// \see expectCompareTrue for more details.
239bool checkNotQuantizedOrSameParams(const TypeRef A, const TypeRef B,
240 const Node *parent);
241
242/// Check that the type of the first operand \p A matches the type of the second
243/// operand \p B but ignore the actual shape. Use only element type and
244/// quantization parameters in comparison.
245/// \p parent is used to print the context of that check
246/// in case the it fails.
247/// \see expectCompareTrue for more details.
248bool checkTypeIgnoreShape(NodeValue A, NodeValue B, const Node *parent);
249} // namespace glow
250#endif // End of GLOW_GRAPH_VERIFIERHELPER_H.
251