1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file constant_folding.cc
22 */
23#include <tvm/relay/analysis.h>
24#include <tvm/relay/attrs/annotation.h>
25#include <tvm/relay/attrs/transform.h>
26#include <tvm/relay/executor.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/interpreter.h>
29#include <tvm/relay/op.h>
30#include <tvm/relay/op_attr_types.h>
31#include <tvm/relay/transform.h>
32#include <tvm/runtime/ndarray.h>
33#include <tvm/runtime/object.h>
34
35#include "../op/memory/on_device.h"
36#include "./pattern_utils.h"
37
38namespace tvm {
39namespace relay {
40namespace transform {
41
42namespace {
43/*!
44 * \brief Returns whether \p expr is a literal \p Constant, optionally wrapped by an "on_device"
45 * annotation CallNode (which serves only to associate an \p VirtualDevice to the constant and has
46 * no operational effect).
47 */
48bool IsSimpleConstant(const Expr& expr) {
49 return AsIgnoringOnDevice<ConstantNode>(expr) != nullptr;
50}
51
52/*!
53 * \brief Returns whether \p expr \p IsSimpleConstant directly or is a tuple of
54 * \p IsComplexConstant expressions.
55 */
56bool IsComplexConstant(const Expr& expr) {
57 if (IsSimpleConstant(expr)) {
58 return true;
59 } else if (const auto* tuple_node = AsIgnoringOnDevice<TupleNode>(expr)) {
60 return std::all_of(tuple_node->fields.begin(), tuple_node->fields.end(), IsComplexConstant);
61 } else {
62 return false;
63 }
64}
65
66// TODO(tvm-team) consider combine dead-code with constant folder.
67// or make a more powerful partial evaluator.
68class ConstantFolder : public MixedModeMutator {
69 public:
70 explicit ConstantFolder(IRModule module, bool fold_qnn)
71 : module_(std::move(module)),
72 fold_qnn_(fold_qnn),
73 device_copy_op_(Op::Get("device_copy")),
74 shape_of_op_(Op::Get("shape_of")),
75 vm_shape_of_op_(Op::Get("vm.shape_of")),
76 cast_op_(Op::Get("cast")),
77 ndarray_size_op_(Op::Get("ndarray_size")) {}
78
79 private:
80 using ExprMutator::VisitExpr_;
81
82 Expr VisitExpr_(const LetNode* let_node) final {
83 auto pre_visit = [this](const LetNode* op) {
84 // Rely on the Memoizer to cache pre-visit values
85 Expr new_value = Mutate(op->value);
86 if (IsSimpleConstant(new_value)) {
87 // Inline new value (along with any on_device annotation wrapping it) at all occurrences of
88 // the variable.
89 //
90 // We need to retain any "on_device" annotation so that downstream 'device aware'
91 // passes can still retrieve the virtual device for the constant in its new position(s). Eg:
92 // def @f(..., result_virtual_device=D) {
93 // let %x = on_device(... something we eval to a constant..., virtual_device=E)
94 // @f(..., %x, ...)
95 // }
96 // Here the default virtual device is D, whereas the argument %x to @f is on E (and @f
97 // expects that). No on_device annotation is required in the call according to the
98 // convention used by the device-aware visitors.
99 //
100 // However once we've inlined the constant we need to insert an on_device, again to
101 // respect the convention used by the device-aware visitors.
102 // def @f(..., result_virtual_device=D) {
103 // @f(..., on_device(...the constant..., virtual_device=E), ...)
104 // }
105 VLOG(1) << "Replacing let-binding for " << op->var->name_hint()
106 << " with constant:" << std::endl
107 << PrettyPrint(new_value);
108 memo_[op->var] = new_value;
109 } else {
110 this->Mutate(op->var);
111 }
112 };
113 auto post_visit = [this](const LetNode* op) {
114 Expr expr = GetRef<Expr>(op);
115 // Rely on the Memoizer to cache pre-visit values
116 Expr new_value = this->Mutate(op->value);
117 if (IsSimpleConstant(new_value)) {
118 // The let-bound value has been inlined, drop the let-binding itself.
119 this->memo_[expr] = Mutate(op->body);
120 } else {
121 Var new_var = Downcast<Var>(this->Mutate(op->var));
122 Expr new_body = this->Mutate(op->body);
123 if (new_var.same_as(op->var) && new_value.same_as(op->value) &&
124 new_body.same_as(op->body)) {
125 this->memo_[expr] = expr;
126 } else {
127 this->memo_[expr] = Let(new_var, new_value, new_body, op->span);
128 }
129 }
130 };
131 ExpandANormalForm(let_node, pre_visit, post_visit);
132 return memo_[GetRef<Expr>(let_node)];
133 }
134
135 Expr VisitExpr_(const FunctionNode* function_node) final {
136 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
137 ICHECK_EQ(inside_primitive_, false);
138 inside_primitive_ = true;
139 auto ret = ExprMutator::VisitExpr_(function_node);
140 inside_primitive_ = false;
141 return ret;
142 } else {
143 return ExprMutator::VisitExpr_(function_node);
144 }
145 }
146
147 Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
148 Call pre_call = GetRef<Call>(pre_call_node);
149 if (inside_primitive_) {
150 return std::move(pre_call);
151 }
152
153 Call post_call = Downcast<Call>(post);
154
155 if (post_call->args.empty()) {
156 // We don't constant fold function with zero arguments.
157 // This is a heuristic that is useful.
158 // For example it is harmful to fold ones(shape=(4, 5)).
159 return std::move(pre_call);
160 }
161
162 const auto* op_node = post_call->op.as<OpNode>();
163 if (op_node == nullptr) {
164 // Only evaluate primitives.
165 return std::move(post_call);
166 }
167 Op op = GetRef<Op>(op_node);
168 static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
169 if (op_stateful.get(op, false)) {
170 // skip stateful ops.
171 return std::move(post_call);
172 }
173 // Try to evaluate shape_of and ndarray_size ops
174 // Use the original call rather than new_call here since it still has valid checked_type
175 // fields. These operators don't care about the value of their argument anyway.
176 if (Optional<Expr> opt_result = EvaluateShapeOf(pre_call)) {
177 return opt_result.value();
178 }
179 // Use the original call rather than new_call here since it still has valid checked_type
180 // fields. This operator doesn't care about the value of its argument anyway.
181 if (Optional<Expr> opt_result = EvaluateNdarraySize(pre_call)) {
182 return opt_result.value();
183 }
184 static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
185 static auto qnn_canonicalize = Op::GetAttrMap<FTVMLegalize>("FTVMQnnCanonicalize");
186 bool is_no_qnn_canonicalized = !qnn_canonicalize.count(op);
187 bool is_no_computational = fnoncomputational.count(op) && fnoncomputational[op];
188 if (is_no_computational && (is_no_qnn_canonicalized || !fold_qnn_)) {
189 return std::move(post_call);
190 }
191 if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_ ||
192 op == ndarray_size_op_) {
193 // We should think about potentially constant evaluation over these ops too.
194 return std::move(post_call);
195 }
196 if (!std::all_of(post_call->args.begin(), post_call->args.end(), IsComplexConstant)) {
197 // At least one non-constant argument.
198 return std::move(post_call);
199 }
200 // During evaluation we have obviously lost all on_device annotations. However any
201 // on_device wrapping this call will be left in place.
202 return ConstEvaluate(post_call);
203 }
204
205 Expr VisitExpr_(const IfNode* if_node) final {
206 If new_if = Downcast<If>(ExprMutator::VisitExpr_(if_node));
207 if (const auto* const_node = AsIgnoringOnDevice<ConstantNode>(new_if->cond)) {
208 if (reinterpret_cast<uint8_t*>(const_node->data->data)[0]) {
209 return new_if->true_branch;
210 } else {
211 return new_if->false_branch;
212 }
213 }
214 return std::move(new_if);
215 }
216
217 Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node,
218 const Expr& post_tuple_get_item) final {
219 const auto* post_tuple_get_item_node = post_tuple_get_item.as<TupleGetItemNode>();
220 if (const auto* tuple_node = AsIgnoringOnDevice<TupleNode>(post_tuple_get_item_node->tuple)) {
221 Expr result = tuple_node->fields[tuple_get_item_node->index];
222 OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple);
223 if (props.body.defined()) {
224 // (on_device((x, y, z), virtual_device=D).1 ==> on_device(y, virtual_device=D)
225 return MaybeOnDeviceWithProps(result, props);
226 } else {
227 return result;
228 }
229 }
230 return post_tuple_get_item;
231 }
232
233 // Convert value to expression.
234 Expr ObjectToExpr(const ObjectRef& value) {
235 if (value->IsInstance<runtime::NDArray::ContainerType>()) {
236 auto nd_array = Downcast<runtime::NDArray>(value);
237 return Constant(nd_array);
238 } else if (const auto* val = value.as<runtime::ADTObj>()) {
239 runtime::ADT adt = GetRef<runtime::ADT>(val);
240 Array<Expr> fields;
241 for (size_t i = 0; i < adt.size(); ++i) {
242 fields.push_back(ObjectToExpr(adt[i]));
243 }
244 return Tuple(fields);
245 } else {
246 LOG(FATAL) << "Cannot handle " << value->GetTypeKey();
247 }
248 }
249
250 // Constant evaluate an expression.
251 Expr ConstEvaluate(const Expr& expr) {
252 VLOG_CONTEXT << "ConstEvaluate";
253 VLOG(1) << "Evaluating :" << std::endl << PrettyPrint(expr);
254
255 // We'll invoke the interpreter using the generic CPU device and target. Technically there's
256 // no guarantee the results will be bitwise equal what we'd get on the true device, however to
257 // support cross-compilation we don't want to assume the true device is available.
258
259 // Use a fresh build context in case we are already in a build context.
260 // needed for both execution and creation(due to JIT)
261 With<transform::PassContext> fresh_build_ctx(transform::PassContext::Create());
262
263 Map<String, ObjectRef> dict = (module_->attrs.defined())
264 ? Map<String, ObjectRef>(module_->attrs.CopyOnWrite()->dict)
265 : Map<String, ObjectRef>();
266
267 // always use graph executor with no link-params
268 dict.Set(tvm::attr::kExecutor,
269 relay::Executor::Create("graph", {{"link-params", Bool(false)}}));
270 Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(),
271 eval_cpu_dev_, eval_cpu_target_, dict));
272 VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result);
273 return result;
274 }
275
276 /*!
277 * \brief Returns constant shape result of \p call if it of form \p shape_of(e) and \p e has
278 * a non-dynamic tensor shape. Returns null otherwise.
279 */
280 Optional<Expr> EvaluateShapeOf(const Call& call) {
281 if (call->op != shape_of_op_ && call->op != vm_shape_of_op_) {
282 return {};
283 }
284
285 VLOG(1) << "Evaluating for shape_of:" << std::endl << PrettyPrint(call);
286 ICHECK_EQ(call->args.size(), 1);
287 const auto* param = call->attrs.as<ShapeOfAttrs>();
288 ICHECK(param != nullptr);
289 Expr input = call->args[0];
290
291 tvm::Array<IndexExpr> ishape;
292 if (Optional<tvm::Array<IndexExpr>> opt_shape = GetConstantShape(input)) {
293 ishape = opt_shape.value();
294 } else {
295 return {};
296 }
297
298 // Get the constant shape
299 runtime::NDArray value;
300 DLDataType cdtype = DataType::Int(32);
301 if (ishape.empty()) {
302 value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_);
303 } else {
304 ICHECK_NE(ishape.size(), 0);
305 std::vector<int64_t> cshape = {static_cast<int64_t>(ishape.size())};
306 value = runtime::NDArray::Empty(cshape, cdtype, eval_cpu_dev_);
307 auto* dims = static_cast<int32_t*>(value->data);
308 using ::tvm::tir::IntImmNode;
309 for (size_t i = 0; i < ishape.size(); ++i) {
310 if (const auto* dim = ishape[i].as<IntImmNode>()) {
311 dims[i] = dim->value;
312 } else {
313 return {};
314 }
315 }
316 }
317
318 Constant shape = Downcast<Constant>(ObjectToExpr(value));
319
320 if (shape->data.Shape().empty() && GetScalarFromConstant<int32_t>(shape) == 0) {
321 auto ndarray = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_);
322 shape = Constant(ndarray);
323 }
324
325 return CastValue(shape, param->dtype);
326 }
327
328 /*!
329 * \brief Returns the constant NDArray size of result of \p call if it is of the form
330 * \p ndarray_size(e) and \p e has non-dynamic tensor type. Returns null otherwise.
331 */
332 Optional<Expr> EvaluateNdarraySize(const Call& call) {
333 if (call->op != ndarray_size_op_) {
334 return {};
335 }
336 VLOG(1) << "Evaluating for ndarray_size:" << std::endl << PrettyPrint(call);
337 ICHECK_EQ(call->args.size(), 1);
338 Expr input = call->args[0];
339 const auto* param = call->attrs.as<NdarraySizeAttrs>();
340 ICHECK(param != nullptr);
341
342 tvm::Array<IndexExpr> ishape;
343 if (Optional<tvm::Array<IndexExpr>> opt_shape = GetConstantShape(input)) {
344 ishape = opt_shape.value();
345 } else {
346 return {};
347 }
348
349 // Get the constant size
350 runtime::NDArray value;
351 DLDataType cdtype = DataType::Int(32);
352 value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_);
353 auto* data = static_cast<int32_t*>(value->data);
354 if (ishape.empty()) {
355 *data = 0;
356 } else {
357 *data = 1;
358 using ::tvm::tir::IntImmNode;
359 for (size_t i = 0; i < ishape.size(); ++i) {
360 if (const auto* dim = ishape[i].as<IntImmNode>()) {
361 *data *= dim->value;
362 } else {
363 return {};
364 }
365 }
366 }
367
368 Constant size = Downcast<Constant>(ObjectToExpr(value));
369 return CastValue(size, param->dtype);
370 }
371
372 Expr CastValue(const Expr& value, DataType dtype) {
373 // Cast the constant into correct dtype
374 auto cast_attrs = make_object<CastAttrs>();
375 cast_attrs->dtype = dtype;
376 Expr ret = Call(cast_op_, {value}, Attrs(cast_attrs), {});
377 return ConstEvaluate(ret);
378 }
379
380 Optional<tvm::Array<IndexExpr>> GetConstantShape(const Expr& input) {
381 if (const auto* const_node = AsIgnoringOnDevice<ConstantNode>(input)) {
382 // TODO(mbs): This is not necessary since we only ever ask for the shapes for
383 // pre-rewritten expressions which will always have a checked_type.
384 return const_node->tensor_type()->shape;
385 } else if (input->checked_type_.defined()) {
386 return input->checked_type().as<TensorTypeNode>()->shape;
387 } else {
388 return {};
389 }
390 }
391
392 // Module
393 IRModule module_;
394
395 // Whether to fold constants for QNN operations.
396 bool fold_qnn_;
397
398 // The kDLCPU device assumed to be available to the compiler. Used only when evaluating
399 // sub-expressions.
400 Device eval_cpu_dev_{kDLCPU, /*device_id=*/0};
401 // The target for the above device assumed to be available to the compiler. Used only when
402 // evaluating sub-expressions.
403 Target eval_cpu_target_{"llvm"};
404
405 // Cache the following ops for equivalence checking in this pass.
406 const Op& device_copy_op_;
407 const Op& shape_of_op_;
408 const Op& vm_shape_of_op_;
409 const Op& cast_op_;
410 const Op& ndarray_size_op_;
411
412 // True if currently within a "primitive" Relay Function.
413 bool inside_primitive_ = false;
414};
415
416} // namespace
417
418TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(IsComplexConstant);
419
420Expr FoldConstantExpr(const Expr& expr, const IRModule& mod, bool fold_qnn) {
421 VLOG_CONTEXT << "FoldConstantExpr";
422 VLOG(1) << "folding:" << std::endl << PrettyPrint(expr);
423 Expr result = ConstantFolder(mod, fold_qnn).VisitExpr(expr);
424 VLOG(1) << "folded to:" << std::endl << PrettyPrint(result);
425 return result;
426}
427
428Expr FoldConstantExpr(const Expr& expr, bool fold_qnn) {
429 auto mod = IRModule::FromExpr(expr);
430 return FoldConstantExpr(expr, mod, fold_qnn);
431}
432
433TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr")
434 .set_body_typed([](const Expr& expr, const IRModule& mod, bool fold_qnn) {
435 return FoldConstantExpr(expr, mod, fold_qnn);
436 });
437
438Pass FoldConstant(bool fold_qnn) {
439 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
440 [=](Function f, IRModule m, PassContext /* pc */) {
441 return Downcast<Function>(FoldConstantExpr(f, m, fold_qnn));
442 };
443 return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
444}
445
446TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant);
447
448} // namespace transform
449} // namespace relay
450} // namespace tvm
451