1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/type_functor.h
22 * \brief A way to defined arbitrary function signature with dispatch on types.
23 */
24#ifndef TVM_IR_TYPE_FUNCTOR_H_
25#define TVM_IR_TYPE_FUNCTOR_H_
26
27#include <tvm/ir/tensor_type.h>
28#include <tvm/ir/type_relation.h>
29#include <tvm/node/functor.h>
30
31#include <string>
32#include <utility>
33#include <vector>
34
35namespace tvm {
36
37template <typename FType>
38class TypeFunctor;
39
40// functions to be overriden.
41#define TYPE_FUNCTOR_DEFAULT \
42 { return VisitTypeDefault_(op, std::forward<Args>(args)...); }
43
44#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \
45 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
46 return self->VisitType_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
47 });
48
49template <typename R, typename... Args>
50class TypeFunctor<R(const Type& n, Args...)> {
51 private:
52 using TSelf = TypeFunctor<R(const Type& n, Args...)>;
53 using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
54
55 public:
56 /*! \brief the result type of this functor */
57 using result_type = R;
58 /*! \brief virtual destructor */
59 virtual ~TypeFunctor() {}
60 /*!
61 * \brief Same as call.
62 * \param n The expression node.
63 * \param args Additional arguments.
64 * \return The result of the call
65 */
66 R operator()(const Type& n, Args... args) { return VisitType(n, std::forward<Args>(args)...); }
67 /*!
68 * \brief The functor call.
69 * \param n The expression node.
70 * \param args Additional arguments.
71 * \return The result of the call
72 */
73 virtual R VisitType(const Type& n, Args... args) {
74 ICHECK(n.defined());
75 static FType vtable = InitVTable();
76 return vtable(n, this, std::forward<Args>(args)...);
77 }
78 // Functions that can be overriden by subclass
79 virtual R VisitType_(const TensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
80 virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
81 virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
82 virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
83 virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
84 virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
85 virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
86 virtual R VisitType_(const RelayRefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
87 virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
88 virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
89 virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
90 virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
91 virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
92 virtual R VisitTypeDefault_(const Object* op, Args...) {
93 LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
94 throw; // unreachable, written to stop compiler warning
95 }
96
97 private:
98 // initialize the vtable.
99 static FType InitVTable() {
100 FType vtable;
101 // Set dispatch
102 TVM_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
103 TVM_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
104 TVM_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
105 TVM_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
106 TVM_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
107 TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
108 TVM_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
109 TVM_TYPE_FUNCTOR_DISPATCH(RelayRefTypeNode);
110 TVM_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
111 TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
112 TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
113 TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
114 TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode);
115 return vtable;
116 }
117};
118
119#undef TVM_TYPE_FUNCTOR_DISPATCH
120
121/*!
122 * \brief A type visitor that recursively visit types.
123 */
124class TVM_DLL TypeVisitor : public TypeFunctor<void(const Type& n)> {
125 public:
126 void VisitType_(const TypeVarNode* op) override;
127 void VisitType_(const IncompleteTypeNode* op) override;
128 void VisitType_(const TensorTypeNode* op) override;
129 void VisitType_(const FuncTypeNode* op) override;
130 void VisitType_(const TupleTypeNode* op) override;
131 void VisitType_(const TypeRelationNode* op) override;
132 void VisitType_(const RelayRefTypeNode* op) override;
133 void VisitType_(const GlobalTypeVarNode* op) override;
134 void VisitType_(const TypeCallNode* op) override;
135 void VisitType_(const TypeDataNode* op) override;
136 void VisitType_(const PrimTypeNode* op) override;
137 void VisitType_(const PointerTypeNode* op) override;
138};
139
140/*!
141 * \brief TypeMutator that mutates expressions.
142 */
143class TVM_DLL TypeMutator : public TypeFunctor<Type(const Type& n)> {
144 public:
145 Type VisitType(const Type& t) override;
146 Type VisitType_(const TypeVarNode* op) override;
147 Type VisitType_(const TensorTypeNode* op) override;
148 Type VisitType_(const IncompleteTypeNode* op) override;
149 Type VisitType_(const FuncTypeNode* op) override;
150 Type VisitType_(const TupleTypeNode* op) override;
151 Type VisitType_(const TypeRelationNode* type_rel) override;
152 Type VisitType_(const RelayRefTypeNode* op) override;
153 Type VisitType_(const GlobalTypeVarNode* op) override;
154 Type VisitType_(const TypeCallNode* op) override;
155 Type VisitType_(const TypeDataNode* op) override;
156 Type VisitType_(const PrimTypeNode* op) override;
157 Type VisitType_(const PointerTypeNode* op) override;
158
159 private:
160 Array<Type> MutateArray(Array<Type> arr);
161};
162
163/*!
164 * \brief Bind free type variables in the type.
165 * \param type The type to be updated.
166 * \param args_map The binding map.
167 */
168Type Bind(const Type& type, const Map<TypeVar, Type>& args_map);
169
170} // namespace tvm
171#endif // TVM_IR_TYPE_FUNCTOR_H_
172