1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | /*! |
25 | * \brief Find the entry function of the given IRModule, i.e, functions marked by |
26 | * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. |
27 | * \param mod The IRModule to find the entry function. |
28 | * \return The entry function. |
29 | */ |
30 | inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { |
31 | // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` |
32 | int num_prim_func = 0; |
33 | const tir::PrimFuncNode* main_func = nullptr; |
34 | const tir::PrimFuncNode* last_func = nullptr; |
35 | for (const auto& kv : mod->functions) { |
36 | GlobalVar gv = kv.first; |
37 | BaseFunc base_func = kv.second; |
38 | if (const auto* func = base_func.as<tir::PrimFuncNode>()) { |
39 | last_func = func; |
40 | if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { |
41 | return GetRef<tir::PrimFunc>(func); |
42 | } |
43 | if (gv->name_hint == "main" ) { |
44 | main_func = func; |
45 | } |
46 | ++num_prim_func; |
47 | } |
48 | } |
49 | // Priority 2: PrimFunc whose name is `main` |
50 | if (main_func != nullptr) { |
51 | return GetRef<tir::PrimFunc>(main_func); |
52 | } |
53 | // Priority 3: The only PrimFunc in the IRModule |
54 | if (num_prim_func == 0) { |
55 | LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " << mod; |
56 | } |
57 | if (num_prim_func > 1) { |
58 | LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " |
59 | "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" |
60 | << mod; |
61 | } |
62 | return GetRef<tir::PrimFunc>(last_func); |
63 | } |
64 | /******** ArgInfo ********/ |
65 | |
66 | ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { |
67 | // The JSON object is always an array whose first element is a tag. For example: |
68 | // `['TENSOR', 'float32', [1, 224, 224, 3]] |
69 | // Step 1. Extract the tag |
70 | String tag{runtime::ObjectPtr<runtime::StringObj>(nullptr)}; |
71 | try { |
72 | const ArrayNode* json_array = json_obj.as<ArrayNode>(); |
73 | CHECK(json_array && json_array->size() >= 1); |
74 | tag = Downcast<String>(json_array->at(0)); |
75 | } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error |
76 | LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj |
77 | << "\nThe error is: " << e.what(); |
78 | } |
79 | // Step 2. Dispatch the tag to corresponding subclass of ArgInfo |
80 | if (tag == "TENSOR" ) { |
81 | return TensorInfo::FromJSON(json_obj); |
82 | } |
83 | LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj; |
84 | throw; |
85 | } |
86 | |
87 | Array<ArgInfo> ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { |
88 | using support::AsVector; |
89 | Array<ArgInfo> result; |
90 | result.reserve(func->params.size()); |
91 | for (const tir::Var& arg : func->params) { |
92 | if (Optional<tir::Buffer> _buffer = func->buffer_map.Get(arg)) { |
93 | tir::Buffer buffer = _buffer.value(); |
94 | result.push_back(TensorInfo(/*dtype=*/buffer->dtype, |
95 | /*shape=*/AsVector<PrimExpr, int64_t>(buffer->shape))); |
96 | } else { |
97 | LOG(FATAL) << "ValueError: Unsupported argument type: " << arg; |
98 | } |
99 | } |
100 | return result; |
101 | } |
102 | |
103 | Array<ArgInfo> ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) { |
104 | if (remove_preproc) { |
105 | IRModule new_mod = |
106 | tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ true)(mod); |
107 | return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod)); |
108 | } |
109 | return ArgInfo::FromPrimFunc(FindEntryFunc(mod)); |
110 | } |
111 | |
112 | /******** TensorInfo ********/ |
113 | |
114 | TensorInfo::TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape) { |
115 | ObjectPtr<TensorInfoNode> n = make_object<TensorInfoNode>(); |
116 | n->dtype = dtype; |
117 | n->shape = shape; |
118 | this->data_ = std::move(n); |
119 | } |
120 | |
121 | ObjectRef TensorInfoNode::AsJSON() const { |
122 | static String tag = "TENSOR" ; |
123 | String dtype = DLDataType2String(this->dtype); |
124 | Array<Integer> shape = support::AsArray(this->shape); |
125 | return Array<ObjectRef>{tag, dtype, shape}; |
126 | } |
127 | |
128 | TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { |
129 | DLDataType dtype; |
130 | Array<Integer> shape; |
131 | try { |
132 | const ArrayNode* json_array = json_obj.as<ArrayNode>(); |
133 | CHECK(json_array && json_array->size() == 3); |
134 | // Load json[1] => dtype |
135 | { |
136 | String dtype_str = Downcast<String>(json_array->at(1)); |
137 | dtype = runtime::String2DLDataType(dtype_str); |
138 | } |
139 | // Load json[2] => shape |
140 | shape = AsIntArray(json_array->at(2)); |
141 | } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error |
142 | LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj |
143 | << "\nThe error is: " << e.what(); |
144 | } |
145 | std::vector<int64_t> s; |
146 | std::transform(shape.begin(), shape.end(), std::back_inserter(s), |
147 | [](Integer i) { return i.IntValue(); }); |
148 | return TensorInfo(DataType(dtype), ShapeTuple(s.begin(), s.end())); |
149 | } |
150 | |
151 | /******** Repr ********/ |
152 | |
153 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
154 | .set_dispatch<TensorInfoNode>([](const ObjectRef& n, ReprPrinter* p) { |
155 | const auto* self = n.as<TensorInfoNode>(); |
156 | ICHECK(self); |
157 | p->stream << "TensorInfo(\"" << self->dtype << "\", " << self->shape << ")" ; |
158 | }); |
159 | |
160 | /******** FFI ********/ |
161 | |
162 | TVM_REGISTER_OBJECT_TYPE(ArgInfoNode); |
163 | TVM_REGISTER_NODE_TYPE(TensorInfoNode); |
164 | |
165 | TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON" ).set_body_method<ArgInfo>(&ArgInfoNode::AsJSON); |
166 | TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc" ).set_body_typed(ArgInfo::FromPrimFunc); |
167 | TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromEntryFunc" ).set_body_typed(ArgInfo::FromEntryFunc); |
168 | TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON" ).set_body_typed(ArgInfo::FromJSON); |
169 | TVM_REGISTER_GLOBAL("meta_schedule.TensorInfo" ) |
170 | .set_body_typed([](runtime::DataType dtype, runtime::ShapeTuple shape) -> TensorInfo { |
171 | return TensorInfo(dtype, shape); |
172 | }); |
173 | |
174 | } // namespace meta_schedule |
175 | } // namespace tvm |
176 | |