1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Graph/VerifierHelper.h" |
18 | |
19 | using namespace glow; |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // Printing |
23 | //===----------------------------------------------------------------------===// |
24 | |
25 | void glow::reportContext(ElemKind Ty) { report(Type::getElementName(Ty)); } |
26 | |
27 | void glow::reportContext(const ShapeNHWC &shapeNHWC) { |
28 | report("NHWC: " ); |
29 | reportContext(llvm::ArrayRef<size_t>( |
30 | {shapeNHWC.n, shapeNHWC.h, shapeNHWC.w, shapeNHWC.c})); |
31 | } |
32 | |
33 | void glow::reportContext(const ShapeNCHW &shapeNCHW) { |
34 | report("NCHW: " ); |
35 | reportContext(llvm::ArrayRef<size_t>( |
36 | {shapeNCHW.n, shapeNCHW.c, shapeNCHW.h, shapeNCHW.w})); |
37 | } |
38 | |
39 | void glow::reportContext(const ShapeNTHWC &shapeNTHWC) { |
40 | report("NTHWC: " ); |
41 | reportContext(llvm::ArrayRef<size_t>( |
42 | {shapeNTHWC.n, shapeNTHWC.t, shapeNTHWC.h, shapeNTHWC.w, shapeNTHWC.c})); |
43 | } |
44 | |
45 | void glow::reportContext(const ShapeNCTHW &shapeNCTHW) { |
46 | report("NCTHW: " ); |
47 | reportContext(llvm::ArrayRef<size_t>( |
48 | {shapeNCTHW.n, shapeNCTHW.c, shapeNCTHW.t, shapeNCTHW.h, shapeNCTHW.w})); |
49 | } |
50 | |
51 | void glow::reportContext(const Node *node) { |
52 | report("In '" ); |
53 | report(node->getName()); |
54 | report("'" ); |
55 | if (const Function *function = node->getParent()) { |
56 | report(" " ); |
57 | reportContext(function); |
58 | } |
59 | } |
60 | |
61 | void glow::reportContext(const Function *function) { |
62 | report("From '" ); |
63 | report(function->getName()); |
64 | report("'" ); |
65 | } |
66 | |
67 | //===----------------------------------------------------------------------===// |
68 | // Checks |
69 | //===----------------------------------------------------------------------===// |
70 | |
71 | bool glow::checkSameType(NodeValue A, NodeValue B, const Node *parent) { |
72 | return expectCompareTrue("Mismatching type" , *A.getType(), *B.getType(), |
73 | parent); |
74 | } |
75 | |
76 | bool glow::checkSameShape(NodeValue A, NodeValue B, const Node *parent) { |
77 | return expectCompareTrue("Mismatching dimensions" , A.dims(), B.dims(), |
78 | parent); |
79 | } |
80 | |
81 | bool glow::checkType(NodeValue A, ElemKind expectedType, const Node *parent) { |
82 | return expectCompareTrue("Mismatching element type" , A.getElementType(), |
83 | expectedType, parent); |
84 | } |
85 | |
86 | bool glow::checkType(llvm::StringRef msg, NodeValue A, ElemKind expectedType, |
87 | const Node *parent) { |
88 | std::string errorMsg{msg}; |
89 | errorMsg += ", Mismatching element type" ; |
90 | return expectCompareTrue(errorMsg.c_str(), A.getElementType(), expectedType, |
91 | parent); |
92 | } |
93 | |
94 | bool glow::checkType(NodeValue A, llvm::ArrayRef<ElemKind> expectedTypes, |
95 | const Node *parent) { |
96 | return expectCompareTrue("Mismatching element type" , A.getElementType(), |
97 | expectedTypes, parent); |
98 | } |
99 | |
100 | bool glow::checkType(llvm::StringRef msg, NodeValue A, |
101 | llvm::ArrayRef<ElemKind> expectedTypes, |
102 | const Node *parent) { |
103 | std::string errorMsg{msg}; |
104 | errorMsg += ", Mismatching element type" ; |
105 | return expectCompareTrue(errorMsg.c_str(), A.getElementType(), expectedTypes, |
106 | parent); |
107 | } |
108 | |
109 | bool glow::checkSameIsQuantized(const TypeRef A, const TypeRef B, |
110 | const Node *parent) { |
111 | return expectCompareTrue("Mismatching isQuantized" , A->isQuantizedType(), |
112 | B->isQuantizedType(), parent); |
113 | } |
114 | |
115 | bool glow::checkNotQuantizedOrSameParams(const TypeRef A, float scale, |
116 | int32_t offset, const Node *parent) { |
117 | if (A->isQuantizedType()) { |
118 | if (!expectCompareTrue("Mismatching scale" , A->getScale(), scale, parent) || |
119 | !expectCompareTrue("Mismatching offset" , A->getOffset(), offset, |
120 | parent)) { |
121 | return false; |
122 | } |
123 | } |
124 | return true; |
125 | } |
126 | |
127 | bool glow::checkNotQuantizedOrSameParams(const TypeRef A, const TypeRef B, |
128 | const Node *parent) { |
129 | if (!B->isQuantizedType()) { |
130 | return checkSameIsQuantized(A, B, parent); |
131 | } |
132 | return checkNotQuantizedOrSameParams(A, B->getScale(), B->getOffset(), |
133 | parent); |
134 | } |
135 | |
136 | bool glow::checkTypeIgnoreShape(NodeValue A, NodeValue B, const Node *parent) { |
137 | bool isValid = checkType(A, B.getElementType(), parent); |
138 | isValid &= checkSameIsQuantized(A.getType(), B.getType(), parent); |
139 | isValid &= checkNotQuantizedOrSameParams(A.getType(), B.getType(), parent); |
140 | return isValid; |
141 | } |
142 | |