1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ |
16 | #define TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ |
17 | |
18 | #include "tensorflow/core/lib/core/status.h" |
19 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | using Labels = gtl::InlinedVector<int, 8>; |
24 | using OperandLabels = gtl::InlinedVector<Labels, 2>; |
25 | using LabelCounts = gtl::InlinedVector<int, 8>; |
26 | using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>; |
27 | |
28 | // Dummy axis label used to denote an ellipsis in an input or output subscript. |
29 | constexpr int kEllipsisLabel = -1; |
30 | |
31 | // Each dimension is categorized into exactly one of five types based on |
32 | // whether its corresponding label is present in the input and/or the output |
33 | // subscripts. |
34 | enum EinsumDimensionType { |
35 | // Batch dimensions are those present in two inputs as well as the output. |
36 | // They are part of the batch dimensions during Tensor contraction. Such |
37 | // dimensions may be broadcasting dimensions (those mapping to ellipsis) |
38 | // or explicit batch dimensions corresponding to named axis labels. |
39 | kBroadcasting = 0, |
40 | kBatch = 1, |
41 | // Free dimensions are present in exactly one of the inputs, and also the |
42 | // output. These are non-contracted axes in the Tensor contraction. |
43 | kFree = 2, |
44 | // Contract dimensions are present in two inputs, but not the output. These |
45 | // dimensions are contracted in Tensor contraction. |
46 | kContract = 3, |
47 | // Reduce dimensions are present in exactly one input; and not in the output |
48 | // and are summed over prior to Tensor contraction. |
49 | kReduce = 4, |
50 | }; |
51 | |
52 | // Parses and validates an einsum equation in explicit form. |
53 | Status ValidateEinsumEquation(const string& equation, |
54 | gtl::InlinedVector<string, 2>* input_subscripts, |
55 | string* output_subscript); |
56 | |
57 | // Parses and validates the equation and the input shapes. Single character |
58 | // labels are integerized and we populate input and output label subscripts |
59 | // and corresponding counts. Also create the mapping from (named) labels to |
60 | // their EinsumDimensionType. |
61 | Status ParseEinsumEquation(const string& equation, OperandLabels* input_labels, |
62 | Labels* output_labels, |
63 | std::vector<EinsumDimensionType>* label_types, |
64 | OperandLabelCounts* input_label_counts, |
65 | LabelCounts* output_label_counts, |
66 | gtl::InlinedVector<bool, 2>* input_has_ellipsis, |
67 | bool* output_has_ellipsis); |
68 | |
69 | } // namespace tensorflow |
70 | |
71 | #endif // TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ |
72 | |