1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file operation_inline.cc
22 */
23#include "operation_inline.h"
24
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/stmt.h>
28#include <tvm/tir/stmt_functor.h>
29
30#include <utility>
31
32#include "../../tir/transforms/ir_utils.h"
33
34namespace tvm {
35namespace te {
36
37// inliner to inline a function
38// the result may not be SSA,
39// ConvertSSA need to be applied after this pass
40class OperationInliner final : public StmtExprMutator {
41 public:
42 OperationInliner(Operation op, Array<Var> args, PrimExpr body)
43 : operation_(op), args_(args), body_(body) {}
44
45 PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
46 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
47 op = expr.as<ProducerLoadNode>();
48 auto tensor = Downcast<Tensor>(op->producer);
49
50 if (tensor->op.same_as(operation_)) {
51 ICHECK_EQ(tensor->value_index, 0);
52 expr = body_;
53 ICHECK_EQ(args_.size(), op->indices.size());
54
55 bool has_side_effect = false;
56 for (size_t i = 0; i < op->indices.size(); ++i) {
57 if (SideEffect(op->indices[i]) > CallEffectKind::kReadState) has_side_effect = true;
58 }
59 if (has_side_effect) {
60 for (size_t i = 0; i < args_.size(); ++i) {
61 expr = Let(args_[i], op->indices[i], expr);
62 }
63 } else {
64 Map<Var, PrimExpr> vmap;
65 for (size_t i = 0; i < args_.size(); ++i) {
66 // cast indices to the type of the original indexing variable
67 vmap.Set(args_[i], cast(args_[i].dtype(), op->indices[i]));
68 }
69 expr = Substitute(Evaluate(expr), vmap).as<EvaluateNode>()->value;
70 }
71 return expr;
72 } else {
73 return expr;
74 }
75 }
76
77 private:
78 Operation operation_;
79 Array<Var> args_;
80 PrimExpr body_;
81};
82
83Stmt Inline(Stmt stmt, Operation f, Array<Var> args, PrimExpr body) {
84 ICHECK_EQ(f->num_outputs(), 1) << "can only inline output single value operation";
85 Stmt ret = OperationInliner(f, args, body)(std::move(stmt));
86 if (ret.same_as(stmt)) return ret;
87 return ConvertSSA(ret);
88}
89} // namespace te
90} // namespace tvm
91