1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/ir/op.cc
22 * \brief Primitive operators and intrinsics.
23 */
24#include <tvm/ir/op.h>
25#include <tvm/ir/type.h>
26#include <tvm/runtime/module.h>
27#include <tvm/runtime/packed_func.h>
28#include <tvm/tir/op_attr_types.h>
29
30#include <memory>
31
32#include "../node/attr_registry.h"
33
34namespace tvm {
35
36using runtime::PackedFunc;
37using runtime::TVMArgs;
38using runtime::TVMRetValue;
39using tir::FLowerIntrinsic;
40
41using OpRegistry = AttrRegistry<OpRegEntry, Op>;
42
43// find operator by name
44const Op& Op::Get(const String& name) {
45 const OpRegEntry* reg = OpRegistry::Global()->Get(name);
46 ICHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered";
47 return reg->op();
48}
49
50OpRegEntry::OpRegEntry(uint32_t reg_index) {
51 ObjectPtr<OpNode> n = make_object<OpNode>();
52 n->index_ = reg_index;
53 op_ = Op(n);
54}
55
56OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) {
57 return OpRegistry::Global()->RegisterOrGet(name);
58}
59
60// Get attribute map by key
61const AttrRegistryMapContainerMap<Op>& Op::GetAttrMapContainer(const String& attr_name) {
62 return OpRegistry::Global()->GetAttrMap(attr_name);
63}
64
65// Check if a key is present in the registry.
66bool Op::HasAttrMap(const String& attr_name) { return OpRegistry::Global()->HasAttrMap(attr_name); }
67
68// Resets attr of the OpAttrMap.
69void OpRegEntry::reset_attr(const std::string& attr_name) {
70 OpRegistry::Global()->ResetAttr(attr_name, op_);
71}
72
73void OpRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
74 OpRegistry::Global()->UpdateAttr(key, op_, value, plevel);
75}
76
77// Frontend APIs
78TVM_REGISTER_GLOBAL("ir.ListOpNames").set_body_typed([]() {
79 return OpRegistry::Global()->ListAllNames();
80});
81
82TVM_REGISTER_GLOBAL("ir.GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); });
83
84TVM_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> TVMRetValue {
85 auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
86 TVMRetValue rv;
87 if (op_map.count(op)) {
88 rv = op_map[op];
89 }
90 return rv;
91});
92
93TVM_REGISTER_GLOBAL("ir.OpHasAttr").set_body_typed([](Op op, String attr_name) -> bool {
94 return Op::HasAttrMap(attr_name);
95});
96
97TVM_REGISTER_GLOBAL("ir.OpSetAttr")
98 .set_body_typed([](Op op, String attr_name, runtime::TVMArgValue value, int plevel) {
99 auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
100 reg.set_attr(attr_name, value, plevel);
101 });
102
103TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) {
104 auto& reg = OpRegistry::Global()->RegisterOrGet(op->name);
105 reg.reset_attr(attr_name);
106});
107
108TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) {
109 const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
110 ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before";
111 auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
112 op.describe(descr);
113});
114
115// This is exposed FFI api for prototyping using in python.
116// Note: it is not full of the C++ type relation,
117// since in python side we don't have access to the type reporter,
118// and cannot propagate constraints to the inputs, only to the output.
119TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
120 .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
121 auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
122 if (value.type_code() == kTVMPackedFuncHandle) {
123 // do an eager copy of the PackedFunc to avoid deleting function from frontend.
124 PackedFunc fcopy = value;
125 auto f = [=](const Array<Type>& args, int num_inputs, const Attrs& attrs,
126 const TypeReporter& reporter) -> bool {
127 Array<Type> input_types(args.begin(), args.end() - 1);
128 // call customized relation functions
129 // *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type
130 Type ret_type = fcopy(input_types, attrs);
131 // when defined ret_type, inference of output type is ok, do type assign
132 // otherwise, inference failure happens
133 if (ret_type.defined()) {
134 // the last argument is output
135 // TODO(xqdan): support multiple outputs
136 reporter->Assign(args.back(), ret_type);
137 return true;
138 }
139 return false;
140 };
141 // adjust function call to call conventions of relay type system with TypeReporter
142 auto type_rel = runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&,
143 const TypeReporter&)>(f);
144 reg.add_type_rel(rel_name, type_rel);
145 } else if (value.type_code() == kTVMNullptr) {
146 // Call relation functions of relay
147 auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
148 auto* f = runtime::Registry::Get(func_name);
149 ICHECK(f != nullptr) << "AddTypeRel error: no type_relation registered.";
150 reg.add_type_rel(rel_name, *f);
151 }
152 });
153
154TVM_REGISTER_GLOBAL("ir.OpAddArgument")
155 .set_body_typed([](Op op, String name, String type, String description) {
156 auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
157 reg.add_argument(name, type, description);
158 });
159
160TVM_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) {
161 auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
162 reg.set_support_level(level);
163});
164
165TVM_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) {
166 auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
167 reg.set_num_inputs(n);
168});
169
170TVM_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) {
171 auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
172 reg.set_attrs_type_key(key);
173});
174
175TVM_REGISTER_GLOBAL("ir.RegisterOpAttr")
176 .set_body_typed([](String op_name, String attr_key, runtime::TVMArgValue value, int plevel) {
177 auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
178 // enable resgiteration and override of certain properties
179 if (attr_key == "num_inputs" && plevel > 128) {
180 reg.set_num_inputs(value);
181 } else if (attr_key == "attrs_type_key" && plevel > 128) {
182 LOG(FATAL) << "attrs type key no longer supported";
183 } else {
184 // normal attr table override.
185 if (value.type_code() == kTVMPackedFuncHandle) {
186 // do an eager copy of the PackedFunc
187 PackedFunc f = value;
188 reg.set_attr(attr_key, f, plevel);
189 } else {
190 reg.set_attr(attr_key, value, plevel);
191 }
192 }
193 });
194
195TVM_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic")
196 .set_body_typed([](String name, PackedFunc f, String target, int plevel) {
197 tvm::OpRegEntry::RegisterOrGet(name).set_attr<FLowerIntrinsic>(target + ".FLowerIntrinsic", f,
198 plevel);
199 });
200
201// helper to get internal dev function in objectref.
202struct Op2ObjectPtr : public ObjectRef {
203 static ObjectPtr<Object> Get(const Op& op) { return GetDataPtr<Object>(op); }
204};
205
206ObjectPtr<Object> CreateOp(const std::string& name) {
207 // Hack use TVMRetValue as exchange
208 auto op = Op::Get(name);
209 ICHECK(op.defined()) << "Cannot find op \'" << name << '\'';
210 return Op2ObjectPtr::Get(op);
211}
212
213TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes(
214 [](const Object* n) -> std::string { return static_cast<const OpNode*>(n)->name; });
215
216TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
217 .set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) {
218 auto* node = static_cast<const OpNode*>(ref.get());
219 p->stream << "Op(" << node->name << ")";
220 });
221
222} // namespace tvm
223