1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/tir/ir/function.cc |
22 | * \brief The function data structure. |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/function.h> |
26 | #include <tvm/tir/op.h> |
27 | |
28 | namespace tvm { |
29 | namespace tir { |
30 | // Get the function type of a PrimFunc |
31 | PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type, |
32 | Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) { |
33 | // Assume void-return type for now |
34 | // TODO(tvm-team) consider type deduction from body. |
35 | if (!ret_type.defined()) { |
36 | ret_type = VoidType(); |
37 | } |
38 | auto n = make_object<PrimFuncNode>(); |
39 | n->params = std::move(params); |
40 | n->body = std::move(body); |
41 | n->ret_type = std::move(ret_type); |
42 | n->buffer_map = std::move(buffer_map); |
43 | n->attrs = std::move(attrs); |
44 | n->checked_type_ = n->func_type_annotation(); |
45 | n->span = std::move(span); |
46 | data_ = std::move(n); |
47 | } |
48 | |
49 | FuncType PrimFuncNode::func_type_annotation() const { |
50 | Array<Type> param_types; |
51 | for (auto param : this->params) { |
52 | param_types.push_back(GetType(param)); |
53 | } |
54 | return FuncType(param_types, ret_type, {}, {}); |
55 | } |
56 | |
57 | TVM_REGISTER_NODE_TYPE(PrimFuncNode); |
58 | |
59 | class TensorIntrinManager { |
60 | public: |
61 | Map<String, tir::TensorIntrin> reg; |
62 | |
63 | static TensorIntrinManager* Global() { |
64 | static TensorIntrinManager* inst = new TensorIntrinManager(); |
65 | return inst; |
66 | } |
67 | }; |
68 | |
69 | TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { |
70 | // Check the number of func var is equal |
71 | CHECK_EQ(desc->params.size(), impl->params.size()) |
72 | << "ValueError: The number of parameters of the description and the implementation of the " |
73 | "tensor intrinsic doesn't match." ; |
74 | for (size_t i = 0; i < desc->params.size(); i++) { |
75 | CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of the description of the " |
76 | "tensor intrinsic should be handle only." ; |
77 | CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of the implementation of " |
78 | "the tensor intrinsic should be handle only." ; |
79 | } |
80 | ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); |
81 | |
82 | ObjectPtr<TensorIntrinNode> n = make_object<TensorIntrinNode>(); |
83 | n->desc = std::move(desc); |
84 | n->impl = std::move(impl); |
85 | data_ = std::move(n); |
86 | } |
87 | |
88 | void TensorIntrin::Register(String name, TensorIntrin intrin, bool override) { |
89 | TensorIntrinManager* manager = TensorIntrinManager::Global(); |
90 | if (!override) { |
91 | CHECK_EQ(manager->reg.count(name), 0) |
92 | << "ValueError: TensorIntrin '" << name << "' has already been registered" ; |
93 | } |
94 | manager->reg.Set(name, intrin); |
95 | } |
96 | |
97 | Optional<TensorIntrin> TensorIntrin::Get(String name, bool allow_missing) { |
98 | const TensorIntrinManager* manager = TensorIntrinManager::Global(); |
99 | auto it = manager->reg.find(name); |
100 | if (it == manager->reg.end()) { |
101 | if (allow_missing) { |
102 | return NullOpt; |
103 | } else { |
104 | LOG(FATAL) << "ValueError: TensorIntrin '" << name << "' is not registered" ; |
105 | } |
106 | } |
107 | return (*it).second; |
108 | } |
109 | |
110 | TVM_REGISTER_NODE_TYPE(TensorIntrinNode); |
111 | |
112 | TVM_REGISTER_GLOBAL("tir.PrimFunc" ) |
113 | .set_body_typed([](Array<tir::Var> params, Stmt body, Type ret_type, |
114 | Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) { |
115 | return PrimFunc(params, body, ret_type, buffer_map, attrs, span); |
116 | }); |
117 | |
118 | TVM_REGISTER_GLOBAL("tir.TensorIntrin" ) |
119 | .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { |
120 | return TensorIntrin(desc_func, intrin_func); |
121 | }); |
122 | |
123 | TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister" ).set_body_typed(TensorIntrin::Register); |
124 | TVM_REGISTER_GLOBAL("tir.TensorIntrinGet" ).set_body_typed(TensorIntrin::Get); |
125 | |
126 | } // namespace tir |
127 | } // namespace tvm |
128 | |