1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/ir/op.cc |
22 | * \brief Primitive operators and intrinsics. |
23 | */ |
24 | #include <tvm/ir/op.h> |
25 | #include <tvm/ir/type.h> |
26 | #include <tvm/runtime/module.h> |
27 | #include <tvm/runtime/packed_func.h> |
28 | #include <tvm/tir/op_attr_types.h> |
29 | |
30 | #include <memory> |
31 | |
32 | #include "../node/attr_registry.h" |
33 | |
34 | namespace tvm { |
35 | |
36 | using runtime::PackedFunc; |
37 | using runtime::TVMArgs; |
38 | using runtime::TVMRetValue; |
39 | using tir::FLowerIntrinsic; |
40 | |
41 | using OpRegistry = AttrRegistry<OpRegEntry, Op>; |
42 | |
43 | // find operator by name |
44 | const Op& Op::Get(const String& name) { |
45 | const OpRegEntry* reg = OpRegistry::Global()->Get(name); |
46 | ICHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered" ; |
47 | return reg->op(); |
48 | } |
49 | |
50 | OpRegEntry::OpRegEntry(uint32_t reg_index) { |
51 | ObjectPtr<OpNode> n = make_object<OpNode>(); |
52 | n->index_ = reg_index; |
53 | op_ = Op(n); |
54 | } |
55 | |
56 | OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) { |
57 | return OpRegistry::Global()->RegisterOrGet(name); |
58 | } |
59 | |
60 | // Get attribute map by key |
61 | const AttrRegistryMapContainerMap<Op>& Op::GetAttrMapContainer(const String& attr_name) { |
62 | return OpRegistry::Global()->GetAttrMap(attr_name); |
63 | } |
64 | |
65 | // Check if a key is present in the registry. |
66 | bool Op::HasAttrMap(const String& attr_name) { return OpRegistry::Global()->HasAttrMap(attr_name); } |
67 | |
68 | // Resets attr of the OpAttrMap. |
69 | void OpRegEntry::reset_attr(const std::string& attr_name) { |
70 | OpRegistry::Global()->ResetAttr(attr_name, op_); |
71 | } |
72 | |
73 | void OpRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) { |
74 | OpRegistry::Global()->UpdateAttr(key, op_, value, plevel); |
75 | } |
76 | |
77 | // Frontend APIs |
78 | TVM_REGISTER_GLOBAL("ir.ListOpNames" ).set_body_typed([]() { |
79 | return OpRegistry::Global()->ListAllNames(); |
80 | }); |
81 | |
82 | TVM_REGISTER_GLOBAL("ir.GetOp" ).set_body_typed([](String name) -> Op { return Op::Get(name); }); |
83 | |
84 | TVM_REGISTER_GLOBAL("ir.OpGetAttr" ).set_body_typed([](Op op, String attr_name) -> TVMRetValue { |
85 | auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name); |
86 | TVMRetValue rv; |
87 | if (op_map.count(op)) { |
88 | rv = op_map[op]; |
89 | } |
90 | return rv; |
91 | }); |
92 | |
93 | TVM_REGISTER_GLOBAL("ir.OpHasAttr" ).set_body_typed([](Op op, String attr_name) -> bool { |
94 | return Op::HasAttrMap(attr_name); |
95 | }); |
96 | |
97 | TVM_REGISTER_GLOBAL("ir.OpSetAttr" ) |
98 | .set_body_typed([](Op op, String attr_name, runtime::TVMArgValue value, int plevel) { |
99 | auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); |
100 | reg.set_attr(attr_name, value, plevel); |
101 | }); |
102 | |
103 | TVM_REGISTER_GLOBAL("ir.OpResetAttr" ).set_body_typed([](Op op, String attr_name) { |
104 | auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); |
105 | reg.reset_attr(attr_name); |
106 | }); |
107 | |
108 | TVM_REGISTER_GLOBAL("ir.RegisterOp" ).set_body_typed([](String op_name, String descr) { |
109 | const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); |
110 | ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before" ; |
111 | auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); |
112 | op.describe(descr); |
113 | }); |
114 | |
115 | // This is exposed FFI api for prototyping using in python. |
116 | // Note: it is not full of the C++ type relation, |
117 | // since in python side we don't have access to the type reporter, |
118 | // and cannot propagate constraints to the inputs, only to the output. |
119 | TVM_REGISTER_GLOBAL("ir.OpAddTypeRel" ) |
120 | .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) { |
121 | auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); |
122 | if (value.type_code() == kTVMPackedFuncHandle) { |
123 | // do an eager copy of the PackedFunc to avoid deleting function from frontend. |
124 | PackedFunc fcopy = value; |
125 | auto f = [=](const Array<Type>& args, int num_inputs, const Attrs& attrs, |
126 | const TypeReporter& reporter) -> bool { |
127 | Array<Type> input_types(args.begin(), args.end() - 1); |
128 | // call customized relation functions |
129 | // *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type |
130 | Type ret_type = fcopy(input_types, attrs); |
131 | // when defined ret_type, inference of output type is ok, do type assign |
132 | // otherwise, inference failure happens |
133 | if (ret_type.defined()) { |
134 | // the last argument is output |
135 | // TODO(xqdan): support multiple outputs |
136 | reporter->Assign(args.back(), ret_type); |
137 | return true; |
138 | } |
139 | return false; |
140 | }; |
141 | // adjust function call to call conventions of relay type system with TypeReporter |
142 | auto type_rel = runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, |
143 | const TypeReporter&)>(f); |
144 | reg.add_type_rel(rel_name, type_rel); |
145 | } else if (value.type_code() == kTVMNullptr) { |
146 | // Call relation functions of relay |
147 | auto func_name = std::string("tvm.relay.type_relation." ) + rel_name; |
148 | auto* f = runtime::Registry::Get(func_name); |
149 | ICHECK(f != nullptr) << "AddTypeRel error: no type_relation registered." ; |
150 | reg.add_type_rel(rel_name, *f); |
151 | } |
152 | }); |
153 | |
154 | TVM_REGISTER_GLOBAL("ir.OpAddArgument" ) |
155 | .set_body_typed([](Op op, String name, String type, String description) { |
156 | auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); |
157 | reg.add_argument(name, type, description); |
158 | }); |
159 | |
160 | TVM_REGISTER_GLOBAL("ir.OpSetSupportLevel" ).set_body_typed([](Op op, int level) { |
161 | auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); |
162 | reg.set_support_level(level); |
163 | }); |
164 | |
165 | TVM_REGISTER_GLOBAL("ir.OpSetNumInputs" ).set_body_typed([](Op op, int n) { |
166 | auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); |
167 | reg.set_num_inputs(n); |
168 | }); |
169 | |
170 | TVM_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey" ).set_body_typed([](Op op, String key) { |
171 | auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); |
172 | reg.set_attrs_type_key(key); |
173 | }); |
174 | |
175 | TVM_REGISTER_GLOBAL("ir.RegisterOpAttr" ) |
176 | .set_body_typed([](String op_name, String attr_key, runtime::TVMArgValue value, int plevel) { |
177 | auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); |
178 | // enable resgiteration and override of certain properties |
179 | if (attr_key == "num_inputs" && plevel > 128) { |
180 | reg.set_num_inputs(value); |
181 | } else if (attr_key == "attrs_type_key" && plevel > 128) { |
182 | LOG(FATAL) << "attrs type key no longer supported" ; |
183 | } else { |
184 | // normal attr table override. |
185 | if (value.type_code() == kTVMPackedFuncHandle) { |
186 | // do an eager copy of the PackedFunc |
187 | PackedFunc f = value; |
188 | reg.set_attr(attr_key, f, plevel); |
189 | } else { |
190 | reg.set_attr(attr_key, value, plevel); |
191 | } |
192 | } |
193 | }); |
194 | |
195 | TVM_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic" ) |
196 | .set_body_typed([](String name, PackedFunc f, String target, int plevel) { |
197 | tvm::OpRegEntry::RegisterOrGet(name).set_attr<FLowerIntrinsic>(target + ".FLowerIntrinsic" , f, |
198 | plevel); |
199 | }); |
200 | |
201 | // helper to get internal dev function in objectref. |
202 | struct Op2ObjectPtr : public ObjectRef { |
203 | static ObjectPtr<Object> Get(const Op& op) { return GetDataPtr<Object>(op); } |
204 | }; |
205 | |
206 | ObjectPtr<Object> CreateOp(const std::string& name) { |
207 | // Hack use TVMRetValue as exchange |
208 | auto op = Op::Get(name); |
209 | ICHECK(op.defined()) << "Cannot find op \'" << name << '\''; |
210 | return Op2ObjectPtr::Get(op); |
211 | } |
212 | |
213 | TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes( |
214 | [](const Object* n) -> std::string { return static_cast<const OpNode*>(n)->name; }); |
215 | |
216 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
217 | .set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) { |
218 | auto* node = static_cast<const OpNode*>(ref.get()); |
219 | p->stream << "Op(" << node->name << ")" ; |
220 | }); |
221 | |
222 | } // namespace tvm |
223 | |