1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/*!
25 * \brief Find the entry function of the given IRModule, i.e, functions marked by
26 * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
27 * \param mod The IRModule to find the entry function.
28 * \return The entry function.
29 */
30inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
31 // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
32 int num_prim_func = 0;
33 const tir::PrimFuncNode* main_func = nullptr;
34 const tir::PrimFuncNode* last_func = nullptr;
35 for (const auto& kv : mod->functions) {
36 GlobalVar gv = kv.first;
37 BaseFunc base_func = kv.second;
38 if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
39 last_func = func;
40 if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
41 return GetRef<tir::PrimFunc>(func);
42 }
43 if (gv->name_hint == "main") {
44 main_func = func;
45 }
46 ++num_prim_func;
47 }
48 }
49 // Priority 2: PrimFunc whose name is `main`
50 if (main_func != nullptr) {
51 return GetRef<tir::PrimFunc>(main_func);
52 }
53 // Priority 3: The only PrimFunc in the IRModule
54 if (num_prim_func == 0) {
55 LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " << mod;
56 }
57 if (num_prim_func > 1) {
58 LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are "
59 "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`"
60 << mod;
61 }
62 return GetRef<tir::PrimFunc>(last_func);
63}
64/******** ArgInfo ********/
65
66ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) {
67 // The JSON object is always an array whose first element is a tag. For example:
68 // `['TENSOR', 'float32', [1, 224, 224, 3]]
69 // Step 1. Extract the tag
70 String tag{runtime::ObjectPtr<runtime::StringObj>(nullptr)};
71 try {
72 const ArrayNode* json_array = json_obj.as<ArrayNode>();
73 CHECK(json_array && json_array->size() >= 1);
74 tag = Downcast<String>(json_array->at(0));
75 } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error
76 LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
77 << "\nThe error is: " << e.what();
78 }
79 // Step 2. Dispatch the tag to corresponding subclass of ArgInfo
80 if (tag == "TENSOR") {
81 return TensorInfo::FromJSON(json_obj);
82 }
83 LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj;
84 throw;
85}
86
87Array<ArgInfo> ArgInfo::FromPrimFunc(const tir::PrimFunc& func) {
88 using support::AsVector;
89 Array<ArgInfo> result;
90 result.reserve(func->params.size());
91 for (const tir::Var& arg : func->params) {
92 if (Optional<tir::Buffer> _buffer = func->buffer_map.Get(arg)) {
93 tir::Buffer buffer = _buffer.value();
94 result.push_back(TensorInfo(/*dtype=*/buffer->dtype,
95 /*shape=*/AsVector<PrimExpr, int64_t>(buffer->shape)));
96 } else {
97 LOG(FATAL) << "ValueError: Unsupported argument type: " << arg;
98 }
99 }
100 return result;
101}
102
103Array<ArgInfo> ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) {
104 if (remove_preproc) {
105 IRModule new_mod =
106 tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ true)(mod);
107 return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod));
108 }
109 return ArgInfo::FromPrimFunc(FindEntryFunc(mod));
110}
111
112/******** TensorInfo ********/
113
114TensorInfo::TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape) {
115 ObjectPtr<TensorInfoNode> n = make_object<TensorInfoNode>();
116 n->dtype = dtype;
117 n->shape = shape;
118 this->data_ = std::move(n);
119}
120
121ObjectRef TensorInfoNode::AsJSON() const {
122 static String tag = "TENSOR";
123 String dtype = DLDataType2String(this->dtype);
124 Array<Integer> shape = support::AsArray(this->shape);
125 return Array<ObjectRef>{tag, dtype, shape};
126}
127
128TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) {
129 DLDataType dtype;
130 Array<Integer> shape;
131 try {
132 const ArrayNode* json_array = json_obj.as<ArrayNode>();
133 CHECK(json_array && json_array->size() == 3);
134 // Load json[1] => dtype
135 {
136 String dtype_str = Downcast<String>(json_array->at(1));
137 dtype = runtime::String2DLDataType(dtype_str);
138 }
139 // Load json[2] => shape
140 shape = AsIntArray(json_array->at(2));
141 } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error
142 LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
143 << "\nThe error is: " << e.what();
144 }
145 std::vector<int64_t> s;
146 std::transform(shape.begin(), shape.end(), std::back_inserter(s),
147 [](Integer i) { return i.IntValue(); });
148 return TensorInfo(DataType(dtype), ShapeTuple(s.begin(), s.end()));
149}
150
151/******** Repr ********/
152
153TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
154 .set_dispatch<TensorInfoNode>([](const ObjectRef& n, ReprPrinter* p) {
155 const auto* self = n.as<TensorInfoNode>();
156 ICHECK(self);
157 p->stream << "TensorInfo(\"" << self->dtype << "\", " << self->shape << ")";
158 });
159
160/******** FFI ********/
161
162TVM_REGISTER_OBJECT_TYPE(ArgInfoNode);
163TVM_REGISTER_NODE_TYPE(TensorInfoNode);
164
165TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method<ArgInfo>(&ArgInfoNode::AsJSON);
166TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc);
167TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromEntryFunc").set_body_typed(ArgInfo::FromEntryFunc);
168TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON);
169TVM_REGISTER_GLOBAL("meta_schedule.TensorInfo")
170 .set_body_typed([](runtime::DataType dtype, runtime::ShapeTuple shape) -> TensorInfo {
171 return TensorInfo(dtype, shape);
172 });
173
174} // namespace meta_schedule
175} // namespace tvm
176