1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file ad_utils.cc
22 * \brief Utility for tensor-level auto-differentiation.
23 */
24#include "ad_utils.h"
25
26#include <tvm/tir/expr.h>
27#include <tvm/tir/stmt_functor.h>
28
29#include <set>
30#include <string>
31
32#include "../schedule/operation_inline.h"
33
34namespace tvm {
35namespace te {
36
37std::pair<Array<IterVar>, Map<Var, PrimExpr>> CloneIterVars(const Array<IterVar>& vars) {
38 Array<IterVar> new_vars;
39 Map<Var, PrimExpr> vmap;
40 for (const IterVar& iv : vars) {
41 IterVar new_v = IterVar(iv->dom, iv->var.copy_with_suffix(""), iv->iter_type, iv->thread_tag);
42 new_vars.push_back(new_v);
43 vmap.Set(iv->var, new_v->var);
44 }
45 return std::make_pair(std::move(new_vars), std::move(vmap));
46}
47
48PrimExpr CloneReduction(const PrimExpr& expr) {
49 if (const ReduceNode* red = expr.as<ReduceNode>()) {
50 auto [new_axis, vmap] = CloneIterVars(red->axis);
51
52 Array<PrimExpr> src_with_newaxis;
53 for (const auto& src : red->source) {
54 src_with_newaxis.push_back(tir::Substitute(src, vmap));
55 }
56 Array<PrimExpr> init_with_newaxis;
57 for (const auto& init : red->init) {
58 init_with_newaxis.push_back(tir::Substitute(init, vmap));
59 }
60
61 return Reduce(red->combiner, src_with_newaxis, new_axis, tir::Substitute(red->condition, vmap),
62 red->value_index, init_with_newaxis);
63 } else {
64 return expr;
65 }
66}
67
68Operation ComputeOpFromExprs(const Array<PrimExpr>& exprs, const Array<IterVar>& axis,
69 const std::string& name, const std::string& tag,
70 const Map<String, ObjectRef>& attrs, bool clone_axis) {
71 if (clone_axis) {
72 auto [new_axis, vmap] = CloneIterVars(axis);
73 Array<PrimExpr> new_exprs;
74 for (const PrimExpr& e : exprs) {
75 new_exprs.push_back(Substitute(CloneReduction(e), vmap));
76 }
77 return ComputeOpFromExprs(new_exprs, new_axis, name, tag, attrs, false);
78 }
79
80 Array<PrimExpr> new_exprs;
81
82 // If this is a reduction then we have to replicate it
83 if (const ReduceNode* red = exprs[0].as<ReduceNode>()) {
84 for (size_t i = 0; i < red->source.size(); ++i) {
85 PrimExpr ith_red =
86 Reduce(red->combiner, red->source, red->axis, red->condition, i, red->init);
87 new_exprs.push_back(ith_red);
88 }
89 } else {
90 new_exprs = exprs;
91 }
92
93 return ComputeOp(name, tag, attrs, axis, new_exprs);
94}
95
96Tensor TensorFromExpr(const PrimExpr& expr, const Array<IterVar>& axis, const std::string& name,
97 const std::string& tag, const Map<String, ObjectRef>& attrs,
98 bool clone_axis) {
99 int new_value_index = 0;
100 if (const ReduceNode* red = expr.as<ReduceNode>()) {
101 new_value_index = red->value_index;
102 }
103 return ComputeOpFromExprs({expr}, axis, name, tag, attrs, clone_axis).output(new_value_index);
104}
105
106Tensor TransformTensorBody(
107 const Tensor& tensor,
108 const std::function<PrimExpr(const PrimExpr&, const Array<IterVar>&)>& func) {
109 if (const ComputeOpNode* op = tensor->op.as<ComputeOpNode>()) {
110 // Transform only one body
111 PrimExpr new_body = func(op->body[tensor->value_index], op->axis);
112
113 // If the body didn't change then we can return the same tensor
114 if (new_body.same_as(op->body[tensor->value_index])) {
115 return tensor;
116 }
117
118 return TensorFromExpr(new_body, op->axis, op->name, op->tag, op->attrs);
119 } else {
120 return tensor;
121 }
122}
123
124Tensor TransformTensorBody(const Tensor& tensor,
125 const std::function<PrimExpr(const PrimExpr&)>& func) {
126 return TransformTensorBody(tensor,
127 [func](const PrimExpr& e, const Array<IterVar>&) { return func(e); });
128}
129
130// If expr is a Tensor Access node, perform inlining, otherwise do nothing
131PrimExpr InlineImmediateTensorAccess(const PrimExpr& expr) {
132 if (const ProducerLoadNode* op = expr.as<ProducerLoadNode>()) {
133 auto tensor = Downcast<te::Tensor>(op->producer);
134 if (const ComputeOpNode* op_comp = tensor->op.as<ComputeOpNode>()) {
135 Array<Var> tensor_axes;
136 for (const auto& var : op_comp->axis) {
137 tensor_axes.push_back(var->var);
138 }
139
140 Stmt inlined =
141 Inline(Evaluate(expr), tensor->op, tensor_axes, op_comp->body[tensor->value_index]);
142 if (const EvaluateNode* ev = inlined.as<EvaluateNode>()) {
143 // If it is a reduction, clone it
144 return CloneReduction(ev->value);
145 }
146 }
147 }
148 return expr;
149}
150
151// Implements InlineTensors by trying to inline every Call of the given Expr
152class InlineTensorsMutator : public ExprMutator {
153 public:
154 explicit InlineTensorsMutator(const Array<Tensor>& inlineable, bool inline_reductions = false)
155 : inline_reductions_(inline_reductions) {
156 for (const Tensor& tensor : inlineable) {
157 inlineable_.emplace(tensor->op.operator->(), tensor->value_index);
158 }
159 }
160
161 PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
162 auto tensor = Downcast<te::Tensor>(op->producer);
163 if (const ComputeOpNode* op_comp = tensor->op.as<ComputeOpNode>()) {
164 // Inline only if the array of inlineable tensors is empty or contains this tensor
165 if (inlineable_.empty() || inlineable_.count({op_comp, tensor->value_index})) {
166 // Inline only compute nodes that are not reductions (unless inline reductions is allowed)
167 if (inline_reductions_ || !op_comp->body[0].as<ReduceNode>()) {
168 PrimExpr expr = GetRef<PrimExpr>(op);
169 // Inline this tensor access and then try to perform further inlining
170 return VisitExpr(InlineImmediateTensorAccess(expr));
171 }
172 }
173 }
174 // If we cannot inline this call, we should try to do inlining in its arguments
175 return ExprMutator::VisitExpr_(op);
176 }
177
178 private:
179 // Tensors which are allowed to be inlined, represented as pairs (op_node, value_index)
180 std::set<std::pair<const OperationNode*, int>> inlineable_;
181 bool inline_reductions_;
182};
183
184Tensor InlineTensorAccess(const Tensor& tensor, const Array<Tensor>& inlineable,
185 bool inline_reductions) {
186 auto transformation = [inlineable, inline_reductions](const PrimExpr& e) {
187 return InlineTensorsMutator(inlineable, inline_reductions)(e);
188 };
189 return TransformTensorBody(tensor, transformation);
190}
191
192Tensor InlineTailTensorAccess(const Tensor& tensor) {
193 return TransformTensorBody(tensor, InlineImmediateTensorAccess);
194}
195
196} // namespace te
197} // namespace tvm
198