1 | /* |
---|---|
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_inline_elem_wise.cc |
22 | */ |
23 | #include <tvm/runtime/registry.h> |
24 | #include <tvm/te/operation.h> |
25 | #include <tvm/te/schedule_pass.h> |
26 | #include <tvm/tir/expr_functor.h> |
27 | |
28 | namespace tvm { |
29 | namespace te { |
30 | |
31 | using namespace tir; |
32 | |
33 | class ElemWiseDetector : public tir::ExprVisitor { |
34 | public: |
35 | explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {} |
36 | |
37 | void VisitExpr(const PrimExpr& e) final { |
38 | if (!is_elem_wise_) return; |
39 | ExprVisitor::VisitExpr(e); |
40 | } |
41 | |
42 | void VisitExpr_(const ProducerLoadNode* op) final { |
43 | Array<PrimExpr> indices = op->indices; |
44 | if (axis_.size() != indices.size()) { |
45 | is_elem_wise_ = false; |
46 | return; |
47 | } |
48 | |
49 | for (size_t i = 0; i < axis_.size(); ++i) { |
50 | if (!indices[i].same_as(axis_[i]->var)) { |
51 | is_elem_wise_ = false; |
52 | return; |
53 | } |
54 | } |
55 | ExprVisitor::VisitExpr_(op); |
56 | } |
57 | |
58 | bool is_elem_wise_{true}; |
59 | |
60 | private: |
61 | Array<IterVar> axis_; |
62 | }; |
63 | |
64 | bool IsElemWise(const Operation& op) { |
65 | if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) { |
66 | ElemWiseDetector v = ElemWiseDetector(compute->axis); |
67 | for (auto& e : compute->body) v(e); |
68 | return v.is_elem_wise_; |
69 | } |
70 | return false; |
71 | } |
72 | |
73 | void AutoInlineElemWise(Schedule sch) { |
74 | for (Stage s : sch->stages) { |
75 | if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) { |
76 | s.compute_inline(); |
77 | } |
78 | } |
79 | } |
80 | |
81 | bool IsBroadcast(const Operation& op) { |
82 | if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) { |
83 | if (compute->reduce_axis.size()) { |
84 | return false; |
85 | } |
86 | constexpr auto kBroadcast = "broadcast"; |
87 | // broadcast op in topi has tag `broadcast` |
88 | if (op->tag == kBroadcast) { |
89 | return true; |
90 | } |
91 | } |
92 | return false; |
93 | } |
94 | |
95 | void AutoInlineBroadcast(Schedule sch) { |
96 | for (Stage s : sch->stages) { |
97 | if (!s.is_scheduled() && IsBroadcast(s->op) && !s->is_output) { |
98 | s.compute_inline(); |
99 | } |
100 | } |
101 | } |
102 | |
103 | bool IsInjective(const Operation& op) { |
104 | if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) { |
105 | return compute->reduce_axis.size() == 0; |
106 | } |
107 | return false; |
108 | } |
109 | |
110 | void AutoInlineInjective(Schedule sch) { |
111 | for (Stage s : sch->stages) { |
112 | if (!s.is_scheduled() && IsInjective(s->op) && !s->is_output) { |
113 | s.compute_inline(); |
114 | } |
115 | } |
116 | } |
117 | |
118 | TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise").set_body_typed(AutoInlineElemWise); |
119 | |
120 | TVM_REGISTER_GLOBAL("schedule.AutoInlineBroadcast").set_body_typed(AutoInlineBroadcast); |
121 | |
122 | TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective").set_body_typed(AutoInlineInjective); |
123 | |
124 | } // namespace te |
125 | } // namespace tvm |
126 |