1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file topi/einsum.h |
22 | * \brief Einstein summation op |
23 | */ |
24 | #ifndef TVM_TOPI_EINSUM_H_ |
25 | #define TVM_TOPI_EINSUM_H_ |
26 | |
27 | #define LABELRANGE 128 |
28 | #define NPY_MAXDIMS 16 |
29 | #define NPY_MAXARGS 16 |
30 | |
31 | #include <tvm/te/operation.h> |
32 | #include <tvm/tir/data_layout.h> |
33 | #include <tvm/topi/detail/constant_utils.h> |
34 | #include <tvm/topi/detail/ravel_unravel.h> |
35 | #include <tvm/topi/detail/tensor_utils.h> |
36 | #include <tvm/topi/tags.h> |
37 | |
38 | #include <algorithm> |
39 | #include <bitset> |
40 | #include <iterator> |
41 | #include <string> |
42 | #include <tuple> |
43 | #include <unordered_set> |
44 | #include <vector> |
45 | |
46 | namespace tvm { |
47 | namespace topi { |
48 | |
49 | using namespace tvm::te; |
50 | using namespace topi::detail; |
51 | |
52 | /*! |
53 | * \brief Compute the shape of the output. |
54 | * \param subscripts input subscripts. |
55 | * \param operands operand tensors. |
56 | * |
57 | * \return the shape of the output. |
58 | */ |
59 | Array<PrimExpr> InferEinsumShape(const std::string& subscripts, |
60 | const std::vector<Array<PrimExpr>>& operands); |
61 | |
62 | /*! |
63 | * \brief Evaluates the Einstein summation convention on the operands. |
64 | * |
65 | * \param subscripts_str Specifies the subscripts for summation as comma separated list of |
66 | * subscript labels. |
67 | * \param inputs Arrays for the operation. |
68 | * \param name The name of the operation. |
69 | * \param tag The tag to mark the operation. |
70 | * |
71 | * \return The calculation based on the Einstein summation convention. |
72 | */ |
73 | Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs, |
74 | std::string name = "T_einsum" , std::string tag = kEinsum); |
75 | |
76 | struct EinsumEquation { |
77 | /*! |
78 | * \brief Create EinsumEquation from a string. |
79 | * The result will be converted to the explicit mode of Einsum if it is in implicit mode. |
80 | * \return The created EinsumEquation. |
81 | */ |
82 | static EinsumEquation FromString(const std::string& equation); |
83 | using Label = char; |
84 | using Subscript = std::vector<Label>; |
85 | // Special label value for ellipsis. The value is chosen to be less than any other letters so make |
86 | // sorting easier. |
87 | static constexpr Label kEllipsis = '\0'; |
88 | // The input subscripts for each operand of the Einsum operator. |
89 | std::vector<Subscript> inputs; |
90 | // The output subscript of the Einsum equation. |
91 | Subscript output; |
92 | }; |
93 | |
94 | } // namespace topi |
95 | } // namespace tvm |
96 | #endif // TVM_TOPI_EINSUM_H_ |
97 | |