1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file topi/einsum.h
22 * \brief Einstein summation op
23 */
24#ifndef TVM_TOPI_EINSUM_H_
25#define TVM_TOPI_EINSUM_H_
26
27#define LABELRANGE 128
28#define NPY_MAXDIMS 16
29#define NPY_MAXARGS 16
30
31#include <tvm/te/operation.h>
32#include <tvm/tir/data_layout.h>
33#include <tvm/topi/detail/constant_utils.h>
34#include <tvm/topi/detail/ravel_unravel.h>
35#include <tvm/topi/detail/tensor_utils.h>
36#include <tvm/topi/tags.h>
37
38#include <algorithm>
39#include <bitset>
40#include <iterator>
41#include <string>
42#include <tuple>
43#include <unordered_set>
44#include <vector>
45
46namespace tvm {
47namespace topi {
48
49using namespace tvm::te;
50using namespace topi::detail;
51
52/*!
53 * \brief Compute the shape of the output.
54 * \param subscripts input subscripts.
55 * \param operands operand tensors.
56 *
57 * \return the shape of the output.
58 */
59Array<PrimExpr> InferEinsumShape(const std::string& subscripts,
60 const std::vector<Array<PrimExpr>>& operands);
61
62/*!
63 * \brief Evaluates the Einstein summation convention on the operands.
64 *
65 * \param subscripts_str Specifies the subscripts for summation as comma separated list of
66 * subscript labels.
67 * \param inputs Arrays for the operation.
68 * \param name The name of the operation.
69 * \param tag The tag to mark the operation.
70 *
71 * \return The calculation based on the Einstein summation convention.
72 */
73Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs,
74 std::string name = "T_einsum", std::string tag = kEinsum);
75
76struct EinsumEquation {
77 /*!
78 * \brief Create EinsumEquation from a string.
79 * The result will be converted to the explicit mode of Einsum if it is in implicit mode.
80 * \return The created EinsumEquation.
81 */
82 static EinsumEquation FromString(const std::string& equation);
83 using Label = char;
84 using Subscript = std::vector<Label>;
85 // Special label value for ellipsis. The value is chosen to be less than any other letters so make
86 // sorting easier.
87 static constexpr Label kEllipsis = '\0';
88 // The input subscripts for each operand of the Einsum operator.
89 std::vector<Subscript> inputs;
90 // The output subscript of the Einsum equation.
91 Subscript output;
92};
93
94} // namespace topi
95} // namespace tvm
96#endif // TVM_TOPI_EINSUM_H_
97