1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/tir/ir/function.cc
22 * \brief The function data structure.
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/function.h>
26#include <tvm/tir/op.h>
27
28namespace tvm {
29namespace tir {
30// Get the function type of a PrimFunc
31PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
32 Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
33 // Assume void-return type for now
34 // TODO(tvm-team) consider type deduction from body.
35 if (!ret_type.defined()) {
36 ret_type = VoidType();
37 }
38 auto n = make_object<PrimFuncNode>();
39 n->params = std::move(params);
40 n->body = std::move(body);
41 n->ret_type = std::move(ret_type);
42 n->buffer_map = std::move(buffer_map);
43 n->attrs = std::move(attrs);
44 n->checked_type_ = n->func_type_annotation();
45 n->span = std::move(span);
46 data_ = std::move(n);
47}
48
49FuncType PrimFuncNode::func_type_annotation() const {
50 Array<Type> param_types;
51 for (auto param : this->params) {
52 param_types.push_back(GetType(param));
53 }
54 return FuncType(param_types, ret_type, {}, {});
55}
56
57TVM_REGISTER_NODE_TYPE(PrimFuncNode);
58
59class TensorIntrinManager {
60 public:
61 Map<String, tir::TensorIntrin> reg;
62
63 static TensorIntrinManager* Global() {
64 static TensorIntrinManager* inst = new TensorIntrinManager();
65 return inst;
66 }
67};
68
69TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) {
70 // Check the number of func var is equal
71 CHECK_EQ(desc->params.size(), impl->params.size())
72 << "ValueError: The number of parameters of the description and the implementation of the "
73 "tensor intrinsic doesn't match.";
74 for (size_t i = 0; i < desc->params.size(); i++) {
75 CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of the description of the "
76 "tensor intrinsic should be handle only.";
77 CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of the implementation of "
78 "the tensor intrinsic should be handle only.";
79 }
80 ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size());
81
82 ObjectPtr<TensorIntrinNode> n = make_object<TensorIntrinNode>();
83 n->desc = std::move(desc);
84 n->impl = std::move(impl);
85 data_ = std::move(n);
86}
87
88void TensorIntrin::Register(String name, TensorIntrin intrin, bool override) {
89 TensorIntrinManager* manager = TensorIntrinManager::Global();
90 if (!override) {
91 CHECK_EQ(manager->reg.count(name), 0)
92 << "ValueError: TensorIntrin '" << name << "' has already been registered";
93 }
94 manager->reg.Set(name, intrin);
95}
96
97Optional<TensorIntrin> TensorIntrin::Get(String name, bool allow_missing) {
98 const TensorIntrinManager* manager = TensorIntrinManager::Global();
99 auto it = manager->reg.find(name);
100 if (it == manager->reg.end()) {
101 if (allow_missing) {
102 return NullOpt;
103 } else {
104 LOG(FATAL) << "ValueError: TensorIntrin '" << name << "' is not registered";
105 }
106 }
107 return (*it).second;
108}
109
110TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
111
112TVM_REGISTER_GLOBAL("tir.PrimFunc")
113 .set_body_typed([](Array<tir::Var> params, Stmt body, Type ret_type,
114 Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
115 return PrimFunc(params, body, ret_type, buffer_map, attrs, span);
116 });
117
118TVM_REGISTER_GLOBAL("tir.TensorIntrin")
119 .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) {
120 return TensorIntrin(desc_func, intrin_func);
121 });
122
123TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register);
124TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get);
125
126} // namespace tir
127} // namespace tvm
128