1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/op.h |
22 | * \brief Primitive operators(builtin intrinsics) |
23 | * and registry for them. |
24 | */ |
25 | #ifndef TVM_IR_OP_H_ |
26 | #define TVM_IR_OP_H_ |
27 | |
28 | #include <dmlc/registry.h> |
29 | #include <tvm/ir/attrs.h> |
30 | #include <tvm/ir/expr.h> |
31 | #include <tvm/ir/type.h> |
32 | #include <tvm/ir/type_relation.h> |
33 | #include <tvm/node/attr_registry_map.h> |
34 | #include <tvm/runtime/registry.h> |
35 | |
36 | #include <string> |
37 | #include <utility> |
38 | #include <vector> |
39 | |
40 | namespace tvm { |
41 | |
42 | // forward declare name. |
43 | template <typename> |
44 | class OpAttrMap; |
45 | |
46 | // TODO(tvm-team): migrate low-level intrinsics to use Op |
47 | /*! |
48 | * \brief Primitive Op(builtin intrinsics) |
49 | * |
50 | * This data structure stores the meta-data |
51 | * about primitive operators that can be invoked via Call. |
52 | * |
53 | * Low-level IR intrinsics(such as libc.expf) are also |
54 | * implemented via Op. |
55 | * |
56 | * \sa Op |
57 | */ |
58 | class OpNode : public RelayExprNode { |
59 | public: |
60 | /*! \brief name of the operator */ |
61 | String name; |
62 | /*! \brief the type of the operator */ |
63 | mutable FuncType op_type; |
64 | /*! |
65 | * \brief detailed description of the operator |
66 | * This can be used to generate docstring automatically for the operator. |
67 | */ |
68 | String description; |
69 | /* \brief Information of input arguments to the operator */ |
70 | Array<AttrFieldInfo> arguments; |
71 | /*! |
72 | * \brief The type key of the attribute field |
73 | * This can be empty, in which case it defaults to anything. |
74 | */ |
75 | String attrs_type_key; |
76 | /*! |
77 | * \brief attribute type index, |
78 | * this field varies in each run and is not exposed to frontend. |
79 | */ |
80 | uint32_t attrs_type_index{0}; |
81 | /*! |
82 | * \brief number of input arguments to the operator, |
83 | * -1 means it is variable length |
84 | */ |
85 | int32_t num_inputs = -1; |
86 | /*! |
87 | * \brief support level of the operator, |
88 | * The lower the more priority it contains. |
89 | * This is in analogies to BLAS levels. |
90 | */ |
91 | int32_t support_level = 10; |
92 | |
93 | void VisitAttrs(AttrVisitor* v) { |
94 | v->Visit("name" , &name); |
95 | v->Visit("op_type" , &op_type); |
96 | v->Visit("description" , &description); |
97 | v->Visit("arguments" , &arguments); |
98 | v->Visit("attrs_type_key" , &attrs_type_key); |
99 | v->Visit("num_inputs" , &num_inputs); |
100 | v->Visit("support_level" , &support_level); |
101 | } |
102 | |
103 | bool SEqualReduce(const OpNode* other, SEqualReducer equal) const { |
104 | // pointer equality is fine as there is only one op with the same name. |
105 | return this == other; |
106 | } |
107 | |
108 | void SHashReduce(SHashReducer hash_reduce) const { |
109 | // Name uniquely identifies an Op. |
110 | hash_reduce(name); |
111 | } |
112 | |
113 | /*! |
114 | * \brief Check that if current op is a "primtive operator". |
115 | * That is the arguments are all type variables, and there is a single |
116 | * type relation applied to the input and output types. |
117 | */ |
118 | bool IsPrimitiveOp() const { |
119 | if (is_primitive_ != -1) return is_primitive_ != 0; |
120 | is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0; |
121 | return is_primitive_ != 0; |
122 | } |
123 | |
124 | static constexpr const char* _type_key = "Op" ; |
125 | TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode); |
126 | |
127 | private: |
128 | /*! \return the internal attr registry index. */ |
129 | uint32_t AttrRegistryIndex() const { return index_; } |
130 | /*! \brief repr to be printed in registry*/ |
131 | std::string AttrRegistryName() const { return name; } |
132 | |
133 | // friend class |
134 | template <typename> |
135 | friend class AttrRegistryMapContainerMap; |
136 | template <typename, typename> |
137 | friend class AttrRegistry; |
138 | friend class OpRegEntry; |
139 | |
140 | friend bool IsPrimitiveOp(const RelayExpr&); |
141 | // Program internal unique index of operator. |
142 | // Used to help index the program. |
143 | uint32_t index_{0}; |
144 | // whether this is a primitive op. -1 means unknown. |
145 | mutable int is_primitive_{-1}; |
146 | // Internal function to compute if it is primitive op |
147 | bool IsPrimitiveOp_() const { |
148 | const auto& fn_ty = this->op_type; |
149 | ICHECK(fn_ty.get() != nullptr) << "op_type of " << this->name << " is not registered" ; |
150 | if (fn_ty->type_constraints.size() != 1) return false; |
151 | const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>(); |
152 | if (rel == nullptr) return false; |
153 | // validate if the type parameter matches up |
154 | for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { |
155 | if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; |
156 | } |
157 | return true; |
158 | } |
159 | }; |
160 | |
161 | /*! |
162 | * \brief Managed reference class to OpNode. |
163 | * \sa OpNode |
164 | */ |
165 | class Op : public RelayExpr { |
166 | public: |
167 | /*! |
168 | * \brief Get additional registered attribute about operators. |
169 | * If nothing has been registered, an empty OpAttrMap will be returned. |
170 | * \param attr_name The name of the attribute. |
171 | * \return An OpAttrMap of specified attr_name. |
172 | * \tparam ValueType The type of the attribute. |
173 | */ |
174 | template <typename ValueType> |
175 | inline static OpAttrMap<ValueType> GetAttrMap(const String& attr_name); |
176 | /*! |
177 | * \brief Checks if an attr map is present in the registry. |
178 | * \param attr_name The name of the attribute. |
179 | * \return bool True if the attr is present. |
180 | */ |
181 | TVM_DLL static bool HasAttrMap(const String& attr_name); |
182 | /*! |
183 | * \brief Get an Op for a given operator name. |
184 | * Will raise an error if the op has not been registered. |
185 | * \param op_name Name of the operator. |
186 | * \return Pointer to a Op, valid throughout program lifetime. |
187 | */ |
188 | TVM_DLL static const Op& Get(const String& op_name); |
189 | |
190 | TVM_DEFINE_OBJECT_REF_METHODS(Op, RelayExpr, OpNode) |
191 | |
192 | private: |
193 | /*! |
194 | * \brief Get generic attrmap given attr name |
195 | * \param key The attribute key |
196 | * \return The attr map. |
197 | */ |
198 | TVM_DLL static const AttrRegistryMapContainerMap<Op>& GetAttrMapContainer(const String& key); |
199 | }; |
200 | |
201 | /*! |
202 | * \brief Helper structure to register operators |
203 | * \sa TVM_REGISTER_OP |
204 | */ |
205 | class OpRegEntry { |
206 | public: |
207 | /*! \return the operator */ |
208 | const Op& op() const { return op_; } |
209 | /*! |
210 | * \brief setter function during registration |
211 | * Set the description of operator |
212 | * \param descr the description string. |
213 | * \return reference to self. |
214 | */ |
215 | inline OpRegEntry& describe(const std::string& descr); // NOLINT(*) |
216 | /*! |
217 | * \brief Add argument information to the function. |
218 | * \param name Name of the argument. |
219 | * \param type Type of the argument. |
220 | * \param description Description of the argument. |
221 | * \return reference to self. |
222 | */ |
223 | inline OpRegEntry& add_argument(const std::string& name, const std::string& type, |
224 | const std::string& description); |
225 | /*! |
226 | * \brief Attach the type function corresponding to the return type. |
227 | * \param rel_name The type relation name to register. |
228 | * \param type_rel_func The backing relation function which can solve an arbitrary |
229 | * relation on variables. |
230 | * \return reference to self. |
231 | */ |
232 | inline OpRegEntry& add_type_rel( |
233 | const std::string& rel_name, |
234 | runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)> |
235 | type_rel_func); |
236 | /*! |
237 | * \brief Set the attrs type key and index to be AttrsType. |
238 | * \tparam AttrsType the attribute type to b set. |
239 | * \return reference to self. |
240 | */ |
241 | template <typename AttrsType> |
242 | inline OpRegEntry& set_attrs_type(); |
243 | /*! |
244 | * \brief Set the attrs type key and index to be AttrsType. |
245 | * \param key The attribute type key to be set. |
246 | * \return reference to self. |
247 | */ |
248 | inline OpRegEntry& set_attrs_type_key(const String& key); |
249 | /*! |
250 | * \brief Set the num_inputs |
251 | * \param n The number of inputs to be set. |
252 | * \return reference to self. |
253 | */ |
254 | inline OpRegEntry& set_num_inputs(int32_t n); // NOLINT(*) |
255 | /*! |
256 | * \brief Set the support level of op. |
257 | * \param level The support level. |
258 | * \return reference to self. |
259 | */ |
260 | inline OpRegEntry& set_support_level(int32_t level); // NOLINT(*) |
261 | /*! |
262 | * \brief Register additional attributes to operator. |
263 | * \param attr_name The name of the attribute. |
264 | * \param value The value to be set. |
265 | * \param plevel The priority level of this set, |
266 | * an higher priority level attribute |
267 | * will replace lower priority level attribute. |
268 | * Must be bigger than 0. |
269 | * |
270 | * Cannot set with same plevel twice in the code. |
271 | * |
272 | * \tparam ValueType The type of the value to be set. |
273 | */ |
274 | template <typename ValueType> |
275 | inline OpRegEntry& set_attr(const std::string& attr_name, // NOLINT(*) |
276 | const ValueType& value, int plevel = 10); |
277 | |
278 | /*! |
279 | * \brief Resets an attr of the registry. |
280 | * \param attr_name The name of the attribute. |
281 | */ |
282 | inline void reset_attr(const std::string& attr_name); |
283 | |
284 | // set the name of the op to be the same as registry |
285 | inline OpRegEntry& set_name() { // NOLINT(*) |
286 | if (get()->name.length() == 0) { |
287 | get()->name = name; |
288 | } |
289 | return *this; |
290 | } |
291 | /*! |
292 | * \brief Register or get a new entry. |
293 | * \param name The name of the operator. |
294 | * \return the corresponding entry. |
295 | */ |
296 | TVM_DLL static OpRegEntry& RegisterOrGet(const String& name); |
297 | |
298 | private: |
299 | template <typename, typename> |
300 | friend class AttrRegistry; |
301 | // the name |
302 | std::string name; |
303 | /*! \brief The operator */ |
304 | Op op_; |
305 | // private constructor |
306 | TVM_DLL OpRegEntry(uint32_t reg_index); |
307 | // return internal pointer to op. |
308 | inline OpNode* get(); |
309 | // update the attribute OpAttrMap |
310 | TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel); |
311 | }; |
312 | |
313 | /*! |
314 | * \brief Map<Op,ValueType> used to store meta-information about Op. |
315 | * \tparam ValueType The type of the value stored in map. |
316 | */ |
317 | template <typename ValueType> |
318 | class OpAttrMap : public AttrRegistryMap<Op, ValueType> { |
319 | public: |
320 | /*! |
321 | * \brief get the corresponding value element at op with default value. |
322 | * \param expr The key to the map |
323 | * \param def_value The default value when the key does not exist |
324 | * or if expr is not an Op. |
325 | * \return the const reference to the content value. |
326 | */ |
327 | inline ValueType get(const RelayExpr& expr, ValueType def_value) const; |
328 | |
329 | using TParent = AttrRegistryMap<Op, ValueType>; |
330 | using TParent::count; |
331 | using TParent::get; |
332 | using TParent::operator[]; |
333 | |
334 | private: |
335 | friend class Op; |
336 | // constructor |
337 | explicit OpAttrMap(const AttrRegistryMapContainerMap<Op>& map) : TParent(map) {} |
338 | }; |
339 | |
340 | // internal macros to make |
341 | #define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op |
342 | |
343 | /*! |
344 | * \def TVM_REGISTER_OP |
345 | * \brief Register a new operator, or set attribute of the corresponding op. |
346 | * |
347 | * \param OpName The name of registry |
348 | * |
349 | * \code |
350 | * |
351 | * TVM_REGISTER_OP("add") |
352 | * .describe("add two inputs together") |
353 | * .set_num_inputs(2) |
354 | * .set_attr<OpKernel>("gpu_kernel", AddKernel); |
355 | * |
356 | * \endcode |
357 | */ |
358 | #define TVM_REGISTER_OP(OpName) \ |
359 | TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ |
360 | ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name() |
361 | |
362 | // implementations |
363 | |
364 | template <typename ValueType> |
365 | inline OpAttrMap<ValueType> Op::GetAttrMap(const String& key) { |
366 | return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key)); |
367 | } |
368 | |
369 | inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); } |
370 | |
371 | inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*) |
372 | get()->description = descr; |
373 | return *this; |
374 | } |
375 | |
376 | inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type, |
377 | const std::string& description) { |
378 | auto n = make_object<AttrFieldInfoNode>(); |
379 | n->name = name; |
380 | n->type_info = type; |
381 | n->description = description; |
382 | get()->arguments.push_back(AttrFieldInfo(n)); |
383 | return *this; |
384 | } |
385 | |
386 | inline OpRegEntry& OpRegEntry::add_type_rel( |
387 | const std::string& rel_name, |
388 | runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)> |
389 | type_rel_func) { |
390 | auto func_name = std::string("tvm.relay.type_relation." ) + rel_name; |
391 | TypeRelationFn env_type_rel_func; |
392 | |
393 | if (runtime::Registry::Get(func_name)) { |
394 | auto env_func = EnvFunc::Get(func_name); |
395 | env_type_rel_func = env_func; |
396 | } else { |
397 | runtime::Registry::Register(func_name).set_body(type_rel_func.packed()); |
398 | auto env_func = EnvFunc::Get(func_name); |
399 | env_type_rel_func = env_func; |
400 | } |
401 | |
402 | Array<TypeVar> type_params; |
403 | Array<Type> arg_types; |
404 | |
405 | // Add inputs. |
406 | std::string input_name_prefix = "in" ; |
407 | for (int i = 0; i < get()->num_inputs; i++) { |
408 | auto name = input_name_prefix + std::to_string(i); |
409 | auto param = TypeVar(name, TypeKind::kType); |
410 | type_params.push_back(param); |
411 | arg_types.push_back(param); |
412 | } |
413 | |
414 | Array<Type> ty_call_args = arg_types; |
415 | |
416 | // Add output type. |
417 | auto out_param = TypeVar("out" , TypeKind::kType); |
418 | type_params.push_back(out_param); |
419 | // this will trigger copy on write. |
420 | ty_call_args.push_back(out_param); |
421 | |
422 | // The attributes of primitive op is nullptr |
423 | // |
424 | // The attributes of primitive operator can vary at the call site. |
425 | // The type of sum is also dependent on Attrs being passed. |
426 | // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs. |
427 | // |
428 | // A common example is sum(x, axis), where the choice of axis |
429 | // can affect the type of the function. |
430 | TypeConstraint type_rel = |
431 | TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs()); |
432 | |
433 | auto func_type = FuncType(arg_types, out_param, type_params, {type_rel}); |
434 | |
435 | get()->op_type = func_type; |
436 | |
437 | return *this; |
438 | } |
439 | |
440 | inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*) |
441 | get()->num_inputs = n; |
442 | return *this; |
443 | } |
444 | |
445 | template <typename AttrsType> |
446 | inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*) |
447 | get()->attrs_type_key = AttrsType::_type_key; |
448 | get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); |
449 | return *this; |
450 | } |
451 | |
452 | inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*) |
453 | get()->attrs_type_key = key; |
454 | get()->attrs_type_index = Object::TypeKey2Index(key); |
455 | return *this; |
456 | } |
457 | |
458 | inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*) |
459 | get()->support_level = n; |
460 | return *this; |
461 | } |
462 | |
463 | template <typename ValueType> |
464 | inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*) |
465 | const std::string& attr_name, const ValueType& value, int plevel) { |
466 | ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0" ; |
467 | runtime::TVMRetValue rv; |
468 | rv = value; |
469 | UpdateAttr(attr_name, rv, plevel); |
470 | return *this; |
471 | } |
472 | |
473 | // member functions of OpAttrMap |
474 | |
475 | template <typename ValueType> |
476 | inline ValueType OpAttrMap<ValueType>::get(const RelayExpr& expr, ValueType def_value) const { |
477 | ICHECK(expr.defined()); |
478 | if (const OpNode* op = expr.as<OpNode>()) { |
479 | return this->map_.get(GetRef<Op>(op), def_value); |
480 | } else { |
481 | return def_value; |
482 | } |
483 | } |
484 | |
485 | /*! |
486 | * \brief Check that an expression is a "primitive operator". |
487 | * |
488 | * Will return true if the expression is an operator which |
489 | * matches the form of primitive operators registered directly |
490 | * by the Relay codebase. |
491 | * |
492 | * That is the arguments are all type variables, and there is a single |
493 | * type relation applied to the input and output types. |
494 | * |
495 | * \param expr An expression. |
496 | * \return Whether the expression is primitive op. |
497 | */ |
498 | inline bool IsPrimitiveOp(const RelayExpr& expr) { |
499 | const auto* op = expr.as<OpNode>(); |
500 | return op != nullptr && op->IsPrimitiveOp(); |
501 | } |
502 | |
503 | } // namespace tvm |
504 | #endif // TVM_IR_OP_H_ |
505 | |