1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file type_functor.cc
22 * \brief Implementations of type functors.
23 */
24#include <tvm/ir/type_functor.h>
25
26#include <utility>
27
28namespace tvm {
29
30void TypeVisitor::VisitType_(const TypeVarNode* op) {}
31
32void TypeVisitor::VisitType_(const TensorTypeNode* op) {}
33
34void TypeVisitor::VisitType_(const IncompleteTypeNode* op) {}
35
36void TypeVisitor::VisitType_(const FuncTypeNode* op) {
37 for (auto type_param : op->type_params) {
38 this->VisitType(type_param);
39 }
40
41 for (auto type_cs : op->type_constraints) {
42 this->VisitType(type_cs);
43 }
44
45 for (auto arg_type : op->arg_types) {
46 this->VisitType(arg_type);
47 }
48 this->VisitType(op->ret_type);
49}
50
51void TypeVisitor::VisitType_(const TupleTypeNode* op) {
52 for (const Type& t : op->fields) {
53 this->VisitType(t);
54 }
55}
56
57void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { this->VisitType(op->value); }
58
59void TypeVisitor::VisitType_(const TypeRelationNode* op) {
60 for (const Type& t : op->args) {
61 this->VisitType(t);
62 }
63}
64
65void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) {}
66
67void TypeVisitor::VisitType_(const TypeCallNode* op) {
68 this->VisitType(op->func);
69 for (const Type& t : op->args) {
70 this->VisitType(t);
71 }
72}
73
74void TypeVisitor::VisitType_(const TypeDataNode* op) {
75 this->VisitType(op->header);
76 for (const auto& v : op->type_vars) {
77 this->VisitType(v);
78 }
79
80 for (const auto& c : op->constructors) {
81 this->VisitType(c->belong_to);
82 for (const auto& t : c->inputs) {
83 this->VisitType(t);
84 }
85 }
86}
87
88void TypeVisitor::VisitType_(const PrimTypeNode* op) {}
89
90void TypeVisitor::VisitType_(const PointerTypeNode* op) { this->VisitType(op->element_type); }
91
92Type TypeMutator::VisitType(const Type& t) {
93 return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
94}
95
96// Type Mutator.
97Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
98 // The array will do copy on write
99 // If no changes are made, the original array will be returned.
100 return arr.Map([this](const Type& ty) { return VisitType(ty); });
101}
102
103Type TypeMutator::VisitType_(const TypeVarNode* op) { return GetRef<TypeVar>(op); }
104
105Type TypeMutator::VisitType_(const TensorTypeNode* op) {
106 // TODO(tvm-team) recursively visit to replace Var
107 return GetRef<Type>(op);
108}
109
110Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { return GetRef<Type>(op); }
111
112Type TypeMutator::VisitType_(const FuncTypeNode* op) {
113 bool changed = false;
114 Array<TypeVar> type_params;
115 for (auto type_param : op->type_params) {
116 auto new_type_param = VisitType(type_param);
117 changed = changed || !new_type_param.same_as(type_param);
118 if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
119 type_params.push_back(GetRef<TypeVar>(tin));
120 } else {
121 LOG(FATAL) << new_type_param;
122 }
123 }
124
125 Array<TypeConstraint> type_constraints;
126 for (auto type_cs : op->type_constraints) {
127 auto new_type_cs = VisitType(type_cs);
128 changed = changed || !new_type_cs.same_as(type_cs);
129 if (const TypeConstraintNode* tin = new_type_cs.as<TypeConstraintNode>()) {
130 type_constraints.push_back(GetRef<TypeConstraint>(tin));
131 } else {
132 LOG(FATAL) << new_type_cs;
133 }
134 }
135
136 Array<Type> new_args = MutateArray(op->arg_types);
137 changed = changed || !new_args.same_as(op->arg_types);
138
139 Type new_ret_type = VisitType(op->ret_type);
140 changed = changed || !new_ret_type.same_as(op->ret_type);
141
142 if (!changed) return GetRef<Type>(op);
143 return FuncType(new_args, new_ret_type, type_params, type_constraints);
144}
145
146Type TypeMutator::VisitType_(const TupleTypeNode* op) {
147 Array<Type> new_fields = MutateArray(op->fields);
148 if (new_fields.same_as(op->fields)) {
149 return GetRef<Type>(op);
150 } else {
151 return TupleType(new_fields);
152 }
153}
154
155Type TypeMutator::VisitType_(const RelayRefTypeNode* op) {
156 return RelayRefType(this->VisitType(op->value));
157}
158
159Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
160 Array<Type> new_args = MutateArray(type_rel->args);
161 if (new_args.same_as(type_rel->args)) {
162 return GetRef<Type>(type_rel);
163 } else {
164 return TypeRelation(type_rel->func, new_args, type_rel->num_inputs, type_rel->attrs);
165 }
166}
167
168Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { return GetRef<Type>(op); }
169
170Type TypeMutator::VisitType_(const TypeCallNode* op) {
171 Type new_func = VisitType(op->func);
172 Array<Type> new_args = MutateArray(op->args);
173 if (new_args.same_as(op->args) && new_func.same_as(op->func)) {
174 return GetRef<TypeCall>(op);
175 } else {
176 return TypeCall(new_func, new_args);
177 }
178}
179
180Type TypeMutator::VisitType_(const TypeDataNode* op) { return GetRef<Type>(op); }
181
182Type TypeMutator::VisitType_(const PrimTypeNode* op) { return GetRef<Type>(op); }
183
184Type TypeMutator::VisitType_(const PointerTypeNode* op) {
185 Type element_type = VisitType(op->element_type);
186
187 if (element_type.same_as(op->element_type)) {
188 return GetRef<Type>(op);
189 } else {
190 return PointerType(element_type, op->storage_scope);
191 }
192}
193
194// Implements bind.
195class TypeBinder : public TypeMutator {
196 public:
197 explicit TypeBinder(const tvm::Map<TypeVar, Type>& args_map) : args_map_(args_map) {}
198
199 Type VisitType_(const TypeVarNode* op) override {
200 auto id = GetRef<TypeVar>(op);
201 auto it = args_map_.find(id);
202 if (it != args_map_.end()) {
203 return (*it).second;
204 } else {
205 return std::move(id);
206 }
207 }
208
209 private:
210 const tvm::Map<TypeVar, Type>& args_map_;
211};
212
213Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
214 return TypeBinder(args_map).VisitType(type);
215}
216
217} // namespace tvm
218