1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/type_relation.h
22 * \brief Type relation and function for type inference(checking).
23 */
24#ifndef TVM_IR_TYPE_RELATION_H_
25#define TVM_IR_TYPE_RELATION_H_
26
27#include <tvm/ir/attrs.h>
28#include <tvm/ir/diagnostic.h>
29#include <tvm/ir/env_func.h>
30#include <tvm/ir/module.h>
31#include <tvm/ir/type.h>
32#include <tvm/runtime/logging.h>
33
34namespace tvm {
35
36/*!
37 * \brief Type function application.
38 * \sa TypeCall
39 */
40class TypeCallNode : public TypeNode {
41 public:
42 /*!
43 * \brief The type-level function (ADT that takes type params).
44 */
45 Type func;
46 /*! \brief The arguments. */
47 Array<Type> args;
48
49 void VisitAttrs(AttrVisitor* v) {
50 v->Visit("func", &func);
51 v->Visit("args", &args);
52 v->Visit("span", &span);
53 }
54
55 bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const {
56 return equal(func, other->func) && equal(args, other->args);
57 }
58
59 void SHashReduce(SHashReducer hash_reduce) const {
60 hash_reduce(func);
61 hash_reduce(args);
62 }
63
64 static constexpr const char* _type_key = "TypeCall";
65 TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
66};
67
68/*!
69 * \brief Managed reference to TypeCallNode.
70 * \sa TypeCallNode
71 */
72class TypeCall : public Type {
73 public:
74 /*!
75 * \brief Constructor
76 * \param func The type function to apply.
77 * \param args The arguments to the type function.
78 */
79 TVM_DLL TypeCall(Type func, Array<Type> args);
80
81 TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode);
82};
83
84/*!
85 * \brief reporter that reports back to the
86 * type resolution information.
87 */
88class TypeReporterNode : public Object {
89 public:
90 /*! \brief virtual destructor */
91 virtual ~TypeReporterNode() {}
92 /*!
93 * \brief Create a type equality constraint.
94 *
95 * The "assign direction" acts as a hint to the solver
96 * showing that it is more likely to resolve dst by src.
97 * But it is possible for the solver to resolve src by dst as well.
98 */
99 TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
100
101 /*!
102 * \brief assert shape expression comparison.
103 * \note Use assert only if any of the condition input is symbolic.
104 * \param cond The condition of operation.
105 * \return false if assertion can be proven to have failed
106 * true if solver can still proceed.
107 */
108 TVM_DLL virtual bool Assert(const PrimExpr& cond) = 0;
109 /*!
110 * \brief assert shape expression equals each other.
111 * \param lhs The left operand.
112 * \param rhs The right operand.
113 * \return false if assertion can be proven to have failed
114 * true if solver can still proceed.
115 */
116 TVM_DLL virtual bool AssertEQ(const PrimExpr& lhs, const PrimExpr& rhs) = 0;
117
118 /*!
119 * \brief Set the location at which to report unification errors.
120 * \param span The span at which to report the error.
121 */
122 TVM_DLL virtual void SetSpan(const Span& span) = 0;
123
124 TVM_DLL virtual Span GetSpan() = 0;
125
126 TVM_DLL virtual DiagnosticContext GetDiagCtx() = 0;
127
128 /*!
129 * \brief Retrieve the current global module.
130 * \return The global module.
131 */
132 TVM_DLL virtual IRModule GetModule() = 0;
133
134 // solver is not serializable.
135 void VisitAttrs(AttrVisitor* v) {}
136
137 static constexpr const char* _type_key = "TypeReporter";
138 TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
139};
140
141/*!
142 * \brief Container class of TypeReporter.
143 * \sa TypeReporterNode
144 */
145class TypeReporter : public ObjectRef {
146 public:
147 TypeReporter() {}
148 explicit TypeReporter(ObjectPtr<Object> n) : ObjectRef(n) {}
149 TypeReporterNode* operator->() const {
150 return const_cast<TypeReporterNode*>(static_cast<const TypeReporterNode*>(get()));
151 }
152 using ContainerType = TypeReporterNode;
153};
154
155/*!
156 * \brief User defined type constraint function.
157 *
158 * If the input type information can be used to fully decide
159 * the IncompleteTypes, then the function should call
160 * reporter.Assign to report the new types, and return true.
161 * Otherwise, the function should return false.
162 *
163 * \param args The arguments to the relation.
164 * The types are stored in the form of
165 * [input_type_0, input_type_1, ... input_type_n,
166 * output_type_0, output_type_1, ... output_type_m]
167 *
168 * \param num_inputs Number of input types in the args.
169 * \param attrs The additional attributes of the operator.
170 * \param reporter The reporter to report solution to.
171 * \return false if This relation cannot be resolved.
172 * true if this relation has been resolved.
173 */
174using TypeRelationFn = TypedEnvFunc<bool(const Array<Type>& args, int num_inputs,
175 const Attrs& attrs, const TypeReporter& reporter)>;
176
177/*!
178 * \brief User defined type relation, it is an input-output relation on types.
179 *
180 * TypeRelation is more generalized than type call as it allows inference
181 * of both inputs and outputs.
182 *
183 * \sa TypeRelation
184 */
185class TypeRelationNode : public TypeConstraintNode {
186 public:
187 /*!
188 * \brief The function on input and output variables which
189 * this is not directly serializable,
190 * need to be looked-up in the module.
191 */
192 TypeRelationFn func;
193 /*! \brief The type arguments to the type function. */
194 Array<Type> args;
195 /*! \brief Number of inputs arguments */
196 int num_inputs;
197 /*! \brief Attributes to the relation function */
198 Attrs attrs;
199
200 void VisitAttrs(AttrVisitor* v) {
201 v->Visit("func", &func);
202 v->Visit("args", &args);
203 v->Visit("num_inputs", &num_inputs);
204 v->Visit("attrs", &attrs);
205 v->Visit("span", &span);
206 }
207
208 bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const {
209 return equal(func, other->func) && equal(args, other->args) &&
210 equal(num_inputs, other->num_inputs) && equal(attrs, other->attrs);
211 }
212
213 void SHashReduce(SHashReducer hash_reduce) const {
214 hash_reduce(func);
215 hash_reduce(args);
216 hash_reduce(num_inputs);
217 hash_reduce(attrs);
218 }
219
220 static constexpr const char* _type_key = "TypeRelation";
221 TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
222};
223
224/*!
225 * \brief Managed reference to TypeRelationNode.
226 * \sa TypeRelationNode
227 */
228class TypeRelation : public TypeConstraint {
229 public:
230 /*!
231 * \brief Constructor
232 * \param func The relation function.
233 * \param args The arguments to the type relation.
234 * \param num_inputs Number of inputs.
235 * \param attrs Attributes to the relation function.
236 * \sa TypeRelationNode for more docs about these fields.
237 */
238 TVM_DLL TypeRelation(TypeRelationFn func, Array<Type> args, int num_inputs, Attrs attrs);
239
240 TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
241};
242} // namespace tvm
243#endif // TVM_IR_TYPE_RELATION_H_
244