1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
16#define TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
17
18#include "tensorflow/core/lib/core/status.h"
19#include "tensorflow/core/lib/gtl/inlined_vector.h"
20
21namespace tensorflow {
22
23using Labels = gtl::InlinedVector<int, 8>;
24using OperandLabels = gtl::InlinedVector<Labels, 2>;
25using LabelCounts = gtl::InlinedVector<int, 8>;
26using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>;
27
28// Dummy axis label used to denote an ellipsis in an input or output subscript.
29constexpr int kEllipsisLabel = -1;
30
31// Each dimension is categorized into exactly one of five types based on
32// whether its corresponding label is present in the input and/or the output
33// subscripts.
34enum EinsumDimensionType {
35 // Batch dimensions are those present in two inputs as well as the output.
36 // They are part of the batch dimensions during Tensor contraction. Such
37 // dimensions may be broadcasting dimensions (those mapping to ellipsis)
38 // or explicit batch dimensions corresponding to named axis labels.
39 kBroadcasting = 0,
40 kBatch = 1,
41 // Free dimensions are present in exactly one of the inputs, and also the
42 // output. These are non-contracted axes in the Tensor contraction.
43 kFree = 2,
44 // Contract dimensions are present in two inputs, but not the output. These
45 // dimensions are contracted in Tensor contraction.
46 kContract = 3,
47 // Reduce dimensions are present in exactly one input; and not in the output
48 // and are summed over prior to Tensor contraction.
49 kReduce = 4,
50};
51
52// Parses and validates an einsum equation in explicit form.
53Status ValidateEinsumEquation(const string& equation,
54 gtl::InlinedVector<string, 2>* input_subscripts,
55 string* output_subscript);
56
57// Parses and validates the equation and the input shapes. Single character
58// labels are integerized and we populate input and output label subscripts
59// and corresponding counts. Also create the mapping from (named) labels to
60// their EinsumDimensionType.
61Status ParseEinsumEquation(const string& equation, OperandLabels* input_labels,
62 Labels* output_labels,
63 std::vector<EinsumDimensionType>* label_types,
64 OperandLabelCounts* input_label_counts,
65 LabelCounts* output_label_counts,
66 gtl::InlinedVector<bool, 2>* input_has_ellipsis,
67 bool* output_has_ellipsis);
68
69} // namespace tensorflow
70
71#endif // TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
72