1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file match_exhaustion.cc |
22 | * \brief Checking Relay match expression exhaustiveness. |
23 | * |
24 | * This file implements a function that checks whether a match |
25 | * expression is exhaustive, that is, whether a given match clause |
26 | * matches every possible case. This is important for ensuring |
27 | * code correctness, since hitting an unmatched case results in a |
28 | * dynamic error unless exhaustiveness is checked in advance. |
29 | */ |
30 | #include <tvm/relay/adt.h> |
31 | #include <tvm/relay/error.h> |
32 | #include <tvm/relay/expr_functor.h> |
33 | #include <tvm/relay/pattern_functor.h> |
34 | |
35 | #include <stack> |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | /*! \brief Possible pattern match results */ |
41 | enum MatchResult : int { |
42 | kMatch = 0, // pattern matches |
43 | kClash = 1, // pattern conflicts |
44 | kUnspecified = 2, // ambiguous: candidate needs more constructors specified |
45 | }; |
46 | |
47 | class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const Pattern&)> { |
48 | public: |
49 | explicit CandidateChecker() {} |
50 | |
51 | MatchResult Check(const Pattern& pat, const Pattern& candidate) { |
52 | return this->VisitPattern(pat, candidate); |
53 | } |
54 | |
55 | // for a constructor pattern, we must ensure that the candidate is |
56 | // a ConstructorPattern, that it has the same constructor, and |
57 | // that its fields match the subpatterns. |
58 | MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override { |
59 | auto* ctor_cand = cand.as<PatternConstructorNode>(); |
60 | // attempting to match non-constructor to constructor pattern: need to specify |
61 | if (ctor_cand == nullptr) { |
62 | return MatchResult::kUnspecified; |
63 | } |
64 | |
65 | // check that constructors match |
66 | if (!op->constructor.same_as(ctor_cand->constructor)) { |
67 | return MatchResult::kClash; |
68 | } |
69 | |
70 | // now check that subpatterns match |
71 | ICHECK_EQ(op->patterns.size(), ctor_cand->patterns.size()); |
72 | bool unspecified = false; |
73 | for (size_t i = 0; i < op->patterns.size(); i++) { |
74 | MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]); |
75 | // if we have a clash anywhere, then we can return clash |
76 | if (submatch == MatchResult::kClash) { |
77 | return MatchResult::kClash; |
78 | } |
79 | if (submatch == MatchResult::kUnspecified) { |
80 | unspecified = true; |
81 | } |
82 | } |
83 | // only return unspecified if we have ruled out a clash |
84 | if (unspecified) { |
85 | return MatchResult::kUnspecified; |
86 | } |
87 | return MatchResult::kMatch; |
88 | } |
89 | |
90 | MatchResult VisitPattern_(const PatternTupleNode* op, const Pattern& cand) override { |
91 | auto* tuple_cand = cand.as<PatternTupleNode>(); |
92 | // attempting to match non-tuple to constructor pattern: need to specify |
93 | if (tuple_cand == nullptr) { |
94 | return MatchResult::kUnspecified; |
95 | } |
96 | |
97 | // now check that subpatterns match |
98 | ICHECK_EQ(op->patterns.size(), tuple_cand->patterns.size()); |
99 | bool unspecified = false; |
100 | for (size_t i = 0; i < op->patterns.size(); i++) { |
101 | MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]); |
102 | // if we have a clash anywhere, then we can return clash |
103 | if (submatch == MatchResult::kClash) { |
104 | return MatchResult::kClash; |
105 | } |
106 | if (submatch == MatchResult::kUnspecified) { |
107 | unspecified = true; |
108 | } |
109 | } |
110 | // only return unspecified if we have ruled out a clash |
111 | if (unspecified) { |
112 | return MatchResult::kUnspecified; |
113 | } |
114 | return MatchResult::kMatch; |
115 | } |
116 | |
117 | // wildcard and var patterns always match |
118 | MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override { |
119 | return MatchResult::kMatch; |
120 | } |
121 | |
122 | MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override { |
123 | return MatchResult::kMatch; |
124 | } |
125 | }; |
126 | |
127 | // Returns list of arrays corresponding to Cartesian product of input list. |
128 | // Note: CartesianProduct({}) = {{}} |
129 | Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) { |
130 | // the only combination of 0 fields is 0 fields |
131 | if (fields.size() == 0) { |
132 | return {{}}; |
133 | } |
134 | |
135 | Array<Pattern> field_vals = fields[fields.size() - 1]; |
136 | Array<Array<Pattern>> ret; |
137 | |
138 | // base case: this is the last field left |
139 | if (fields.size() == 1) { |
140 | for (auto val : field_vals) { |
141 | ret.push_back(Array<Pattern>{val}); |
142 | } |
143 | return ret; |
144 | } |
145 | |
146 | // if we have more fields left, get the sub-candidates by getting |
147 | // their cartesian product and appending the elements here onto those |
148 | Array<Array<Pattern>> remaining_fields; |
149 | for (size_t i = 0; i < fields.size() - 1; i++) { |
150 | remaining_fields.push_back(fields[i]); |
151 | } |
152 | Array<Array<Pattern>> candidates = CartesianProduct(remaining_fields); |
153 | for (auto val : field_vals) { |
154 | for (auto candidate : candidates) { |
155 | candidate.push_back(val); |
156 | ret.push_back(candidate); |
157 | } |
158 | } |
159 | return ret; |
160 | } |
161 | |
162 | Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, |
163 | const Pattern& cand, const IRModule& mod); |
164 | |
165 | Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, |
166 | const IRModule& mod); |
167 | |
168 | // Expands all wildcards in the candidate pattern once |
169 | // Returns a list of all possible expansions. |
170 | Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, |
171 | const IRModule& mod) { |
172 | if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) { |
173 | return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod); |
174 | } else if (auto clause_tup = clause_pat.as<PatternTupleNode>()) { |
175 | return ExpandWildcardsTuple(GetRef<PatternTuple>(clause_tup), cand, mod); |
176 | } else { |
177 | return {cand}; |
178 | } |
179 | } |
180 | |
181 | // Expands all wildcards in the candidate pattern once. |
182 | // Use the pattern to decide which constructors to insert. |
183 | // Returns a list of all possible expansions. |
184 | Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, |
185 | const Pattern& cand, const IRModule& mod) { |
186 | auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to); |
187 | |
188 | // for a wildcard node, create constructor nodes with wildcards for all args. |
189 | if (cand.as<PatternWildcardNode>()) { |
190 | TypeData td = mod->LookupTypeDef(gtv); |
191 | // for each constructor add a candidate. |
192 | Array<Pattern> ret; |
193 | for (auto constructor : td->constructors) { |
194 | Array<Pattern> args; |
195 | for (auto inp : constructor->inputs) { |
196 | args.push_back(PatternWildcard()); |
197 | } |
198 | ret.push_back(PatternConstructor(constructor, args)); |
199 | } |
200 | return ret; |
201 | } |
202 | |
203 | auto ctor_cand = Downcast<PatternConstructor>(cand); |
204 | |
205 | // expand all fields' wildcards |
206 | Array<Array<Pattern>> values_by_field; |
207 | for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { |
208 | values_by_field.push_back( |
209 | ExpandWildcards(clause_ctor->patterns[i], ctor_cand->patterns[i], mod)); |
210 | } |
211 | |
212 | // generate new candidates using a cartesian product. |
213 | auto all_subfields = CartesianProduct(values_by_field); |
214 | Array<Pattern> ret; |
215 | for (auto subfields : all_subfields) { |
216 | ret.push_back(PatternConstructor(ctor_cand->constructor, subfields)); |
217 | } |
218 | return ret; |
219 | } |
220 | |
221 | // Expands all wildcards in the candidate pattern once. |
222 | // Returns a list of all possible expansions. |
223 | Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, |
224 | const IRModule& mod) { |
225 | // for a wildcard node, create tuple with wildcards for all args. |
226 | if (cand.as<PatternWildcardNode>()) { |
227 | Array<Pattern> args; |
228 | for (auto inp : clause_tuple->patterns) { |
229 | args.push_back(PatternWildcard()); |
230 | } |
231 | return {PatternTuple(args)}; |
232 | } |
233 | |
234 | auto tuple_cand = Downcast<PatternTuple>(cand); |
235 | |
236 | // expand all members' patterns |
237 | Array<Array<Pattern>> values_by_field; |
238 | for (size_t i = 0; i < tuple_cand->patterns.size(); i++) { |
239 | values_by_field.push_back( |
240 | ExpandWildcards(clause_tuple->patterns[i], tuple_cand->patterns[i], mod)); |
241 | } |
242 | |
243 | // generate new candidates using a cartesian product |
244 | auto all_subfields = CartesianProduct(values_by_field); |
245 | Array<Pattern> ret; |
246 | for (auto subfields : all_subfields) { |
247 | ret.push_back(PatternTuple(subfields)); |
248 | } |
249 | return ret; |
250 | } |
251 | |
252 | /*! |
253 | * \brief Finds cases that the match expression does not catch, if any. |
254 | * \return Returns a list of cases that are not handled by the match |
255 | * expression. |
256 | */ |
257 | Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) { |
258 | /* algorithm: |
259 | * candidates = { Wildcard } |
260 | * while candidates not empty { |
261 | * cand = candidates.pop() |
262 | * for clause in clauses { |
263 | * if clause fails: next clause |
264 | * if clause matches candidate: next candidate |
265 | * if candidate is not specific enough: |
266 | * candidates += expand_possible_wildcards(cand) |
267 | * next candidate |
268 | * } |
269 | * failed_candidates += { cand } |
270 | * } |
271 | * return failed_candidates |
272 | */ |
273 | std::stack<Pattern> candidates; |
274 | candidates.push(PatternWildcard()); |
275 | CandidateChecker checker; |
276 | |
277 | Array<Pattern> failures; |
278 | |
279 | while (!candidates.empty()) { |
280 | Pattern cand = candidates.top(); |
281 | candidates.pop(); |
282 | |
283 | bool failure = true; |
284 | for (auto clause : match->clauses) { |
285 | // if the check fails, we move on to the next |
286 | MatchResult check = checker.Check(clause->lhs, cand); |
287 | if (check == MatchResult::kClash) { |
288 | continue; |
289 | } |
290 | |
291 | // either success or we need to generate more candidates; |
292 | // either way, we're done with this candidate |
293 | failure = false; |
294 | if (check == MatchResult::kUnspecified) { |
295 | auto new_candidates = ExpandWildcards(clause->lhs, cand, mod); |
296 | for (auto candidate : new_candidates) { |
297 | candidates.push(candidate); |
298 | } |
299 | } |
300 | break; |
301 | } |
302 | |
303 | if (failure) { |
304 | failures.push_back(cand); |
305 | } |
306 | } |
307 | |
308 | return failures; |
309 | } |
310 | |
311 | // expose for testing only |
312 | TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases" ) |
313 | .set_body_typed([](const Match& match, const Optional<IRModule>& mod_ref) { |
314 | IRModule call_mod = mod_ref.defined() ? mod_ref.value() : IRModule({}, {}); |
315 | return UnmatchedCases(match, call_mod); |
316 | }); |
317 | |
318 | } // namespace relay |
319 | } // namespace tvm |
320 | |