1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/ir/pattern_functor.cc
22 * \brief Implementations of visitors and mutators for ADT patterns.
23 */
24
25#include <tvm/relay/pattern_functor.h>
26
27namespace tvm {
28namespace relay {
29
30Pattern PatternMutator::Mutate(const Pattern& pat) { return (*this)(pat); }
31
32Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { return GetRef<Pattern>(op); }
33
34Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) {
35 return PatternVar(VisitVar(op->var));
36}
37
38Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
39 std::vector<Pattern> pat;
40 for (const auto& p : op->patterns) {
41 pat.push_back(VisitPattern(p));
42 }
43 return PatternConstructor(VisitConstructor(op->constructor), pat);
44}
45
46Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) {
47 std::vector<Pattern> pat;
48 for (const auto& p : op->patterns) {
49 pat.push_back(VisitPattern(p));
50 }
51 return PatternTuple(pat);
52}
53
54Type PatternMutator::VisitType(const Type& t) { return t; }
55
56Var PatternMutator::VisitVar(const Var& v) {
57 if (var_map_.count(v) == 0) {
58 var_map_.insert(std::pair<Var, Var>(v, Var(v->name_hint(), VisitType(v->type_annotation))));
59 }
60 return var_map_.at(v);
61}
62
63Constructor PatternMutator::VisitConstructor(const Constructor& v) { return v; }
64
65void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) {}
66
67void PatternVisitor::VisitPattern_(const PatternVarNode* op) { VisitVar(op->var); }
68
69void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) {
70 VisitConstructor(op->constructor);
71 for (const auto& p : op->patterns) {
72 VisitPattern(p);
73 }
74}
75
76void PatternVisitor::VisitPattern_(const PatternTupleNode* op) {
77 for (const auto& p : op->patterns) {
78 VisitPattern(p);
79 }
80}
81
82void PatternVisitor::VisitType(const Type& t) {}
83
84void PatternVisitor::VisitVar(const Var& v) { VisitType(v->type_annotation); }
85
86void PatternVisitor::VisitConstructor(const Constructor& c) {
87 for (const auto& inp : c->inputs) {
88 VisitType(inp);
89 }
90}
91
92} // namespace relay
93} // namespace tvm
94