1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file constant_folding.cc |
22 | */ |
23 | #include <tvm/relay/analysis.h> |
24 | #include <tvm/relay/attrs/annotation.h> |
25 | #include <tvm/relay/attrs/transform.h> |
26 | #include <tvm/relay/executor.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/interpreter.h> |
29 | #include <tvm/relay/op.h> |
30 | #include <tvm/relay/op_attr_types.h> |
31 | #include <tvm/relay/transform.h> |
32 | #include <tvm/runtime/ndarray.h> |
33 | #include <tvm/runtime/object.h> |
34 | |
35 | #include "../op/memory/on_device.h" |
36 | #include "./pattern_utils.h" |
37 | |
38 | namespace tvm { |
39 | namespace relay { |
40 | namespace transform { |
41 | |
42 | namespace { |
43 | /*! |
44 | * \brief Returns whether \p expr is a literal \p Constant, optionally wrapped by an "on_device" |
45 | * annotation CallNode (which serves only to associate an \p VirtualDevice to the constant and has |
46 | * no operational effect). |
47 | */ |
48 | bool IsSimpleConstant(const Expr& expr) { |
49 | return AsIgnoringOnDevice<ConstantNode>(expr) != nullptr; |
50 | } |
51 | |
52 | /*! |
53 | * \brief Returns whether \p expr \p IsSimpleConstant directly or is a tuple of |
54 | * \p IsComplexConstant expressions. |
55 | */ |
56 | bool IsComplexConstant(const Expr& expr) { |
57 | if (IsSimpleConstant(expr)) { |
58 | return true; |
59 | } else if (const auto* tuple_node = AsIgnoringOnDevice<TupleNode>(expr)) { |
60 | return std::all_of(tuple_node->fields.begin(), tuple_node->fields.end(), IsComplexConstant); |
61 | } else { |
62 | return false; |
63 | } |
64 | } |
65 | |
66 | // TODO(tvm-team) consider combine dead-code with constant folder. |
67 | // or make a more powerful partial evaluator. |
68 | class ConstantFolder : public MixedModeMutator { |
69 | public: |
70 | explicit ConstantFolder(IRModule module, bool fold_qnn) |
71 | : module_(std::move(module)), |
72 | fold_qnn_(fold_qnn), |
73 | device_copy_op_(Op::Get("device_copy" )), |
74 | shape_of_op_(Op::Get("shape_of" )), |
75 | vm_shape_of_op_(Op::Get("vm.shape_of" )), |
76 | cast_op_(Op::Get("cast" )), |
77 | ndarray_size_op_(Op::Get("ndarray_size" )) {} |
78 | |
79 | private: |
80 | using ExprMutator::VisitExpr_; |
81 | |
82 | Expr VisitExpr_(const LetNode* let_node) final { |
83 | auto pre_visit = [this](const LetNode* op) { |
84 | // Rely on the Memoizer to cache pre-visit values |
85 | Expr new_value = Mutate(op->value); |
86 | if (IsSimpleConstant(new_value)) { |
87 | // Inline new value (along with any on_device annotation wrapping it) at all occurrences of |
88 | // the variable. |
89 | // |
90 | // We need to retain any "on_device" annotation so that downstream 'device aware' |
91 | // passes can still retrieve the virtual device for the constant in its new position(s). Eg: |
92 | // def @f(..., result_virtual_device=D) { |
93 | // let %x = on_device(... something we eval to a constant..., virtual_device=E) |
94 | // @f(..., %x, ...) |
95 | // } |
96 | // Here the default virtual device is D, whereas the argument %x to @f is on E (and @f |
97 | // expects that). No on_device annotation is required in the call according to the |
98 | // convention used by the device-aware visitors. |
99 | // |
100 | // However once we've inlined the constant we need to insert an on_device, again to |
101 | // respect the convention used by the device-aware visitors. |
102 | // def @f(..., result_virtual_device=D) { |
103 | // @f(..., on_device(...the constant..., virtual_device=E), ...) |
104 | // } |
105 | VLOG(1) << "Replacing let-binding for " << op->var->name_hint() |
106 | << " with constant:" << std::endl |
107 | << PrettyPrint(new_value); |
108 | memo_[op->var] = new_value; |
109 | } else { |
110 | this->Mutate(op->var); |
111 | } |
112 | }; |
113 | auto post_visit = [this](const LetNode* op) { |
114 | Expr expr = GetRef<Expr>(op); |
115 | // Rely on the Memoizer to cache pre-visit values |
116 | Expr new_value = this->Mutate(op->value); |
117 | if (IsSimpleConstant(new_value)) { |
118 | // The let-bound value has been inlined, drop the let-binding itself. |
119 | this->memo_[expr] = Mutate(op->body); |
120 | } else { |
121 | Var new_var = Downcast<Var>(this->Mutate(op->var)); |
122 | Expr new_body = this->Mutate(op->body); |
123 | if (new_var.same_as(op->var) && new_value.same_as(op->value) && |
124 | new_body.same_as(op->body)) { |
125 | this->memo_[expr] = expr; |
126 | } else { |
127 | this->memo_[expr] = Let(new_var, new_value, new_body, op->span); |
128 | } |
129 | } |
130 | }; |
131 | ExpandANormalForm(let_node, pre_visit, post_visit); |
132 | return memo_[GetRef<Expr>(let_node)]; |
133 | } |
134 | |
135 | Expr VisitExpr_(const FunctionNode* function_node) final { |
136 | if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
137 | ICHECK_EQ(inside_primitive_, false); |
138 | inside_primitive_ = true; |
139 | auto ret = ExprMutator::VisitExpr_(function_node); |
140 | inside_primitive_ = false; |
141 | return ret; |
142 | } else { |
143 | return ExprMutator::VisitExpr_(function_node); |
144 | } |
145 | } |
146 | |
147 | Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { |
148 | Call pre_call = GetRef<Call>(pre_call_node); |
149 | if (inside_primitive_) { |
150 | return std::move(pre_call); |
151 | } |
152 | |
153 | Call post_call = Downcast<Call>(post); |
154 | |
155 | if (post_call->args.empty()) { |
156 | // We don't constant fold function with zero arguments. |
157 | // This is a heuristic that is useful. |
158 | // For example it is harmful to fold ones(shape=(4, 5)). |
159 | return std::move(pre_call); |
160 | } |
161 | |
162 | const auto* op_node = post_call->op.as<OpNode>(); |
163 | if (op_node == nullptr) { |
164 | // Only evaluate primitives. |
165 | return std::move(post_call); |
166 | } |
167 | Op op = GetRef<Op>(op_node); |
168 | static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful" ); |
169 | if (op_stateful.get(op, false)) { |
170 | // skip stateful ops. |
171 | return std::move(post_call); |
172 | } |
173 | // Try to evaluate shape_of and ndarray_size ops |
174 | // Use the original call rather than new_call here since it still has valid checked_type |
175 | // fields. These operators don't care about the value of their argument anyway. |
176 | if (Optional<Expr> opt_result = EvaluateShapeOf(pre_call)) { |
177 | return opt_result.value(); |
178 | } |
179 | // Use the original call rather than new_call here since it still has valid checked_type |
180 | // fields. This operator doesn't care about the value of its argument anyway. |
181 | if (Optional<Expr> opt_result = EvaluateNdarraySize(pre_call)) { |
182 | return opt_result.value(); |
183 | } |
184 | static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational" ); |
185 | static auto qnn_canonicalize = Op::GetAttrMap<FTVMLegalize>("FTVMQnnCanonicalize" ); |
186 | bool is_no_qnn_canonicalized = !qnn_canonicalize.count(op); |
187 | bool is_no_computational = fnoncomputational.count(op) && fnoncomputational[op]; |
188 | if (is_no_computational && (is_no_qnn_canonicalized || !fold_qnn_)) { |
189 | return std::move(post_call); |
190 | } |
191 | if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_ || |
192 | op == ndarray_size_op_) { |
193 | // We should think about potentially constant evaluation over these ops too. |
194 | return std::move(post_call); |
195 | } |
196 | if (!std::all_of(post_call->args.begin(), post_call->args.end(), IsComplexConstant)) { |
197 | // At least one non-constant argument. |
198 | return std::move(post_call); |
199 | } |
200 | // During evaluation we have obviously lost all on_device annotations. However any |
201 | // on_device wrapping this call will be left in place. |
202 | return ConstEvaluate(post_call); |
203 | } |
204 | |
205 | Expr VisitExpr_(const IfNode* if_node) final { |
206 | If new_if = Downcast<If>(ExprMutator::VisitExpr_(if_node)); |
207 | if (const auto* const_node = AsIgnoringOnDevice<ConstantNode>(new_if->cond)) { |
208 | if (reinterpret_cast<uint8_t*>(const_node->data->data)[0]) { |
209 | return new_if->true_branch; |
210 | } else { |
211 | return new_if->false_branch; |
212 | } |
213 | } |
214 | return std::move(new_if); |
215 | } |
216 | |
217 | Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node, |
218 | const Expr& post_tuple_get_item) final { |
219 | const auto* post_tuple_get_item_node = post_tuple_get_item.as<TupleGetItemNode>(); |
220 | if (const auto* tuple_node = AsIgnoringOnDevice<TupleNode>(post_tuple_get_item_node->tuple)) { |
221 | Expr result = tuple_node->fields[tuple_get_item_node->index]; |
222 | OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple); |
223 | if (props.body.defined()) { |
224 | // (on_device((x, y, z), virtual_device=D).1 ==> on_device(y, virtual_device=D) |
225 | return MaybeOnDeviceWithProps(result, props); |
226 | } else { |
227 | return result; |
228 | } |
229 | } |
230 | return post_tuple_get_item; |
231 | } |
232 | |
233 | // Convert value to expression. |
234 | Expr ObjectToExpr(const ObjectRef& value) { |
235 | if (value->IsInstance<runtime::NDArray::ContainerType>()) { |
236 | auto nd_array = Downcast<runtime::NDArray>(value); |
237 | return Constant(nd_array); |
238 | } else if (const auto* val = value.as<runtime::ADTObj>()) { |
239 | runtime::ADT adt = GetRef<runtime::ADT>(val); |
240 | Array<Expr> fields; |
241 | for (size_t i = 0; i < adt.size(); ++i) { |
242 | fields.push_back(ObjectToExpr(adt[i])); |
243 | } |
244 | return Tuple(fields); |
245 | } else { |
246 | LOG(FATAL) << "Cannot handle " << value->GetTypeKey(); |
247 | } |
248 | } |
249 | |
250 | // Constant evaluate an expression. |
251 | Expr ConstEvaluate(const Expr& expr) { |
252 | VLOG_CONTEXT << "ConstEvaluate" ; |
253 | VLOG(1) << "Evaluating :" << std::endl << PrettyPrint(expr); |
254 | |
255 | // We'll invoke the interpreter using the generic CPU device and target. Technically there's |
256 | // no guarantee the results will be bitwise equal what we'd get on the true device, however to |
257 | // support cross-compilation we don't want to assume the true device is available. |
258 | |
259 | // Use a fresh build context in case we are already in a build context. |
260 | // needed for both execution and creation(due to JIT) |
261 | With<transform::PassContext> fresh_build_ctx(transform::PassContext::Create()); |
262 | |
263 | Map<String, ObjectRef> dict = (module_->attrs.defined()) |
264 | ? Map<String, ObjectRef>(module_->attrs.CopyOnWrite()->dict) |
265 | : Map<String, ObjectRef>(); |
266 | |
267 | // always use graph executor with no link-params |
268 | dict.Set(tvm::attr::kExecutor, |
269 | relay::Executor::Create("graph" , {{"link-params" , Bool(false)}})); |
270 | Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), |
271 | eval_cpu_dev_, eval_cpu_target_, dict)); |
272 | VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); |
273 | return result; |
274 | } |
275 | |
276 | /*! |
277 | * \brief Returns constant shape result of \p call if it of form \p shape_of(e) and \p e has |
278 | * a non-dynamic tensor shape. Returns null otherwise. |
279 | */ |
280 | Optional<Expr> EvaluateShapeOf(const Call& call) { |
281 | if (call->op != shape_of_op_ && call->op != vm_shape_of_op_) { |
282 | return {}; |
283 | } |
284 | |
285 | VLOG(1) << "Evaluating for shape_of:" << std::endl << PrettyPrint(call); |
286 | ICHECK_EQ(call->args.size(), 1); |
287 | const auto* param = call->attrs.as<ShapeOfAttrs>(); |
288 | ICHECK(param != nullptr); |
289 | Expr input = call->args[0]; |
290 | |
291 | tvm::Array<IndexExpr> ishape; |
292 | if (Optional<tvm::Array<IndexExpr>> opt_shape = GetConstantShape(input)) { |
293 | ishape = opt_shape.value(); |
294 | } else { |
295 | return {}; |
296 | } |
297 | |
298 | // Get the constant shape |
299 | runtime::NDArray value; |
300 | DLDataType cdtype = DataType::Int(32); |
301 | if (ishape.empty()) { |
302 | value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); |
303 | } else { |
304 | ICHECK_NE(ishape.size(), 0); |
305 | std::vector<int64_t> cshape = {static_cast<int64_t>(ishape.size())}; |
306 | value = runtime::NDArray::Empty(cshape, cdtype, eval_cpu_dev_); |
307 | auto* dims = static_cast<int32_t*>(value->data); |
308 | using ::tvm::tir::IntImmNode; |
309 | for (size_t i = 0; i < ishape.size(); ++i) { |
310 | if (const auto* dim = ishape[i].as<IntImmNode>()) { |
311 | dims[i] = dim->value; |
312 | } else { |
313 | return {}; |
314 | } |
315 | } |
316 | } |
317 | |
318 | Constant shape = Downcast<Constant>(ObjectToExpr(value)); |
319 | |
320 | if (shape->data.Shape().empty() && GetScalarFromConstant<int32_t>(shape) == 0) { |
321 | auto ndarray = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); |
322 | shape = Constant(ndarray); |
323 | } |
324 | |
325 | return CastValue(shape, param->dtype); |
326 | } |
327 | |
328 | /*! |
329 | * \brief Returns the constant NDArray size of result of \p call if it is of the form |
330 | * \p ndarray_size(e) and \p e has non-dynamic tensor type. Returns null otherwise. |
331 | */ |
332 | Optional<Expr> EvaluateNdarraySize(const Call& call) { |
333 | if (call->op != ndarray_size_op_) { |
334 | return {}; |
335 | } |
336 | VLOG(1) << "Evaluating for ndarray_size:" << std::endl << PrettyPrint(call); |
337 | ICHECK_EQ(call->args.size(), 1); |
338 | Expr input = call->args[0]; |
339 | const auto* param = call->attrs.as<NdarraySizeAttrs>(); |
340 | ICHECK(param != nullptr); |
341 | |
342 | tvm::Array<IndexExpr> ishape; |
343 | if (Optional<tvm::Array<IndexExpr>> opt_shape = GetConstantShape(input)) { |
344 | ishape = opt_shape.value(); |
345 | } else { |
346 | return {}; |
347 | } |
348 | |
349 | // Get the constant size |
350 | runtime::NDArray value; |
351 | DLDataType cdtype = DataType::Int(32); |
352 | value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); |
353 | auto* data = static_cast<int32_t*>(value->data); |
354 | if (ishape.empty()) { |
355 | *data = 0; |
356 | } else { |
357 | *data = 1; |
358 | using ::tvm::tir::IntImmNode; |
359 | for (size_t i = 0; i < ishape.size(); ++i) { |
360 | if (const auto* dim = ishape[i].as<IntImmNode>()) { |
361 | *data *= dim->value; |
362 | } else { |
363 | return {}; |
364 | } |
365 | } |
366 | } |
367 | |
368 | Constant size = Downcast<Constant>(ObjectToExpr(value)); |
369 | return CastValue(size, param->dtype); |
370 | } |
371 | |
372 | Expr CastValue(const Expr& value, DataType dtype) { |
373 | // Cast the constant into correct dtype |
374 | auto cast_attrs = make_object<CastAttrs>(); |
375 | cast_attrs->dtype = dtype; |
376 | Expr ret = Call(cast_op_, {value}, Attrs(cast_attrs), {}); |
377 | return ConstEvaluate(ret); |
378 | } |
379 | |
380 | Optional<tvm::Array<IndexExpr>> GetConstantShape(const Expr& input) { |
381 | if (const auto* const_node = AsIgnoringOnDevice<ConstantNode>(input)) { |
382 | // TODO(mbs): This is not necessary since we only ever ask for the shapes for |
383 | // pre-rewritten expressions which will always have a checked_type. |
384 | return const_node->tensor_type()->shape; |
385 | } else if (input->checked_type_.defined()) { |
386 | return input->checked_type().as<TensorTypeNode>()->shape; |
387 | } else { |
388 | return {}; |
389 | } |
390 | } |
391 | |
392 | // Module |
393 | IRModule module_; |
394 | |
395 | // Whether to fold constants for QNN operations. |
396 | bool fold_qnn_; |
397 | |
398 | // The kDLCPU device assumed to be available to the compiler. Used only when evaluating |
399 | // sub-expressions. |
400 | Device eval_cpu_dev_{kDLCPU, /*device_id=*/0}; |
401 | // The target for the above device assumed to be available to the compiler. Used only when |
402 | // evaluating sub-expressions. |
403 | Target eval_cpu_target_{"llvm" }; |
404 | |
405 | // Cache the following ops for equivalence checking in this pass. |
406 | const Op& device_copy_op_; |
407 | const Op& shape_of_op_; |
408 | const Op& vm_shape_of_op_; |
409 | const Op& cast_op_; |
410 | const Op& ndarray_size_op_; |
411 | |
412 | // True if currently within a "primitive" Relay Function. |
413 | bool inside_primitive_ = false; |
414 | }; |
415 | |
416 | } // namespace |
417 | |
418 | TVM_REGISTER_GLOBAL("relay.analysis.check_constant" ).set_body_typed(IsComplexConstant); |
419 | |
420 | Expr FoldConstantExpr(const Expr& expr, const IRModule& mod, bool fold_qnn) { |
421 | VLOG_CONTEXT << "FoldConstantExpr" ; |
422 | VLOG(1) << "folding:" << std::endl << PrettyPrint(expr); |
423 | Expr result = ConstantFolder(mod, fold_qnn).VisitExpr(expr); |
424 | VLOG(1) << "folded to:" << std::endl << PrettyPrint(result); |
425 | return result; |
426 | } |
427 | |
428 | Expr FoldConstantExpr(const Expr& expr, bool fold_qnn) { |
429 | auto mod = IRModule::FromExpr(expr); |
430 | return FoldConstantExpr(expr, mod, fold_qnn); |
431 | } |
432 | |
433 | TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr" ) |
434 | .set_body_typed([](const Expr& expr, const IRModule& mod, bool fold_qnn) { |
435 | return FoldConstantExpr(expr, mod, fold_qnn); |
436 | }); |
437 | |
438 | Pass FoldConstant(bool fold_qnn) { |
439 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
440 | [=](Function f, IRModule m, PassContext /* pc */) { |
441 | return Downcast<Function>(FoldConstantExpr(f, m, fold_qnn)); |
442 | }; |
443 | return CreateFunctionPass(pass_func, 2, "FoldConstant" , {}); |
444 | } |
445 | |
446 | TVM_REGISTER_GLOBAL("relay._transform.FoldConstant" ).set_body_typed(FoldConstant); |
447 | |
448 | } // namespace transform |
449 | } // namespace relay |
450 | } // namespace tvm |
451 | |