1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/util/einsum_op_util.h"
17
18#include <string>
19
20#include "absl/container/flat_hash_map.h"
21#include "absl/strings/str_split.h"
22#include "tensorflow/core/lib/core/errors.h"
23#include "tensorflow/core/lib/core/status.h"
24#include "tensorflow/core/lib/gtl/inlined_vector.h"
25
26namespace tensorflow {
27
28Status ValidateEinsumEquation(const string& equation,
29 gtl::InlinedVector<string, 2>* input_subscripts,
30 string* output_subscript) {
31 gtl::InlinedVector<string, 2> inputs_and_output_subscripts =
32 absl::StrSplit(equation, "->");
33 if (inputs_and_output_subscripts.size() != 2) {
34 return errors::InvalidArgument(
35 "Expecting exactly one '->' in einsum equation: ", equation);
36 }
37 *output_subscript = std::move(inputs_and_output_subscripts[1]);
38 *input_subscripts =
39 absl::StrSplit(std::move(inputs_and_output_subscripts[0]), ',');
40 if (input_subscripts->size() != 1 && input_subscripts->size() != 2) {
41 return errors::InvalidArgument(
42 "Expecting 1 or 2 input subscripts in equation '", equation,
43 "' but got: ", input_subscripts->size());
44 }
45 return OkStatus();
46}
47
48// Returns the EinsumDimensionType given whether the corresponding label is
49// present in exactly one input subscript (is_unique) and whether it is absent
50// from the output subscripts (is_removed). Does not handle broadcasting
51// dimensions.
52EinsumDimensionType GetDimensionType(bool is_removed, bool is_unique) {
53 if (!is_removed && !is_unique)
54 return kBatch;
55 else if (!is_removed && is_unique)
56 return kFree;
57 else if (is_removed && !is_unique)
58 return kContract;
59 else // is_removed && is_unique
60 return kReduce;
61}
62
63// Maps the character labels to consecutive integers.
64void MapToLabels(const string& subscript, Labels* labels,
65 absl::flat_hash_map<char, int>* label_mapping) {
66 for (int i = 0; i < subscript.size(); ++i) {
67 const char label_char = subscript[i];
68 if (label_char == '.') {
69 labels->push_back(kEllipsisLabel);
70 i += 2; // Skip next 2 characters as well.
71 continue;
72 }
73 if (!label_mapping->contains(label_char)) {
74 const int next_label = label_mapping->size();
75 (*label_mapping)[label_char] = next_label;
76 }
77 const int mapped_label = (*label_mapping)[label_char];
78 labels->push_back(mapped_label);
79 }
80}
81
82Status ParseEinsumEquation(const string& equation, OperandLabels* input_labels,
83 Labels* output_labels,
84 std::vector<EinsumDimensionType>* label_types,
85 OperandLabelCounts* input_label_counts,
86 LabelCounts* output_label_counts,
87 gtl::InlinedVector<bool, 2>* input_has_ellipsis,
88 bool* output_has_ellipsis) {
89 gtl::InlinedVector<string, 2> input_str;
90 string output_str;
91 TF_RETURN_IF_ERROR(ValidateEinsumEquation(equation, &input_str, &output_str));
92
93 // Temporary map from single character labels to (consecutive) integer labels.
94 absl::flat_hash_map<char, int> label_mapping;
95 int num_inputs = input_str.size();
96 input_labels->resize(num_inputs);
97
98 // Map from single characters to integer labels.
99 for (int i = 0; i < num_inputs; ++i) {
100 MapToLabels(input_str[i], &input_labels->at(i), &label_mapping);
101 }
102 MapToLabels(output_str, output_labels, &label_mapping);
103
104 // Compute counts for input and output labels.
105 int num_labels = label_mapping.size();
106 input_label_counts->resize(num_inputs);
107 input_has_ellipsis->resize(num_inputs);
108 for (int i = 0; i < num_inputs; ++i) {
109 input_label_counts->at(i).resize(num_labels);
110 input_has_ellipsis->at(i) = false;
111 for (const int label : input_labels->at(i)) {
112 if (label != kEllipsisLabel)
113 input_label_counts->at(i)[label] += 1;
114 else
115 input_has_ellipsis->at(i) = true;
116 }
117 }
118 output_label_counts->resize(num_labels);
119 *output_has_ellipsis = false;
120 for (const int label : *output_labels) {
121 if (label != kEllipsisLabel)
122 output_label_counts->at(label) += 1;
123 else
124 *output_has_ellipsis = true;
125 }
126
127 // Map each label to a unique EinsumDimensionType.
128 label_types->resize(num_labels);
129 for (int label = 0; label < num_labels; ++label) {
130 if (label == kEllipsisLabel) continue;
131 bool removed = (*output_label_counts)[label] == 0;
132 bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 ||
133 (*input_label_counts)[1][label] == 0;
134 (*label_types)[label] = GetDimensionType(removed, unique);
135 }
136 return OkStatus();
137}
138
139} // namespace tensorflow
140