1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/op.h
22 * \brief Primitive operators(builtin intrinsics)
23 * and registry for them.
24 */
25#ifndef TVM_IR_OP_H_
26#define TVM_IR_OP_H_
27
28#include <dmlc/registry.h>
29#include <tvm/ir/attrs.h>
30#include <tvm/ir/expr.h>
31#include <tvm/ir/type.h>
32#include <tvm/ir/type_relation.h>
33#include <tvm/node/attr_registry_map.h>
34#include <tvm/runtime/registry.h>
35
36#include <string>
37#include <utility>
38#include <vector>
39
40namespace tvm {
41
42// forward declare name.
43template <typename>
44class OpAttrMap;
45
46// TODO(tvm-team): migrate low-level intrinsics to use Op
47/*!
48 * \brief Primitive Op(builtin intrinsics)
49 *
50 * This data structure stores the meta-data
51 * about primitive operators that can be invoked via Call.
52 *
53 * Low-level IR intrinsics(such as libc.expf) are also
54 * implemented via Op.
55 *
56 * \sa Op
57 */
58class OpNode : public RelayExprNode {
59 public:
60 /*! \brief name of the operator */
61 String name;
62 /*! \brief the type of the operator */
63 mutable FuncType op_type;
64 /*!
65 * \brief detailed description of the operator
66 * This can be used to generate docstring automatically for the operator.
67 */
68 String description;
69 /* \brief Information of input arguments to the operator */
70 Array<AttrFieldInfo> arguments;
71 /*!
72 * \brief The type key of the attribute field
73 * This can be empty, in which case it defaults to anything.
74 */
75 String attrs_type_key;
76 /*!
77 * \brief attribute type index,
78 * this field varies in each run and is not exposed to frontend.
79 */
80 uint32_t attrs_type_index{0};
81 /*!
82 * \brief number of input arguments to the operator,
83 * -1 means it is variable length
84 */
85 int32_t num_inputs = -1;
86 /*!
87 * \brief support level of the operator,
88 * The lower the more priority it contains.
89 * This is in analogies to BLAS levels.
90 */
91 int32_t support_level = 10;
92
93 void VisitAttrs(AttrVisitor* v) {
94 v->Visit("name", &name);
95 v->Visit("op_type", &op_type);
96 v->Visit("description", &description);
97 v->Visit("arguments", &arguments);
98 v->Visit("attrs_type_key", &attrs_type_key);
99 v->Visit("num_inputs", &num_inputs);
100 v->Visit("support_level", &support_level);
101 }
102
103 bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
104 // pointer equality is fine as there is only one op with the same name.
105 return this == other;
106 }
107
108 void SHashReduce(SHashReducer hash_reduce) const {
109 // Name uniquely identifies an Op.
110 hash_reduce(name);
111 }
112
113 /*!
114 * \brief Check that if current op is a "primtive operator".
115 * That is the arguments are all type variables, and there is a single
116 * type relation applied to the input and output types.
117 */
118 bool IsPrimitiveOp() const {
119 if (is_primitive_ != -1) return is_primitive_ != 0;
120 is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
121 return is_primitive_ != 0;
122 }
123
124 static constexpr const char* _type_key = "Op";
125 TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);
126
127 private:
128 /*! \return the internal attr registry index. */
129 uint32_t AttrRegistryIndex() const { return index_; }
130 /*! \brief repr to be printed in registry*/
131 std::string AttrRegistryName() const { return name; }
132
133 // friend class
134 template <typename>
135 friend class AttrRegistryMapContainerMap;
136 template <typename, typename>
137 friend class AttrRegistry;
138 friend class OpRegEntry;
139
140 friend bool IsPrimitiveOp(const RelayExpr&);
141 // Program internal unique index of operator.
142 // Used to help index the program.
143 uint32_t index_{0};
144 // whether this is a primitive op. -1 means unknown.
145 mutable int is_primitive_{-1};
146 // Internal function to compute if it is primitive op
147 bool IsPrimitiveOp_() const {
148 const auto& fn_ty = this->op_type;
149 ICHECK(fn_ty.get() != nullptr) << "op_type of " << this->name << " is not registered";
150 if (fn_ty->type_constraints.size() != 1) return false;
151 const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
152 if (rel == nullptr) return false;
153 // validate if the type parameter matches up
154 for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
155 if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
156 }
157 return true;
158 }
159};
160
161/*!
162 * \brief Managed reference class to OpNode.
163 * \sa OpNode
164 */
165class Op : public RelayExpr {
166 public:
167 /*!
168 * \brief Get additional registered attribute about operators.
169 * If nothing has been registered, an empty OpAttrMap will be returned.
170 * \param attr_name The name of the attribute.
171 * \return An OpAttrMap of specified attr_name.
172 * \tparam ValueType The type of the attribute.
173 */
174 template <typename ValueType>
175 inline static OpAttrMap<ValueType> GetAttrMap(const String& attr_name);
176 /*!
177 * \brief Checks if an attr map is present in the registry.
178 * \param attr_name The name of the attribute.
179 * \return bool True if the attr is present.
180 */
181 TVM_DLL static bool HasAttrMap(const String& attr_name);
182 /*!
183 * \brief Get an Op for a given operator name.
184 * Will raise an error if the op has not been registered.
185 * \param op_name Name of the operator.
186 * \return Pointer to a Op, valid throughout program lifetime.
187 */
188 TVM_DLL static const Op& Get(const String& op_name);
189
190 TVM_DEFINE_OBJECT_REF_METHODS(Op, RelayExpr, OpNode)
191
192 private:
193 /*!
194 * \brief Get generic attrmap given attr name
195 * \param key The attribute key
196 * \return The attr map.
197 */
198 TVM_DLL static const AttrRegistryMapContainerMap<Op>& GetAttrMapContainer(const String& key);
199};
200
201/*!
202 * \brief Helper structure to register operators
203 * \sa TVM_REGISTER_OP
204 */
205class OpRegEntry {
206 public:
207 /*! \return the operator */
208 const Op& op() const { return op_; }
209 /*!
210 * \brief setter function during registration
211 * Set the description of operator
212 * \param descr the description string.
213 * \return reference to self.
214 */
215 inline OpRegEntry& describe(const std::string& descr); // NOLINT(*)
216 /*!
217 * \brief Add argument information to the function.
218 * \param name Name of the argument.
219 * \param type Type of the argument.
220 * \param description Description of the argument.
221 * \return reference to self.
222 */
223 inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
224 const std::string& description);
225 /*!
226 * \brief Attach the type function corresponding to the return type.
227 * \param rel_name The type relation name to register.
228 * \param type_rel_func The backing relation function which can solve an arbitrary
229 * relation on variables.
230 * \return reference to self.
231 */
232 inline OpRegEntry& add_type_rel(
233 const std::string& rel_name,
234 runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
235 type_rel_func);
236 /*!
237 * \brief Set the attrs type key and index to be AttrsType.
238 * \tparam AttrsType the attribute type to b set.
239 * \return reference to self.
240 */
241 template <typename AttrsType>
242 inline OpRegEntry& set_attrs_type();
243 /*!
244 * \brief Set the attrs type key and index to be AttrsType.
245 * \param key The attribute type key to be set.
246 * \return reference to self.
247 */
248 inline OpRegEntry& set_attrs_type_key(const String& key);
249 /*!
250 * \brief Set the num_inputs
251 * \param n The number of inputs to be set.
252 * \return reference to self.
253 */
254 inline OpRegEntry& set_num_inputs(int32_t n); // NOLINT(*)
255 /*!
256 * \brief Set the support level of op.
257 * \param level The support level.
258 * \return reference to self.
259 */
260 inline OpRegEntry& set_support_level(int32_t level); // NOLINT(*)
261 /*!
262 * \brief Register additional attributes to operator.
263 * \param attr_name The name of the attribute.
264 * \param value The value to be set.
265 * \param plevel The priority level of this set,
266 * an higher priority level attribute
267 * will replace lower priority level attribute.
268 * Must be bigger than 0.
269 *
270 * Cannot set with same plevel twice in the code.
271 *
272 * \tparam ValueType The type of the value to be set.
273 */
274 template <typename ValueType>
275 inline OpRegEntry& set_attr(const std::string& attr_name, // NOLINT(*)
276 const ValueType& value, int plevel = 10);
277
278 /*!
279 * \brief Resets an attr of the registry.
280 * \param attr_name The name of the attribute.
281 */
282 inline void reset_attr(const std::string& attr_name);
283
284 // set the name of the op to be the same as registry
285 inline OpRegEntry& set_name() { // NOLINT(*)
286 if (get()->name.length() == 0) {
287 get()->name = name;
288 }
289 return *this;
290 }
291 /*!
292 * \brief Register or get a new entry.
293 * \param name The name of the operator.
294 * \return the corresponding entry.
295 */
296 TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
297
298 private:
299 template <typename, typename>
300 friend class AttrRegistry;
301 // the name
302 std::string name;
303 /*! \brief The operator */
304 Op op_;
305 // private constructor
306 TVM_DLL OpRegEntry(uint32_t reg_index);
307 // return internal pointer to op.
308 inline OpNode* get();
309 // update the attribute OpAttrMap
310 TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
311};
312
313/*!
314 * \brief Map<Op,ValueType> used to store meta-information about Op.
315 * \tparam ValueType The type of the value stored in map.
316 */
317template <typename ValueType>
318class OpAttrMap : public AttrRegistryMap<Op, ValueType> {
319 public:
320 /*!
321 * \brief get the corresponding value element at op with default value.
322 * \param expr The key to the map
323 * \param def_value The default value when the key does not exist
324 * or if expr is not an Op.
325 * \return the const reference to the content value.
326 */
327 inline ValueType get(const RelayExpr& expr, ValueType def_value) const;
328
329 using TParent = AttrRegistryMap<Op, ValueType>;
330 using TParent::count;
331 using TParent::get;
332 using TParent::operator[];
333
334 private:
335 friend class Op;
336 // constructor
337 explicit OpAttrMap(const AttrRegistryMapContainerMap<Op>& map) : TParent(map) {}
338};
339
340// internal macros to make
341#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op
342
343/*!
344 * \def TVM_REGISTER_OP
345 * \brief Register a new operator, or set attribute of the corresponding op.
346 *
347 * \param OpName The name of registry
348 *
349 * \code
350 *
351 * TVM_REGISTER_OP("add")
352 * .describe("add two inputs together")
353 * .set_num_inputs(2)
354 * .set_attr<OpKernel>("gpu_kernel", AddKernel);
355 *
356 * \endcode
357 */
358#define TVM_REGISTER_OP(OpName) \
359 TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
360 ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
361
362// implementations
363
364template <typename ValueType>
365inline OpAttrMap<ValueType> Op::GetAttrMap(const String& key) {
366 return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key));
367}
368
369inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); }
370
371inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*)
372 get()->description = descr;
373 return *this;
374}
375
376inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
377 const std::string& description) {
378 auto n = make_object<AttrFieldInfoNode>();
379 n->name = name;
380 n->type_info = type;
381 n->description = description;
382 get()->arguments.push_back(AttrFieldInfo(n));
383 return *this;
384}
385
386inline OpRegEntry& OpRegEntry::add_type_rel(
387 const std::string& rel_name,
388 runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
389 type_rel_func) {
390 auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
391 TypeRelationFn env_type_rel_func;
392
393 if (runtime::Registry::Get(func_name)) {
394 auto env_func = EnvFunc::Get(func_name);
395 env_type_rel_func = env_func;
396 } else {
397 runtime::Registry::Register(func_name).set_body(type_rel_func.packed());
398 auto env_func = EnvFunc::Get(func_name);
399 env_type_rel_func = env_func;
400 }
401
402 Array<TypeVar> type_params;
403 Array<Type> arg_types;
404
405 // Add inputs.
406 std::string input_name_prefix = "in";
407 for (int i = 0; i < get()->num_inputs; i++) {
408 auto name = input_name_prefix + std::to_string(i);
409 auto param = TypeVar(name, TypeKind::kType);
410 type_params.push_back(param);
411 arg_types.push_back(param);
412 }
413
414 Array<Type> ty_call_args = arg_types;
415
416 // Add output type.
417 auto out_param = TypeVar("out", TypeKind::kType);
418 type_params.push_back(out_param);
419 // this will trigger copy on write.
420 ty_call_args.push_back(out_param);
421
422 // The attributes of primitive op is nullptr
423 //
424 // The attributes of primitive operator can vary at the call site.
425 // The type of sum is also dependent on Attrs being passed.
426 // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs.
427 //
428 // A common example is sum(x, axis), where the choice of axis
429 // can affect the type of the function.
430 TypeConstraint type_rel =
431 TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs());
432
433 auto func_type = FuncType(arg_types, out_param, type_params, {type_rel});
434
435 get()->op_type = func_type;
436
437 return *this;
438}
439
440inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*)
441 get()->num_inputs = n;
442 return *this;
443}
444
445template <typename AttrsType>
446inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*)
447 get()->attrs_type_key = AttrsType::_type_key;
448 get()->attrs_type_index = AttrsType::RuntimeTypeIndex();
449 return *this;
450}
451
452inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*)
453 get()->attrs_type_key = key;
454 get()->attrs_type_index = Object::TypeKey2Index(key);
455 return *this;
456}
457
458inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*)
459 get()->support_level = n;
460 return *this;
461}
462
463template <typename ValueType>
464inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*)
465 const std::string& attr_name, const ValueType& value, int plevel) {
466 ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
467 runtime::TVMRetValue rv;
468 rv = value;
469 UpdateAttr(attr_name, rv, plevel);
470 return *this;
471}
472
473// member functions of OpAttrMap
474
475template <typename ValueType>
476inline ValueType OpAttrMap<ValueType>::get(const RelayExpr& expr, ValueType def_value) const {
477 ICHECK(expr.defined());
478 if (const OpNode* op = expr.as<OpNode>()) {
479 return this->map_.get(GetRef<Op>(op), def_value);
480 } else {
481 return def_value;
482 }
483}
484
485/*!
486 * \brief Check that an expression is a "primitive operator".
487 *
488 * Will return true if the expression is an operator which
489 * matches the form of primitive operators registered directly
490 * by the Relay codebase.
491 *
492 * That is the arguments are all type variables, and there is a single
493 * type relation applied to the input and output types.
494 *
495 * \param expr An expression.
496 * \return Whether the expression is primitive op.
497 */
498inline bool IsPrimitiveOp(const RelayExpr& expr) {
499 const auto* op = expr.as<OpNode>();
500 return op != nullptr && op->IsPrimitiveOp();
501}
502
503} // namespace tvm
504#endif // TVM_IR_OP_H_
505