1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file topi/einsum.cc
22 * \brief Einstein summation op
23 */
24#include <tvm/topi/broadcast.h>
25#include <tvm/topi/einsum.h>
26
27namespace tvm {
28namespace topi {
29
30EinsumEquation EinsumEquation::FromString(const std::string& equation) {
31 EinsumEquation result;
32 Subscript current;
33 bool has_arrow = false;
34 bool has_ellipsis = false;
35
36 for (int i = 0, n = equation.size(); i < n; ++i) {
37 switch (equation[i]) {
38 case ' ':
39 // Ignore spaces
40 break;
41 case '-':
42 // Arrow
43 CHECK(!has_arrow) << "Equation can only have one arrow";
44 CHECK(i + 1 < n && equation[i + 1] == '>')
45 << "Cannot parse the Einsum equation: invalid arrow";
46 i++;
47 has_arrow = true;
48 [[fallthrough]];
49 case ',':
50 // Delimiter between inputs, push current and start a new one
51 result.inputs.emplace_back(current);
52 current.clear();
53 has_ellipsis = false;
54 break;
55 case '.':
56 // Ellipsis
57 CHECK(!has_ellipsis) << "Ellipsis can only appear once for each input and output";
58 CHECK(i + 2 < n && equation[i + 1] == '.' && equation[i + 2] == '.')
59 << "Cannot parse the Einsum equation: invalid ellipsis";
60 current.push_back(kEllipsis);
61 has_ellipsis = true;
62 i += 2;
63 break;
64 default:
65 // Default case: current character is a subscript label
66 CHECK(std::isalpha(equation[i])) << "Cannot parse the Einsum equation: invalid character "
67 << equation[i] << " in equation " << equation;
68 current.emplace_back(equation[i]);
69 break;
70 }
71 }
72
73 if (has_arrow) {
74 // If there is an arrow, the last subscript is the output
75 result.output = current;
76 } else {
77 // Otherwise, the equation is in implicit mode, and the last subscript is an input
78 result.inputs.emplace_back(current);
79 }
80
81 // Convert the equation to explicit mode if it is in implicit mode
82 if (!has_arrow) {
83 // The output of the implicit mode is all repeated labels sorted in alphabetical order and the
84 // ellipsis in the leftmost if it exists in the inputs.
85 std::map<char, int> label_counts;
86 for (const Subscript& subscript : result.inputs) {
87 for (char label : subscript) {
88 label_counts[label]++;
89 }
90 }
91 for (auto [label, count] : label_counts) {
92 if (label == kEllipsis || count == 1) {
93 result.output.emplace_back(label);
94 }
95 }
96 }
97 return result;
98}
99
100PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr& extent2) {
101 int64_t extent1_value = GetConstInt(extent1);
102 int64_t extent2_value = GetConstInt(extent2);
103 if (extent1_value == extent2_value) {
104 return extent1;
105 } else if (extent1_value == 1 || extent2_value == 1) {
106 return Integer(std::max(extent1_value, extent2_value));
107 }
108 LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2;
109 throw;
110}
111
112PrimExpr GetIndexForBroadcastedDim(const Var& index, const PrimExpr& extent,
113 const PrimExpr& broadcasted_extent) {
114 if (GetConstInt(extent) == GetConstInt(broadcasted_extent)) {
115 return index;
116 } else {
117 return Integer(0);
118 }
119}
120
121/*! \brief The compute builder for Einsum */
122class EinsumBuilder {
123 public:
124 /*!
125 * \brief The constructor
126 * \param equation The Einsum equation
127 * \param input_shapes The shapes of the input tensors
128 */
129 EinsumBuilder(EinsumEquation equation, Array<Array<PrimExpr>> input_shapes)
130 : equation_(equation), input_shapes_(input_shapes) {}
131
132 /*!
133 * \brief Run the shape inference
134 * \return The inferred shape of the output
135 */
136 Array<PrimExpr> InferShape() {
137 CHECK_EQ(equation_.inputs.size(), input_shapes_.size())
138 << "Number of operands does not match the "
139 "equation";
140
141 std::vector<Array<PrimExpr>>
142 ellipis_shapes; // the sub-shape covered by the ellipsis for each operand
143
144 // Step 1: Collect the broadcasted extent for each label
145 for (int operand_index = 0; operand_index < static_cast<int>(input_shapes_.size());
146 ++operand_index) {
147 const EinsumEquation::Subscript subscript = equation_.inputs[operand_index];
148 const Array<PrimExpr>& input_shape = input_shapes_[operand_index];
149
150 int current_dim = 0;
151 for (auto label : subscript) {
152 if (label == EinsumEquation::kEllipsis) {
153 // Find the sub-shape covered by the ellipsis
154 int ellipsis_ndim =
155 static_cast<int>(input_shape.size()) - static_cast<int>(subscript.size()) + 1;
156 ellipis_shapes.emplace_back(input_shape.begin() + current_dim,
157 input_shape.begin() + current_dim + ellipsis_ndim);
158 current_dim += ellipsis_ndim;
159 } else {
160 const PrimExpr& extent = input_shape[current_dim++];
161 auto it = label_to_extent_.find(label);
162 if (it == label_to_extent_.end()) {
163 label_to_extent_[label] = extent;
164 } else {
165 it->second = GetBroadcastedExtent(it->second, extent);
166 }
167 }
168 }
169 ICHECK_EQ(current_dim, input_shape.size());
170 }
171
172 // Step 2: Infer the shape of the ellipsis if exists
173 // The ellipsis may cover different number of dimensions for each operand, these sub-shapes
174 // need to be broadcasted to the shape with the maximum number of dimensions
175 Array<PrimExpr> ellipsis_shape;
176 if (ellipis_shapes.size()) {
177 ellipsis_shape = *std::max_element(
178 ellipis_shapes.begin(), ellipis_shapes.end(),
179 [](const Array<PrimExpr>& a, const Array<PrimExpr>& b) { return a.size() < b.size(); });
180 for (const Array<PrimExpr>& shape : ellipis_shapes) {
181 auto common_shape = detail::BroadcastShape(ellipsis_shape, shape).common_shape;
182 ellipsis_shape = Array<PrimExpr>(common_shape.begin(), common_shape.end());
183 }
184 }
185
186 // Step 3: Infer output shape based on infered extent for each label
187 for (auto label : equation_.output) {
188 if (label == EinsumEquation::kEllipsis) {
189 output_shape_.insert(output_shape_.end(), ellipsis_shape.begin(), ellipsis_shape.end());
190 } else {
191 output_shape_.push_back(label_to_extent_[label]);
192 }
193 }
194 ellipsis_shape_ = std::move(ellipsis_shape);
195 return output_shape_;
196 }
197
198 PrimExpr BuildOutputExpr(const Array<Tensor> inputs, const Array<Var>& indices) {
199 std::unordered_map<EinsumEquation::Label, Var> label_to_index;
200 Array<Var> ellipsis_indices;
201 Array<IterVar> reduce_axes;
202
203 PrepareOutputIndicesMapping(indices, &label_to_index, &ellipsis_indices);
204 PrepareReductionIndicesMapping(indices, &label_to_index, &ellipsis_indices, &reduce_axes);
205
206 auto zero = make_zero(inputs[0]->dtype);
207
208 PrimExpr result = zero;
209 for (int i = 0, n = static_cast<int>(inputs.size()); i < n; ++i) {
210 auto term = inputs[i](GetIndicesForOperand(i, label_to_index, ellipsis_indices));
211 if (i == 0) {
212 result = term;
213 } else {
214 result = result * term;
215 }
216 }
217 if (reduce_axes.size() > 0) {
218 result = sum(result, reduce_axes, {zero});
219 }
220 return result;
221 }
222
223 private:
224 /*!
225 * \brief Prepare mapping from label (including ellipsis) to the output indices
226 */
227 void PrepareOutputIndicesMapping(const Array<Var>& indices,
228 std::unordered_map<EinsumEquation::Label, Var>* label_to_index,
229 Array<Var>* ellipsis_indices) {
230 int i = 0;
231 for (auto label : equation_.output) {
232 if (label == EinsumEquation::kEllipsis) {
233 auto ellipsis_ndim = ellipsis_shape_.value().size();
234 *ellipsis_indices = Array<Var>(indices.begin() + i, indices.begin() + i + ellipsis_ndim);
235 i += ellipsis_ndim;
236 } else {
237 label_to_index->emplace(label, indices[i++]);
238 }
239 }
240 ICHECK_EQ(i, indices.size());
241 }
242
243 /*!
244 * \brief Create reduction axes and prepare mapping from reduction label (including ellipsis if
245 * necessary) to the reduction axes
246 */
247 void PrepareReductionIndicesMapping(
248 const Array<Var>& indices, std::unordered_map<EinsumEquation::Label, Var>* label_to_index,
249 Array<Var>* ellipsis_indices, Array<IterVar>* reduction_axes) {
250 // Collect labels that need to be reduced, which is the union(input_labels) - output_labels
251 std::set<char> reduction_labels;
252 for (const EinsumEquation::Subscript& subscript : equation_.inputs) {
253 reduction_labels.insert(subscript.begin(), subscript.end());
254 }
255 for (auto label : equation_.output) {
256 reduction_labels.erase(label);
257 }
258
259 // Create reduction axes.The order of the reduction axes is not specified in the Einsum
260 // equation. Here we sort them alphabetically, with the ellipsis axes at the
261 // beginning if exists.
262 for (auto label : reduction_labels) {
263 if (label == EinsumEquation::kEllipsis) {
264 // Ellipsis
265 auto ellipsis_shape = ellipsis_shape_.value();
266 for (int i = 0; i < static_cast<int>(ellipsis_shape.size()); ++i) {
267 reduction_axes->push_back(
268 IterVar(Range(0, ellipsis_shape[i]), Var("k"), IterVarType::kCommReduce));
269 ellipsis_indices->push_back(reduction_axes->back()->var);
270 }
271 } else {
272 // Normal label
273 reduction_axes->push_back(IterVar(Range(0, label_to_extent_[label]),
274 Var(std::string(1, label)), IterVarType::kCommReduce));
275 label_to_index->emplace(label, reduction_axes->back()->var);
276 }
277 }
278 }
279
280 Array<PrimExpr> GetIndicesForOperand(
281 int operand_index, const std::unordered_map<EinsumEquation::Label, Var>& label_to_index,
282 const Array<Var>& ellipsis_indices) {
283 const EinsumEquation::Subscript& subscript = equation_.inputs[operand_index];
284 Array<PrimExpr> indices; // the indices for the operand
285 const Array<PrimExpr> input_shape = input_shapes_[operand_index];
286
287 int i = 0; // index of the operand shape
288 for (char label : subscript) {
289 if (label == EinsumEquation::kEllipsis) {
290 // Ellipsis
291 Array<PrimExpr> ellipsis_shape = ellipsis_shape_.value();
292 int ellipsis_ndim =
293 static_cast<int>(input_shape.size()) - static_cast<int>(subscript.size()) + 1;
294 // use last 'ellipsis_ndim' axes
295 for (int j = static_cast<int>(ellipsis_indices.size()) - ellipsis_ndim;
296 j < static_cast<int>(ellipsis_indices.size()); ++j) {
297 indices.push_back(
298 GetIndexForBroadcastedDim(ellipsis_indices[j], input_shape[i++], ellipsis_shape[j]));
299 }
300 } else {
301 // Normal label
302 indices.push_back(GetIndexForBroadcastedDim(label_to_index.at(label), input_shape[i++],
303 label_to_extent_.at(label)));
304 }
305 }
306 ICHECK_EQ(i, input_shape.size());
307 ICHECK_EQ(indices.size(), input_shape.size());
308 return indices;
309 }
310
311 EinsumEquation equation_;
312 Array<Array<PrimExpr>> input_shapes_;
313
314 // intermediate results of shape inference
315
316 // The output shape
317 Array<PrimExpr> output_shape_;
318 // The extent of each label with broadcast rules applied
319 std::unordered_map<EinsumEquation::Label, PrimExpr> label_to_extent_;
320 // The shape of the ellipsis if ellipsis is used. The shape covered by the
321 // ellipsis in each operand might be different from this, this is the common
322 // shape among them according to the broadcast rules.
323 Optional<Array<PrimExpr>> ellipsis_shape_;
324};
325
326Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs, std::string name,
327 std::string tag) {
328 EinsumEquation equation = EinsumEquation::FromString(subscripts_str);
329 Array<Array<PrimExpr>> input_shapes;
330 for (const Tensor& input : inputs) {
331 input_shapes.push_back(input->shape);
332 }
333 EinsumBuilder einsum_builder = EinsumBuilder(equation, input_shapes);
334 auto output_shape = einsum_builder.InferShape();
335 return te::compute(
336 output_shape,
337 [&](const Array<Var>& indices) { return einsum_builder.BuildOutputExpr(inputs, indices); },
338 name, tag);
339}
340
341Array<PrimExpr> InferEinsumShape(const std::string& subscripts,
342 const std::vector<Array<PrimExpr>>& operands) {
343 EinsumEquation equation = EinsumEquation::FromString(subscripts);
344 EinsumBuilder einsum_builder = EinsumBuilder(equation, operands);
345 return einsum_builder.InferShape();
346}
347
348TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) {
349 *rv = einsum(args[0], args[1]);
350});
351
352} // namespace topi
353} // namespace tvm
354