1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file topi/einsum.cc |
22 | * \brief Einstein summation op |
23 | */ |
24 | #include <tvm/topi/broadcast.h> |
25 | #include <tvm/topi/einsum.h> |
26 | |
27 | namespace tvm { |
28 | namespace topi { |
29 | |
30 | EinsumEquation EinsumEquation::FromString(const std::string& equation) { |
31 | EinsumEquation result; |
32 | Subscript current; |
33 | bool has_arrow = false; |
34 | bool has_ellipsis = false; |
35 | |
36 | for (int i = 0, n = equation.size(); i < n; ++i) { |
37 | switch (equation[i]) { |
38 | case ' ': |
39 | // Ignore spaces |
40 | break; |
41 | case '-': |
42 | // Arrow |
43 | CHECK(!has_arrow) << "Equation can only have one arrow" ; |
44 | CHECK(i + 1 < n && equation[i + 1] == '>') |
45 | << "Cannot parse the Einsum equation: invalid arrow" ; |
46 | i++; |
47 | has_arrow = true; |
48 | [[fallthrough]]; |
49 | case ',': |
50 | // Delimiter between inputs, push current and start a new one |
51 | result.inputs.emplace_back(current); |
52 | current.clear(); |
53 | has_ellipsis = false; |
54 | break; |
55 | case '.': |
56 | // Ellipsis |
57 | CHECK(!has_ellipsis) << "Ellipsis can only appear once for each input and output" ; |
58 | CHECK(i + 2 < n && equation[i + 1] == '.' && equation[i + 2] == '.') |
59 | << "Cannot parse the Einsum equation: invalid ellipsis" ; |
60 | current.push_back(kEllipsis); |
61 | has_ellipsis = true; |
62 | i += 2; |
63 | break; |
64 | default: |
65 | // Default case: current character is a subscript label |
66 | CHECK(std::isalpha(equation[i])) << "Cannot parse the Einsum equation: invalid character " |
67 | << equation[i] << " in equation " << equation; |
68 | current.emplace_back(equation[i]); |
69 | break; |
70 | } |
71 | } |
72 | |
73 | if (has_arrow) { |
74 | // If there is an arrow, the last subscript is the output |
75 | result.output = current; |
76 | } else { |
77 | // Otherwise, the equation is in implicit mode, and the last subscript is an input |
78 | result.inputs.emplace_back(current); |
79 | } |
80 | |
81 | // Convert the equation to explicit mode if it is in implicit mode |
82 | if (!has_arrow) { |
83 | // The output of the implicit mode is all repeated labels sorted in alphabetical order and the |
84 | // ellipsis in the leftmost if it exists in the inputs. |
85 | std::map<char, int> label_counts; |
86 | for (const Subscript& subscript : result.inputs) { |
87 | for (char label : subscript) { |
88 | label_counts[label]++; |
89 | } |
90 | } |
91 | for (auto [label, count] : label_counts) { |
92 | if (label == kEllipsis || count == 1) { |
93 | result.output.emplace_back(label); |
94 | } |
95 | } |
96 | } |
97 | return result; |
98 | } |
99 | |
100 | PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr& extent2) { |
101 | int64_t extent1_value = GetConstInt(extent1); |
102 | int64_t extent2_value = GetConstInt(extent2); |
103 | if (extent1_value == extent2_value) { |
104 | return extent1; |
105 | } else if (extent1_value == 1 || extent2_value == 1) { |
106 | return Integer(std::max(extent1_value, extent2_value)); |
107 | } |
108 | LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2; |
109 | throw; |
110 | } |
111 | |
112 | PrimExpr GetIndexForBroadcastedDim(const Var& index, const PrimExpr& extent, |
113 | const PrimExpr& broadcasted_extent) { |
114 | if (GetConstInt(extent) == GetConstInt(broadcasted_extent)) { |
115 | return index; |
116 | } else { |
117 | return Integer(0); |
118 | } |
119 | } |
120 | |
121 | /*! \brief The compute builder for Einsum */ |
122 | class EinsumBuilder { |
123 | public: |
124 | /*! |
125 | * \brief The constructor |
126 | * \param equation The Einsum equation |
127 | * \param input_shapes The shapes of the input tensors |
128 | */ |
129 | EinsumBuilder(EinsumEquation equation, Array<Array<PrimExpr>> input_shapes) |
130 | : equation_(equation), input_shapes_(input_shapes) {} |
131 | |
132 | /*! |
133 | * \brief Run the shape inference |
134 | * \return The inferred shape of the output |
135 | */ |
136 | Array<PrimExpr> InferShape() { |
137 | CHECK_EQ(equation_.inputs.size(), input_shapes_.size()) |
138 | << "Number of operands does not match the " |
139 | "equation" ; |
140 | |
141 | std::vector<Array<PrimExpr>> |
142 | ellipis_shapes; // the sub-shape covered by the ellipsis for each operand |
143 | |
144 | // Step 1: Collect the broadcasted extent for each label |
145 | for (int operand_index = 0; operand_index < static_cast<int>(input_shapes_.size()); |
146 | ++operand_index) { |
147 | const EinsumEquation::Subscript subscript = equation_.inputs[operand_index]; |
148 | const Array<PrimExpr>& input_shape = input_shapes_[operand_index]; |
149 | |
150 | int current_dim = 0; |
151 | for (auto label : subscript) { |
152 | if (label == EinsumEquation::kEllipsis) { |
153 | // Find the sub-shape covered by the ellipsis |
154 | int ellipsis_ndim = |
155 | static_cast<int>(input_shape.size()) - static_cast<int>(subscript.size()) + 1; |
156 | ellipis_shapes.emplace_back(input_shape.begin() + current_dim, |
157 | input_shape.begin() + current_dim + ellipsis_ndim); |
158 | current_dim += ellipsis_ndim; |
159 | } else { |
160 | const PrimExpr& extent = input_shape[current_dim++]; |
161 | auto it = label_to_extent_.find(label); |
162 | if (it == label_to_extent_.end()) { |
163 | label_to_extent_[label] = extent; |
164 | } else { |
165 | it->second = GetBroadcastedExtent(it->second, extent); |
166 | } |
167 | } |
168 | } |
169 | ICHECK_EQ(current_dim, input_shape.size()); |
170 | } |
171 | |
172 | // Step 2: Infer the shape of the ellipsis if exists |
173 | // The ellipsis may cover different number of dimensions for each operand, these sub-shapes |
174 | // need to be broadcasted to the shape with the maximum number of dimensions |
175 | Array<PrimExpr> ellipsis_shape; |
176 | if (ellipis_shapes.size()) { |
177 | ellipsis_shape = *std::max_element( |
178 | ellipis_shapes.begin(), ellipis_shapes.end(), |
179 | [](const Array<PrimExpr>& a, const Array<PrimExpr>& b) { return a.size() < b.size(); }); |
180 | for (const Array<PrimExpr>& shape : ellipis_shapes) { |
181 | auto common_shape = detail::BroadcastShape(ellipsis_shape, shape).common_shape; |
182 | ellipsis_shape = Array<PrimExpr>(common_shape.begin(), common_shape.end()); |
183 | } |
184 | } |
185 | |
186 | // Step 3: Infer output shape based on infered extent for each label |
187 | for (auto label : equation_.output) { |
188 | if (label == EinsumEquation::kEllipsis) { |
189 | output_shape_.insert(output_shape_.end(), ellipsis_shape.begin(), ellipsis_shape.end()); |
190 | } else { |
191 | output_shape_.push_back(label_to_extent_[label]); |
192 | } |
193 | } |
194 | ellipsis_shape_ = std::move(ellipsis_shape); |
195 | return output_shape_; |
196 | } |
197 | |
198 | PrimExpr BuildOutputExpr(const Array<Tensor> inputs, const Array<Var>& indices) { |
199 | std::unordered_map<EinsumEquation::Label, Var> label_to_index; |
200 | Array<Var> ellipsis_indices; |
201 | Array<IterVar> reduce_axes; |
202 | |
203 | PrepareOutputIndicesMapping(indices, &label_to_index, &ellipsis_indices); |
204 | PrepareReductionIndicesMapping(indices, &label_to_index, &ellipsis_indices, &reduce_axes); |
205 | |
206 | auto zero = make_zero(inputs[0]->dtype); |
207 | |
208 | PrimExpr result = zero; |
209 | for (int i = 0, n = static_cast<int>(inputs.size()); i < n; ++i) { |
210 | auto term = inputs[i](GetIndicesForOperand(i, label_to_index, ellipsis_indices)); |
211 | if (i == 0) { |
212 | result = term; |
213 | } else { |
214 | result = result * term; |
215 | } |
216 | } |
217 | if (reduce_axes.size() > 0) { |
218 | result = sum(result, reduce_axes, {zero}); |
219 | } |
220 | return result; |
221 | } |
222 | |
223 | private: |
224 | /*! |
225 | * \brief Prepare mapping from label (including ellipsis) to the output indices |
226 | */ |
227 | void PrepareOutputIndicesMapping(const Array<Var>& indices, |
228 | std::unordered_map<EinsumEquation::Label, Var>* label_to_index, |
229 | Array<Var>* ellipsis_indices) { |
230 | int i = 0; |
231 | for (auto label : equation_.output) { |
232 | if (label == EinsumEquation::kEllipsis) { |
233 | auto ellipsis_ndim = ellipsis_shape_.value().size(); |
234 | *ellipsis_indices = Array<Var>(indices.begin() + i, indices.begin() + i + ellipsis_ndim); |
235 | i += ellipsis_ndim; |
236 | } else { |
237 | label_to_index->emplace(label, indices[i++]); |
238 | } |
239 | } |
240 | ICHECK_EQ(i, indices.size()); |
241 | } |
242 | |
243 | /*! |
244 | * \brief Create reduction axes and prepare mapping from reduction label (including ellipsis if |
245 | * necessary) to the reduction axes |
246 | */ |
247 | void PrepareReductionIndicesMapping( |
248 | const Array<Var>& indices, std::unordered_map<EinsumEquation::Label, Var>* label_to_index, |
249 | Array<Var>* ellipsis_indices, Array<IterVar>* reduction_axes) { |
250 | // Collect labels that need to be reduced, which is the union(input_labels) - output_labels |
251 | std::set<char> reduction_labels; |
252 | for (const EinsumEquation::Subscript& subscript : equation_.inputs) { |
253 | reduction_labels.insert(subscript.begin(), subscript.end()); |
254 | } |
255 | for (auto label : equation_.output) { |
256 | reduction_labels.erase(label); |
257 | } |
258 | |
259 | // Create reduction axes.The order of the reduction axes is not specified in the Einsum |
260 | // equation. Here we sort them alphabetically, with the ellipsis axes at the |
261 | // beginning if exists. |
262 | for (auto label : reduction_labels) { |
263 | if (label == EinsumEquation::kEllipsis) { |
264 | // Ellipsis |
265 | auto ellipsis_shape = ellipsis_shape_.value(); |
266 | for (int i = 0; i < static_cast<int>(ellipsis_shape.size()); ++i) { |
267 | reduction_axes->push_back( |
268 | IterVar(Range(0, ellipsis_shape[i]), Var("k" ), IterVarType::kCommReduce)); |
269 | ellipsis_indices->push_back(reduction_axes->back()->var); |
270 | } |
271 | } else { |
272 | // Normal label |
273 | reduction_axes->push_back(IterVar(Range(0, label_to_extent_[label]), |
274 | Var(std::string(1, label)), IterVarType::kCommReduce)); |
275 | label_to_index->emplace(label, reduction_axes->back()->var); |
276 | } |
277 | } |
278 | } |
279 | |
280 | Array<PrimExpr> GetIndicesForOperand( |
281 | int operand_index, const std::unordered_map<EinsumEquation::Label, Var>& label_to_index, |
282 | const Array<Var>& ellipsis_indices) { |
283 | const EinsumEquation::Subscript& subscript = equation_.inputs[operand_index]; |
284 | Array<PrimExpr> indices; // the indices for the operand |
285 | const Array<PrimExpr> input_shape = input_shapes_[operand_index]; |
286 | |
287 | int i = 0; // index of the operand shape |
288 | for (char label : subscript) { |
289 | if (label == EinsumEquation::kEllipsis) { |
290 | // Ellipsis |
291 | Array<PrimExpr> ellipsis_shape = ellipsis_shape_.value(); |
292 | int ellipsis_ndim = |
293 | static_cast<int>(input_shape.size()) - static_cast<int>(subscript.size()) + 1; |
294 | // use last 'ellipsis_ndim' axes |
295 | for (int j = static_cast<int>(ellipsis_indices.size()) - ellipsis_ndim; |
296 | j < static_cast<int>(ellipsis_indices.size()); ++j) { |
297 | indices.push_back( |
298 | GetIndexForBroadcastedDim(ellipsis_indices[j], input_shape[i++], ellipsis_shape[j])); |
299 | } |
300 | } else { |
301 | // Normal label |
302 | indices.push_back(GetIndexForBroadcastedDim(label_to_index.at(label), input_shape[i++], |
303 | label_to_extent_.at(label))); |
304 | } |
305 | } |
306 | ICHECK_EQ(i, input_shape.size()); |
307 | ICHECK_EQ(indices.size(), input_shape.size()); |
308 | return indices; |
309 | } |
310 | |
311 | EinsumEquation equation_; |
312 | Array<Array<PrimExpr>> input_shapes_; |
313 | |
314 | // intermediate results of shape inference |
315 | |
316 | // The output shape |
317 | Array<PrimExpr> output_shape_; |
318 | // The extent of each label with broadcast rules applied |
319 | std::unordered_map<EinsumEquation::Label, PrimExpr> label_to_extent_; |
320 | // The shape of the ellipsis if ellipsis is used. The shape covered by the |
321 | // ellipsis in each operand might be different from this, this is the common |
322 | // shape among them according to the broadcast rules. |
323 | Optional<Array<PrimExpr>> ellipsis_shape_; |
324 | }; |
325 | |
326 | Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs, std::string name, |
327 | std::string tag) { |
328 | EinsumEquation equation = EinsumEquation::FromString(subscripts_str); |
329 | Array<Array<PrimExpr>> input_shapes; |
330 | for (const Tensor& input : inputs) { |
331 | input_shapes.push_back(input->shape); |
332 | } |
333 | EinsumBuilder einsum_builder = EinsumBuilder(equation, input_shapes); |
334 | auto output_shape = einsum_builder.InferShape(); |
335 | return te::compute( |
336 | output_shape, |
337 | [&](const Array<Var>& indices) { return einsum_builder.BuildOutputExpr(inputs, indices); }, |
338 | name, tag); |
339 | } |
340 | |
341 | Array<PrimExpr> InferEinsumShape(const std::string& subscripts, |
342 | const std::vector<Array<PrimExpr>>& operands) { |
343 | EinsumEquation equation = EinsumEquation::FromString(subscripts); |
344 | EinsumBuilder einsum_builder = EinsumBuilder(equation, operands); |
345 | return einsum_builder.InferShape(); |
346 | } |
347 | |
348 | TVM_REGISTER_GLOBAL("topi.einsum" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
349 | *rv = einsum(args[0], args[1]); |
350 | }); |
351 | |
352 | } // namespace topi |
353 | } // namespace tvm |
354 | |