1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_inline_elem_wise.cc
22 */
23#include <tvm/runtime/registry.h>
24#include <tvm/te/operation.h>
25#include <tvm/te/schedule_pass.h>
26#include <tvm/tir/expr_functor.h>
27
28namespace tvm {
29namespace te {
30
31using namespace tir;
32
33class ElemWiseDetector : public tir::ExprVisitor {
34 public:
35 explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
36
37 void VisitExpr(const PrimExpr& e) final {
38 if (!is_elem_wise_) return;
39 ExprVisitor::VisitExpr(e);
40 }
41
42 void VisitExpr_(const ProducerLoadNode* op) final {
43 Array<PrimExpr> indices = op->indices;
44 if (axis_.size() != indices.size()) {
45 is_elem_wise_ = false;
46 return;
47 }
48
49 for (size_t i = 0; i < axis_.size(); ++i) {
50 if (!indices[i].same_as(axis_[i]->var)) {
51 is_elem_wise_ = false;
52 return;
53 }
54 }
55 ExprVisitor::VisitExpr_(op);
56 }
57
58 bool is_elem_wise_{true};
59
60 private:
61 Array<IterVar> axis_;
62};
63
64bool IsElemWise(const Operation& op) {
65 if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
66 ElemWiseDetector v = ElemWiseDetector(compute->axis);
67 for (auto& e : compute->body) v(e);
68 return v.is_elem_wise_;
69 }
70 return false;
71}
72
73void AutoInlineElemWise(Schedule sch) {
74 for (Stage s : sch->stages) {
75 if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) {
76 s.compute_inline();
77 }
78 }
79}
80
81bool IsBroadcast(const Operation& op) {
82 if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
83 if (compute->reduce_axis.size()) {
84 return false;
85 }
86 constexpr auto kBroadcast = "broadcast";
87 // broadcast op in topi has tag `broadcast`
88 if (op->tag == kBroadcast) {
89 return true;
90 }
91 }
92 return false;
93}
94
95void AutoInlineBroadcast(Schedule sch) {
96 for (Stage s : sch->stages) {
97 if (!s.is_scheduled() && IsBroadcast(s->op) && !s->is_output) {
98 s.compute_inline();
99 }
100 }
101}
102
103bool IsInjective(const Operation& op) {
104 if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
105 return compute->reduce_axis.size() == 0;
106 }
107 return false;
108}
109
110void AutoInlineInjective(Schedule sch) {
111 for (Stage s : sch->stages) {
112 if (!s.is_scheduled() && IsInjective(s->op) && !s->is_output) {
113 s.compute_inline();
114 }
115 }
116}
117
118TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise").set_body_typed(AutoInlineElemWise);
119
120TVM_REGISTER_GLOBAL("schedule.AutoInlineBroadcast").set_body_typed(AutoInlineBroadcast);
121
122TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective").set_body_typed(AutoInlineInjective);
123
124} // namespace te
125} // namespace tvm
126