1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/type_functor.h |
22 | * \brief A way to defined arbitrary function signature with dispatch on types. |
23 | */ |
24 | #ifndef TVM_IR_TYPE_FUNCTOR_H_ |
25 | #define TVM_IR_TYPE_FUNCTOR_H_ |
26 | |
27 | #include <tvm/ir/tensor_type.h> |
28 | #include <tvm/ir/type_relation.h> |
29 | #include <tvm/node/functor.h> |
30 | |
31 | #include <string> |
32 | #include <utility> |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | |
37 | template <typename FType> |
38 | class TypeFunctor; |
39 | |
40 | // functions to be overriden. |
41 | #define TYPE_FUNCTOR_DEFAULT \ |
42 | { return VisitTypeDefault_(op, std::forward<Args>(args)...); } |
43 | |
44 | #define TVM_TYPE_FUNCTOR_DISPATCH(OP) \ |
45 | vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ |
46 | return self->VisitType_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ |
47 | }); |
48 | |
49 | template <typename R, typename... Args> |
50 | class TypeFunctor<R(const Type& n, Args...)> { |
51 | private: |
52 | using TSelf = TypeFunctor<R(const Type& n, Args...)>; |
53 | using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; |
54 | |
55 | public: |
56 | /*! \brief the result type of this functor */ |
57 | using result_type = R; |
58 | /*! \brief virtual destructor */ |
59 | virtual ~TypeFunctor() {} |
60 | /*! |
61 | * \brief Same as call. |
62 | * \param n The expression node. |
63 | * \param args Additional arguments. |
64 | * \return The result of the call |
65 | */ |
66 | R operator()(const Type& n, Args... args) { return VisitType(n, std::forward<Args>(args)...); } |
67 | /*! |
68 | * \brief The functor call. |
69 | * \param n The expression node. |
70 | * \param args Additional arguments. |
71 | * \return The result of the call |
72 | */ |
73 | virtual R VisitType(const Type& n, Args... args) { |
74 | ICHECK(n.defined()); |
75 | static FType vtable = InitVTable(); |
76 | return vtable(n, this, std::forward<Args>(args)...); |
77 | } |
78 | // Functions that can be overriden by subclass |
79 | virtual R VisitType_(const TensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
80 | virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
81 | virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
82 | virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
83 | virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
84 | virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
85 | virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
86 | virtual R VisitType_(const RelayRefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
87 | virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
88 | virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
89 | virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
90 | virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
91 | virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; |
92 | virtual R VisitTypeDefault_(const Object* op, Args...) { |
93 | LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); |
94 | throw; // unreachable, written to stop compiler warning |
95 | } |
96 | |
97 | private: |
98 | // initialize the vtable. |
99 | static FType InitVTable() { |
100 | FType vtable; |
101 | // Set dispatch |
102 | TVM_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); |
103 | TVM_TYPE_FUNCTOR_DISPATCH(TypeVarNode); |
104 | TVM_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); |
105 | TVM_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); |
106 | TVM_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); |
107 | TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); |
108 | TVM_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); |
109 | TVM_TYPE_FUNCTOR_DISPATCH(RelayRefTypeNode); |
110 | TVM_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode); |
111 | TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode); |
112 | TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode); |
113 | TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode); |
114 | TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode); |
115 | return vtable; |
116 | } |
117 | }; |
118 | |
119 | #undef TVM_TYPE_FUNCTOR_DISPATCH |
120 | |
121 | /*! |
122 | * \brief A type visitor that recursively visit types. |
123 | */ |
124 | class TVM_DLL TypeVisitor : public TypeFunctor<void(const Type& n)> { |
125 | public: |
126 | void VisitType_(const TypeVarNode* op) override; |
127 | void VisitType_(const IncompleteTypeNode* op) override; |
128 | void VisitType_(const TensorTypeNode* op) override; |
129 | void VisitType_(const FuncTypeNode* op) override; |
130 | void VisitType_(const TupleTypeNode* op) override; |
131 | void VisitType_(const TypeRelationNode* op) override; |
132 | void VisitType_(const RelayRefTypeNode* op) override; |
133 | void VisitType_(const GlobalTypeVarNode* op) override; |
134 | void VisitType_(const TypeCallNode* op) override; |
135 | void VisitType_(const TypeDataNode* op) override; |
136 | void VisitType_(const PrimTypeNode* op) override; |
137 | void VisitType_(const PointerTypeNode* op) override; |
138 | }; |
139 | |
140 | /*! |
141 | * \brief TypeMutator that mutates expressions. |
142 | */ |
143 | class TVM_DLL TypeMutator : public TypeFunctor<Type(const Type& n)> { |
144 | public: |
145 | Type VisitType(const Type& t) override; |
146 | Type VisitType_(const TypeVarNode* op) override; |
147 | Type VisitType_(const TensorTypeNode* op) override; |
148 | Type VisitType_(const IncompleteTypeNode* op) override; |
149 | Type VisitType_(const FuncTypeNode* op) override; |
150 | Type VisitType_(const TupleTypeNode* op) override; |
151 | Type VisitType_(const TypeRelationNode* type_rel) override; |
152 | Type VisitType_(const RelayRefTypeNode* op) override; |
153 | Type VisitType_(const GlobalTypeVarNode* op) override; |
154 | Type VisitType_(const TypeCallNode* op) override; |
155 | Type VisitType_(const TypeDataNode* op) override; |
156 | Type VisitType_(const PrimTypeNode* op) override; |
157 | Type VisitType_(const PointerTypeNode* op) override; |
158 | |
159 | private: |
160 | Array<Type> MutateArray(Array<Type> arr); |
161 | }; |
162 | |
163 | /*! |
164 | * \brief Bind free type variables in the type. |
165 | * \param type The type to be updated. |
166 | * \param args_map The binding map. |
167 | */ |
168 | Type Bind(const Type& type, const Map<TypeVar, Type>& args_map); |
169 | |
170 | } // namespace tvm |
171 | #endif // TVM_IR_TYPE_FUNCTOR_H_ |
172 | |