1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/type_relation.h |
22 | * \brief Type relation and function for type inference(checking). |
23 | */ |
24 | #ifndef TVM_IR_TYPE_RELATION_H_ |
25 | #define TVM_IR_TYPE_RELATION_H_ |
26 | |
27 | #include <tvm/ir/attrs.h> |
28 | #include <tvm/ir/diagnostic.h> |
29 | #include <tvm/ir/env_func.h> |
30 | #include <tvm/ir/module.h> |
31 | #include <tvm/ir/type.h> |
32 | #include <tvm/runtime/logging.h> |
33 | |
34 | namespace tvm { |
35 | |
36 | /*! |
37 | * \brief Type function application. |
38 | * \sa TypeCall |
39 | */ |
40 | class TypeCallNode : public TypeNode { |
41 | public: |
42 | /*! |
43 | * \brief The type-level function (ADT that takes type params). |
44 | */ |
45 | Type func; |
46 | /*! \brief The arguments. */ |
47 | Array<Type> args; |
48 | |
49 | void VisitAttrs(AttrVisitor* v) { |
50 | v->Visit("func" , &func); |
51 | v->Visit("args" , &args); |
52 | v->Visit("span" , &span); |
53 | } |
54 | |
55 | bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const { |
56 | return equal(func, other->func) && equal(args, other->args); |
57 | } |
58 | |
59 | void SHashReduce(SHashReducer hash_reduce) const { |
60 | hash_reduce(func); |
61 | hash_reduce(args); |
62 | } |
63 | |
64 | static constexpr const char* _type_key = "TypeCall" ; |
65 | TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode); |
66 | }; |
67 | |
68 | /*! |
69 | * \brief Managed reference to TypeCallNode. |
70 | * \sa TypeCallNode |
71 | */ |
72 | class TypeCall : public Type { |
73 | public: |
74 | /*! |
75 | * \brief Constructor |
76 | * \param func The type function to apply. |
77 | * \param args The arguments to the type function. |
78 | */ |
79 | TVM_DLL TypeCall(Type func, Array<Type> args); |
80 | |
81 | TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode); |
82 | }; |
83 | |
84 | /*! |
85 | * \brief reporter that reports back to the |
86 | * type resolution information. |
87 | */ |
88 | class TypeReporterNode : public Object { |
89 | public: |
90 | /*! \brief virtual destructor */ |
91 | virtual ~TypeReporterNode() {} |
92 | /*! |
93 | * \brief Create a type equality constraint. |
94 | * |
95 | * The "assign direction" acts as a hint to the solver |
96 | * showing that it is more likely to resolve dst by src. |
97 | * But it is possible for the solver to resolve src by dst as well. |
98 | */ |
99 | TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0; |
100 | |
101 | /*! |
102 | * \brief assert shape expression comparison. |
103 | * \note Use assert only if any of the condition input is symbolic. |
104 | * \param cond The condition of operation. |
105 | * \return false if assertion can be proven to have failed |
106 | * true if solver can still proceed. |
107 | */ |
108 | TVM_DLL virtual bool Assert(const PrimExpr& cond) = 0; |
109 | /*! |
110 | * \brief assert shape expression equals each other. |
111 | * \param lhs The left operand. |
112 | * \param rhs The right operand. |
113 | * \return false if assertion can be proven to have failed |
114 | * true if solver can still proceed. |
115 | */ |
116 | TVM_DLL virtual bool AssertEQ(const PrimExpr& lhs, const PrimExpr& rhs) = 0; |
117 | |
118 | /*! |
119 | * \brief Set the location at which to report unification errors. |
120 | * \param span The span at which to report the error. |
121 | */ |
122 | TVM_DLL virtual void SetSpan(const Span& span) = 0; |
123 | |
124 | TVM_DLL virtual Span GetSpan() = 0; |
125 | |
126 | TVM_DLL virtual DiagnosticContext GetDiagCtx() = 0; |
127 | |
128 | /*! |
129 | * \brief Retrieve the current global module. |
130 | * \return The global module. |
131 | */ |
132 | TVM_DLL virtual IRModule GetModule() = 0; |
133 | |
134 | // solver is not serializable. |
135 | void VisitAttrs(AttrVisitor* v) {} |
136 | |
137 | static constexpr const char* _type_key = "TypeReporter" ; |
138 | TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); |
139 | }; |
140 | |
141 | /*! |
142 | * \brief Container class of TypeReporter. |
143 | * \sa TypeReporterNode |
144 | */ |
145 | class TypeReporter : public ObjectRef { |
146 | public: |
147 | TypeReporter() {} |
148 | explicit TypeReporter(ObjectPtr<Object> n) : ObjectRef(n) {} |
149 | TypeReporterNode* operator->() const { |
150 | return const_cast<TypeReporterNode*>(static_cast<const TypeReporterNode*>(get())); |
151 | } |
152 | using ContainerType = TypeReporterNode; |
153 | }; |
154 | |
155 | /*! |
156 | * \brief User defined type constraint function. |
157 | * |
158 | * If the input type information can be used to fully decide |
159 | * the IncompleteTypes, then the function should call |
160 | * reporter.Assign to report the new types, and return true. |
161 | * Otherwise, the function should return false. |
162 | * |
163 | * \param args The arguments to the relation. |
164 | * The types are stored in the form of |
165 | * [input_type_0, input_type_1, ... input_type_n, |
166 | * output_type_0, output_type_1, ... output_type_m] |
167 | * |
168 | * \param num_inputs Number of input types in the args. |
169 | * \param attrs The additional attributes of the operator. |
170 | * \param reporter The reporter to report solution to. |
171 | * \return false if This relation cannot be resolved. |
172 | * true if this relation has been resolved. |
173 | */ |
174 | using TypeRelationFn = TypedEnvFunc<bool(const Array<Type>& args, int num_inputs, |
175 | const Attrs& attrs, const TypeReporter& reporter)>; |
176 | |
177 | /*! |
178 | * \brief User defined type relation, it is an input-output relation on types. |
179 | * |
180 | * TypeRelation is more generalized than type call as it allows inference |
181 | * of both inputs and outputs. |
182 | * |
183 | * \sa TypeRelation |
184 | */ |
185 | class TypeRelationNode : public TypeConstraintNode { |
186 | public: |
187 | /*! |
188 | * \brief The function on input and output variables which |
189 | * this is not directly serializable, |
190 | * need to be looked-up in the module. |
191 | */ |
192 | TypeRelationFn func; |
193 | /*! \brief The type arguments to the type function. */ |
194 | Array<Type> args; |
195 | /*! \brief Number of inputs arguments */ |
196 | int num_inputs; |
197 | /*! \brief Attributes to the relation function */ |
198 | Attrs attrs; |
199 | |
200 | void VisitAttrs(AttrVisitor* v) { |
201 | v->Visit("func" , &func); |
202 | v->Visit("args" , &args); |
203 | v->Visit("num_inputs" , &num_inputs); |
204 | v->Visit("attrs" , &attrs); |
205 | v->Visit("span" , &span); |
206 | } |
207 | |
208 | bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const { |
209 | return equal(func, other->func) && equal(args, other->args) && |
210 | equal(num_inputs, other->num_inputs) && equal(attrs, other->attrs); |
211 | } |
212 | |
213 | void SHashReduce(SHashReducer hash_reduce) const { |
214 | hash_reduce(func); |
215 | hash_reduce(args); |
216 | hash_reduce(num_inputs); |
217 | hash_reduce(attrs); |
218 | } |
219 | |
220 | static constexpr const char* _type_key = "TypeRelation" ; |
221 | TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); |
222 | }; |
223 | |
224 | /*! |
225 | * \brief Managed reference to TypeRelationNode. |
226 | * \sa TypeRelationNode |
227 | */ |
228 | class TypeRelation : public TypeConstraint { |
229 | public: |
230 | /*! |
231 | * \brief Constructor |
232 | * \param func The relation function. |
233 | * \param args The arguments to the type relation. |
234 | * \param num_inputs Number of inputs. |
235 | * \param attrs Attributes to the relation function. |
236 | * \sa TypeRelationNode for more docs about these fields. |
237 | */ |
238 | TVM_DLL TypeRelation(TypeRelationFn func, Array<Type> args, int num_inputs, Attrs attrs); |
239 | |
240 | TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); |
241 | }; |
242 | } // namespace tvm |
243 | #endif // TVM_IR_TYPE_RELATION_H_ |
244 | |