1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file match_exhaustion.cc
22 * \brief Checking Relay match expression exhaustiveness.
23 *
24 * This file implements a function that checks whether a match
25 * expression is exhaustive, that is, whether a given match clause
26 * matches every possible case. This is important for ensuring
27 * code correctness, since hitting an unmatched case results in a
28 * dynamic error unless exhaustiveness is checked in advance.
29 */
30#include <tvm/relay/adt.h>
31#include <tvm/relay/error.h>
32#include <tvm/relay/expr_functor.h>
33#include <tvm/relay/pattern_functor.h>
34
35#include <stack>
36
37namespace tvm {
38namespace relay {
39
40/*! \brief Possible pattern match results */
41enum MatchResult : int {
42 kMatch = 0, // pattern matches
43 kClash = 1, // pattern conflicts
44 kUnspecified = 2, // ambiguous: candidate needs more constructors specified
45};
46
47class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const Pattern&)> {
48 public:
49 explicit CandidateChecker() {}
50
51 MatchResult Check(const Pattern& pat, const Pattern& candidate) {
52 return this->VisitPattern(pat, candidate);
53 }
54
55 // for a constructor pattern, we must ensure that the candidate is
56 // a ConstructorPattern, that it has the same constructor, and
57 // that its fields match the subpatterns.
58 MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override {
59 auto* ctor_cand = cand.as<PatternConstructorNode>();
60 // attempting to match non-constructor to constructor pattern: need to specify
61 if (ctor_cand == nullptr) {
62 return MatchResult::kUnspecified;
63 }
64
65 // check that constructors match
66 if (!op->constructor.same_as(ctor_cand->constructor)) {
67 return MatchResult::kClash;
68 }
69
70 // now check that subpatterns match
71 ICHECK_EQ(op->patterns.size(), ctor_cand->patterns.size());
72 bool unspecified = false;
73 for (size_t i = 0; i < op->patterns.size(); i++) {
74 MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
75 // if we have a clash anywhere, then we can return clash
76 if (submatch == MatchResult::kClash) {
77 return MatchResult::kClash;
78 }
79 if (submatch == MatchResult::kUnspecified) {
80 unspecified = true;
81 }
82 }
83 // only return unspecified if we have ruled out a clash
84 if (unspecified) {
85 return MatchResult::kUnspecified;
86 }
87 return MatchResult::kMatch;
88 }
89
90 MatchResult VisitPattern_(const PatternTupleNode* op, const Pattern& cand) override {
91 auto* tuple_cand = cand.as<PatternTupleNode>();
92 // attempting to match non-tuple to constructor pattern: need to specify
93 if (tuple_cand == nullptr) {
94 return MatchResult::kUnspecified;
95 }
96
97 // now check that subpatterns match
98 ICHECK_EQ(op->patterns.size(), tuple_cand->patterns.size());
99 bool unspecified = false;
100 for (size_t i = 0; i < op->patterns.size(); i++) {
101 MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]);
102 // if we have a clash anywhere, then we can return clash
103 if (submatch == MatchResult::kClash) {
104 return MatchResult::kClash;
105 }
106 if (submatch == MatchResult::kUnspecified) {
107 unspecified = true;
108 }
109 }
110 // only return unspecified if we have ruled out a clash
111 if (unspecified) {
112 return MatchResult::kUnspecified;
113 }
114 return MatchResult::kMatch;
115 }
116
117 // wildcard and var patterns always match
118 MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
119 return MatchResult::kMatch;
120 }
121
122 MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override {
123 return MatchResult::kMatch;
124 }
125};
126
127// Returns list of arrays corresponding to Cartesian product of input list.
128// Note: CartesianProduct({}) = {{}}
129Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
130 // the only combination of 0 fields is 0 fields
131 if (fields.size() == 0) {
132 return {{}};
133 }
134
135 Array<Pattern> field_vals = fields[fields.size() - 1];
136 Array<Array<Pattern>> ret;
137
138 // base case: this is the last field left
139 if (fields.size() == 1) {
140 for (auto val : field_vals) {
141 ret.push_back(Array<Pattern>{val});
142 }
143 return ret;
144 }
145
146 // if we have more fields left, get the sub-candidates by getting
147 // their cartesian product and appending the elements here onto those
148 Array<Array<Pattern>> remaining_fields;
149 for (size_t i = 0; i < fields.size() - 1; i++) {
150 remaining_fields.push_back(fields[i]);
151 }
152 Array<Array<Pattern>> candidates = CartesianProduct(remaining_fields);
153 for (auto val : field_vals) {
154 for (auto candidate : candidates) {
155 candidate.push_back(val);
156 ret.push_back(candidate);
157 }
158 }
159 return ret;
160}
161
162Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
163 const Pattern& cand, const IRModule& mod);
164
165Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand,
166 const IRModule& mod);
167
168// Expands all wildcards in the candidate pattern once
169// Returns a list of all possible expansions.
170Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
171 const IRModule& mod) {
172 if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
173 return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
174 } else if (auto clause_tup = clause_pat.as<PatternTupleNode>()) {
175 return ExpandWildcardsTuple(GetRef<PatternTuple>(clause_tup), cand, mod);
176 } else {
177 return {cand};
178 }
179}
180
181// Expands all wildcards in the candidate pattern once.
182// Use the pattern to decide which constructors to insert.
183// Returns a list of all possible expansions.
184Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
185 const Pattern& cand, const IRModule& mod) {
186 auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
187
188 // for a wildcard node, create constructor nodes with wildcards for all args.
189 if (cand.as<PatternWildcardNode>()) {
190 TypeData td = mod->LookupTypeDef(gtv);
191 // for each constructor add a candidate.
192 Array<Pattern> ret;
193 for (auto constructor : td->constructors) {
194 Array<Pattern> args;
195 for (auto inp : constructor->inputs) {
196 args.push_back(PatternWildcard());
197 }
198 ret.push_back(PatternConstructor(constructor, args));
199 }
200 return ret;
201 }
202
203 auto ctor_cand = Downcast<PatternConstructor>(cand);
204
205 // expand all fields' wildcards
206 Array<Array<Pattern>> values_by_field;
207 for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
208 values_by_field.push_back(
209 ExpandWildcards(clause_ctor->patterns[i], ctor_cand->patterns[i], mod));
210 }
211
212 // generate new candidates using a cartesian product.
213 auto all_subfields = CartesianProduct(values_by_field);
214 Array<Pattern> ret;
215 for (auto subfields : all_subfields) {
216 ret.push_back(PatternConstructor(ctor_cand->constructor, subfields));
217 }
218 return ret;
219}
220
221// Expands all wildcards in the candidate pattern once.
222// Returns a list of all possible expansions.
223Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand,
224 const IRModule& mod) {
225 // for a wildcard node, create tuple with wildcards for all args.
226 if (cand.as<PatternWildcardNode>()) {
227 Array<Pattern> args;
228 for (auto inp : clause_tuple->patterns) {
229 args.push_back(PatternWildcard());
230 }
231 return {PatternTuple(args)};
232 }
233
234 auto tuple_cand = Downcast<PatternTuple>(cand);
235
236 // expand all members' patterns
237 Array<Array<Pattern>> values_by_field;
238 for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
239 values_by_field.push_back(
240 ExpandWildcards(clause_tuple->patterns[i], tuple_cand->patterns[i], mod));
241 }
242
243 // generate new candidates using a cartesian product
244 auto all_subfields = CartesianProduct(values_by_field);
245 Array<Pattern> ret;
246 for (auto subfields : all_subfields) {
247 ret.push_back(PatternTuple(subfields));
248 }
249 return ret;
250}
251
252/*!
253 * \brief Finds cases that the match expression does not catch, if any.
254 * \return Returns a list of cases that are not handled by the match
255 * expression.
256 */
257Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {
258 /* algorithm:
259 * candidates = { Wildcard }
260 * while candidates not empty {
261 * cand = candidates.pop()
262 * for clause in clauses {
263 * if clause fails: next clause
264 * if clause matches candidate: next candidate
265 * if candidate is not specific enough:
266 * candidates += expand_possible_wildcards(cand)
267 * next candidate
268 * }
269 * failed_candidates += { cand }
270 * }
271 * return failed_candidates
272 */
273 std::stack<Pattern> candidates;
274 candidates.push(PatternWildcard());
275 CandidateChecker checker;
276
277 Array<Pattern> failures;
278
279 while (!candidates.empty()) {
280 Pattern cand = candidates.top();
281 candidates.pop();
282
283 bool failure = true;
284 for (auto clause : match->clauses) {
285 // if the check fails, we move on to the next
286 MatchResult check = checker.Check(clause->lhs, cand);
287 if (check == MatchResult::kClash) {
288 continue;
289 }
290
291 // either success or we need to generate more candidates;
292 // either way, we're done with this candidate
293 failure = false;
294 if (check == MatchResult::kUnspecified) {
295 auto new_candidates = ExpandWildcards(clause->lhs, cand, mod);
296 for (auto candidate : new_candidates) {
297 candidates.push(candidate);
298 }
299 }
300 break;
301 }
302
303 if (failure) {
304 failures.push_back(cand);
305 }
306 }
307
308 return failures;
309}
310
311// expose for testing only
312TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases")
313 .set_body_typed([](const Match& match, const Optional<IRModule>& mod_ref) {
314 IRModule call_mod = mod_ref.defined() ? mod_ref.value() : IRModule({}, {});
315 return UnmatchedCases(match, call_mod);
316 });
317
318} // namespace relay
319} // namespace tvm
320