1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file ad_utils.cc |
22 | * \brief Utility for tensor-level auto-differentiation. |
23 | */ |
24 | #include "ad_utils.h" |
25 | |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | |
29 | #include <set> |
30 | #include <string> |
31 | |
32 | #include "../schedule/operation_inline.h" |
33 | |
34 | namespace tvm { |
35 | namespace te { |
36 | |
37 | std::pair<Array<IterVar>, Map<Var, PrimExpr>> CloneIterVars(const Array<IterVar>& vars) { |
38 | Array<IterVar> new_vars; |
39 | Map<Var, PrimExpr> vmap; |
40 | for (const IterVar& iv : vars) { |
41 | IterVar new_v = IterVar(iv->dom, iv->var.copy_with_suffix("" ), iv->iter_type, iv->thread_tag); |
42 | new_vars.push_back(new_v); |
43 | vmap.Set(iv->var, new_v->var); |
44 | } |
45 | return std::make_pair(std::move(new_vars), std::move(vmap)); |
46 | } |
47 | |
48 | PrimExpr CloneReduction(const PrimExpr& expr) { |
49 | if (const ReduceNode* red = expr.as<ReduceNode>()) { |
50 | auto [new_axis, vmap] = CloneIterVars(red->axis); |
51 | |
52 | Array<PrimExpr> src_with_newaxis; |
53 | for (const auto& src : red->source) { |
54 | src_with_newaxis.push_back(tir::Substitute(src, vmap)); |
55 | } |
56 | Array<PrimExpr> init_with_newaxis; |
57 | for (const auto& init : red->init) { |
58 | init_with_newaxis.push_back(tir::Substitute(init, vmap)); |
59 | } |
60 | |
61 | return Reduce(red->combiner, src_with_newaxis, new_axis, tir::Substitute(red->condition, vmap), |
62 | red->value_index, init_with_newaxis); |
63 | } else { |
64 | return expr; |
65 | } |
66 | } |
67 | |
68 | Operation ComputeOpFromExprs(const Array<PrimExpr>& exprs, const Array<IterVar>& axis, |
69 | const std::string& name, const std::string& tag, |
70 | const Map<String, ObjectRef>& attrs, bool clone_axis) { |
71 | if (clone_axis) { |
72 | auto [new_axis, vmap] = CloneIterVars(axis); |
73 | Array<PrimExpr> new_exprs; |
74 | for (const PrimExpr& e : exprs) { |
75 | new_exprs.push_back(Substitute(CloneReduction(e), vmap)); |
76 | } |
77 | return ComputeOpFromExprs(new_exprs, new_axis, name, tag, attrs, false); |
78 | } |
79 | |
80 | Array<PrimExpr> new_exprs; |
81 | |
82 | // If this is a reduction then we have to replicate it |
83 | if (const ReduceNode* red = exprs[0].as<ReduceNode>()) { |
84 | for (size_t i = 0; i < red->source.size(); ++i) { |
85 | PrimExpr ith_red = |
86 | Reduce(red->combiner, red->source, red->axis, red->condition, i, red->init); |
87 | new_exprs.push_back(ith_red); |
88 | } |
89 | } else { |
90 | new_exprs = exprs; |
91 | } |
92 | |
93 | return ComputeOp(name, tag, attrs, axis, new_exprs); |
94 | } |
95 | |
96 | Tensor TensorFromExpr(const PrimExpr& expr, const Array<IterVar>& axis, const std::string& name, |
97 | const std::string& tag, const Map<String, ObjectRef>& attrs, |
98 | bool clone_axis) { |
99 | int new_value_index = 0; |
100 | if (const ReduceNode* red = expr.as<ReduceNode>()) { |
101 | new_value_index = red->value_index; |
102 | } |
103 | return ComputeOpFromExprs({expr}, axis, name, tag, attrs, clone_axis).output(new_value_index); |
104 | } |
105 | |
106 | Tensor TransformTensorBody( |
107 | const Tensor& tensor, |
108 | const std::function<PrimExpr(const PrimExpr&, const Array<IterVar>&)>& func) { |
109 | if (const ComputeOpNode* op = tensor->op.as<ComputeOpNode>()) { |
110 | // Transform only one body |
111 | PrimExpr new_body = func(op->body[tensor->value_index], op->axis); |
112 | |
113 | // If the body didn't change then we can return the same tensor |
114 | if (new_body.same_as(op->body[tensor->value_index])) { |
115 | return tensor; |
116 | } |
117 | |
118 | return TensorFromExpr(new_body, op->axis, op->name, op->tag, op->attrs); |
119 | } else { |
120 | return tensor; |
121 | } |
122 | } |
123 | |
124 | Tensor TransformTensorBody(const Tensor& tensor, |
125 | const std::function<PrimExpr(const PrimExpr&)>& func) { |
126 | return TransformTensorBody(tensor, |
127 | [func](const PrimExpr& e, const Array<IterVar>&) { return func(e); }); |
128 | } |
129 | |
130 | // If expr is a Tensor Access node, perform inlining, otherwise do nothing |
131 | PrimExpr InlineImmediateTensorAccess(const PrimExpr& expr) { |
132 | if (const ProducerLoadNode* op = expr.as<ProducerLoadNode>()) { |
133 | auto tensor = Downcast<te::Tensor>(op->producer); |
134 | if (const ComputeOpNode* op_comp = tensor->op.as<ComputeOpNode>()) { |
135 | Array<Var> tensor_axes; |
136 | for (const auto& var : op_comp->axis) { |
137 | tensor_axes.push_back(var->var); |
138 | } |
139 | |
140 | Stmt inlined = |
141 | Inline(Evaluate(expr), tensor->op, tensor_axes, op_comp->body[tensor->value_index]); |
142 | if (const EvaluateNode* ev = inlined.as<EvaluateNode>()) { |
143 | // If it is a reduction, clone it |
144 | return CloneReduction(ev->value); |
145 | } |
146 | } |
147 | } |
148 | return expr; |
149 | } |
150 | |
151 | // Implements InlineTensors by trying to inline every Call of the given Expr |
152 | class InlineTensorsMutator : public ExprMutator { |
153 | public: |
154 | explicit InlineTensorsMutator(const Array<Tensor>& inlineable, bool inline_reductions = false) |
155 | : inline_reductions_(inline_reductions) { |
156 | for (const Tensor& tensor : inlineable) { |
157 | inlineable_.emplace(tensor->op.operator->(), tensor->value_index); |
158 | } |
159 | } |
160 | |
161 | PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
162 | auto tensor = Downcast<te::Tensor>(op->producer); |
163 | if (const ComputeOpNode* op_comp = tensor->op.as<ComputeOpNode>()) { |
164 | // Inline only if the array of inlineable tensors is empty or contains this tensor |
165 | if (inlineable_.empty() || inlineable_.count({op_comp, tensor->value_index})) { |
166 | // Inline only compute nodes that are not reductions (unless inline reductions is allowed) |
167 | if (inline_reductions_ || !op_comp->body[0].as<ReduceNode>()) { |
168 | PrimExpr expr = GetRef<PrimExpr>(op); |
169 | // Inline this tensor access and then try to perform further inlining |
170 | return VisitExpr(InlineImmediateTensorAccess(expr)); |
171 | } |
172 | } |
173 | } |
174 | // If we cannot inline this call, we should try to do inlining in its arguments |
175 | return ExprMutator::VisitExpr_(op); |
176 | } |
177 | |
178 | private: |
179 | // Tensors which are allowed to be inlined, represented as pairs (op_node, value_index) |
180 | std::set<std::pair<const OperationNode*, int>> inlineable_; |
181 | bool inline_reductions_; |
182 | }; |
183 | |
184 | Tensor InlineTensorAccess(const Tensor& tensor, const Array<Tensor>& inlineable, |
185 | bool inline_reductions) { |
186 | auto transformation = [inlineable, inline_reductions](const PrimExpr& e) { |
187 | return InlineTensorsMutator(inlineable, inline_reductions)(e); |
188 | }; |
189 | return TransformTensorBody(tensor, transformation); |
190 | } |
191 | |
192 | Tensor InlineTailTensorAccess(const Tensor& tensor) { |
193 | return TransformTensorBody(tensor, InlineImmediateTensorAccess); |
194 | } |
195 | |
196 | } // namespace te |
197 | } // namespace tvm |
198 | |