1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file type_functor.cc |
22 | * \brief Implementations of type functors. |
23 | */ |
24 | #include <tvm/ir/type_functor.h> |
25 | |
26 | #include <utility> |
27 | |
28 | namespace tvm { |
29 | |
30 | void TypeVisitor::VisitType_(const TypeVarNode* op) {} |
31 | |
32 | void TypeVisitor::VisitType_(const TensorTypeNode* op) {} |
33 | |
34 | void TypeVisitor::VisitType_(const IncompleteTypeNode* op) {} |
35 | |
36 | void TypeVisitor::VisitType_(const FuncTypeNode* op) { |
37 | for (auto type_param : op->type_params) { |
38 | this->VisitType(type_param); |
39 | } |
40 | |
41 | for (auto type_cs : op->type_constraints) { |
42 | this->VisitType(type_cs); |
43 | } |
44 | |
45 | for (auto arg_type : op->arg_types) { |
46 | this->VisitType(arg_type); |
47 | } |
48 | this->VisitType(op->ret_type); |
49 | } |
50 | |
51 | void TypeVisitor::VisitType_(const TupleTypeNode* op) { |
52 | for (const Type& t : op->fields) { |
53 | this->VisitType(t); |
54 | } |
55 | } |
56 | |
57 | void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { this->VisitType(op->value); } |
58 | |
59 | void TypeVisitor::VisitType_(const TypeRelationNode* op) { |
60 | for (const Type& t : op->args) { |
61 | this->VisitType(t); |
62 | } |
63 | } |
64 | |
65 | void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) {} |
66 | |
67 | void TypeVisitor::VisitType_(const TypeCallNode* op) { |
68 | this->VisitType(op->func); |
69 | for (const Type& t : op->args) { |
70 | this->VisitType(t); |
71 | } |
72 | } |
73 | |
74 | void TypeVisitor::VisitType_(const TypeDataNode* op) { |
75 | this->VisitType(op->header); |
76 | for (const auto& v : op->type_vars) { |
77 | this->VisitType(v); |
78 | } |
79 | |
80 | for (const auto& c : op->constructors) { |
81 | this->VisitType(c->belong_to); |
82 | for (const auto& t : c->inputs) { |
83 | this->VisitType(t); |
84 | } |
85 | } |
86 | } |
87 | |
88 | void TypeVisitor::VisitType_(const PrimTypeNode* op) {} |
89 | |
90 | void TypeVisitor::VisitType_(const PointerTypeNode* op) { this->VisitType(op->element_type); } |
91 | |
92 | Type TypeMutator::VisitType(const Type& t) { |
93 | return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t; |
94 | } |
95 | |
96 | // Type Mutator. |
97 | Array<Type> TypeMutator::MutateArray(Array<Type> arr) { |
98 | // The array will do copy on write |
99 | // If no changes are made, the original array will be returned. |
100 | return arr.Map([this](const Type& ty) { return VisitType(ty); }); |
101 | } |
102 | |
103 | Type TypeMutator::VisitType_(const TypeVarNode* op) { return GetRef<TypeVar>(op); } |
104 | |
105 | Type TypeMutator::VisitType_(const TensorTypeNode* op) { |
106 | // TODO(tvm-team) recursively visit to replace Var |
107 | return GetRef<Type>(op); |
108 | } |
109 | |
110 | Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { return GetRef<Type>(op); } |
111 | |
112 | Type TypeMutator::VisitType_(const FuncTypeNode* op) { |
113 | bool changed = false; |
114 | Array<TypeVar> type_params; |
115 | for (auto type_param : op->type_params) { |
116 | auto new_type_param = VisitType(type_param); |
117 | changed = changed || !new_type_param.same_as(type_param); |
118 | if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) { |
119 | type_params.push_back(GetRef<TypeVar>(tin)); |
120 | } else { |
121 | LOG(FATAL) << new_type_param; |
122 | } |
123 | } |
124 | |
125 | Array<TypeConstraint> type_constraints; |
126 | for (auto type_cs : op->type_constraints) { |
127 | auto new_type_cs = VisitType(type_cs); |
128 | changed = changed || !new_type_cs.same_as(type_cs); |
129 | if (const TypeConstraintNode* tin = new_type_cs.as<TypeConstraintNode>()) { |
130 | type_constraints.push_back(GetRef<TypeConstraint>(tin)); |
131 | } else { |
132 | LOG(FATAL) << new_type_cs; |
133 | } |
134 | } |
135 | |
136 | Array<Type> new_args = MutateArray(op->arg_types); |
137 | changed = changed || !new_args.same_as(op->arg_types); |
138 | |
139 | Type new_ret_type = VisitType(op->ret_type); |
140 | changed = changed || !new_ret_type.same_as(op->ret_type); |
141 | |
142 | if (!changed) return GetRef<Type>(op); |
143 | return FuncType(new_args, new_ret_type, type_params, type_constraints); |
144 | } |
145 | |
146 | Type TypeMutator::VisitType_(const TupleTypeNode* op) { |
147 | Array<Type> new_fields = MutateArray(op->fields); |
148 | if (new_fields.same_as(op->fields)) { |
149 | return GetRef<Type>(op); |
150 | } else { |
151 | return TupleType(new_fields); |
152 | } |
153 | } |
154 | |
155 | Type TypeMutator::VisitType_(const RelayRefTypeNode* op) { |
156 | return RelayRefType(this->VisitType(op->value)); |
157 | } |
158 | |
159 | Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { |
160 | Array<Type> new_args = MutateArray(type_rel->args); |
161 | if (new_args.same_as(type_rel->args)) { |
162 | return GetRef<Type>(type_rel); |
163 | } else { |
164 | return TypeRelation(type_rel->func, new_args, type_rel->num_inputs, type_rel->attrs); |
165 | } |
166 | } |
167 | |
168 | Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { return GetRef<Type>(op); } |
169 | |
170 | Type TypeMutator::VisitType_(const TypeCallNode* op) { |
171 | Type new_func = VisitType(op->func); |
172 | Array<Type> new_args = MutateArray(op->args); |
173 | if (new_args.same_as(op->args) && new_func.same_as(op->func)) { |
174 | return GetRef<TypeCall>(op); |
175 | } else { |
176 | return TypeCall(new_func, new_args); |
177 | } |
178 | } |
179 | |
180 | Type TypeMutator::VisitType_(const TypeDataNode* op) { return GetRef<Type>(op); } |
181 | |
182 | Type TypeMutator::VisitType_(const PrimTypeNode* op) { return GetRef<Type>(op); } |
183 | |
184 | Type TypeMutator::VisitType_(const PointerTypeNode* op) { |
185 | Type element_type = VisitType(op->element_type); |
186 | |
187 | if (element_type.same_as(op->element_type)) { |
188 | return GetRef<Type>(op); |
189 | } else { |
190 | return PointerType(element_type, op->storage_scope); |
191 | } |
192 | } |
193 | |
194 | // Implements bind. |
195 | class TypeBinder : public TypeMutator { |
196 | public: |
197 | explicit TypeBinder(const tvm::Map<TypeVar, Type>& args_map) : args_map_(args_map) {} |
198 | |
199 | Type VisitType_(const TypeVarNode* op) override { |
200 | auto id = GetRef<TypeVar>(op); |
201 | auto it = args_map_.find(id); |
202 | if (it != args_map_.end()) { |
203 | return (*it).second; |
204 | } else { |
205 | return std::move(id); |
206 | } |
207 | } |
208 | |
209 | private: |
210 | const tvm::Map<TypeVar, Type>& args_map_; |
211 | }; |
212 | |
213 | Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) { |
214 | return TypeBinder(args_map).VisitType(type); |
215 | } |
216 | |
217 | } // namespace tvm |
218 | |