1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/einsum_op_util.h" |
17 | |
18 | #include <string> |
19 | |
20 | #include "absl/container/flat_hash_map.h" |
21 | #include "absl/strings/str_split.h" |
22 | #include "tensorflow/core/lib/core/errors.h" |
23 | #include "tensorflow/core/lib/core/status.h" |
24 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | Status ValidateEinsumEquation(const string& equation, |
29 | gtl::InlinedVector<string, 2>* input_subscripts, |
30 | string* output_subscript) { |
31 | gtl::InlinedVector<string, 2> inputs_and_output_subscripts = |
32 | absl::StrSplit(equation, "->" ); |
33 | if (inputs_and_output_subscripts.size() != 2) { |
34 | return errors::InvalidArgument( |
35 | "Expecting exactly one '->' in einsum equation: " , equation); |
36 | } |
37 | *output_subscript = std::move(inputs_and_output_subscripts[1]); |
38 | *input_subscripts = |
39 | absl::StrSplit(std::move(inputs_and_output_subscripts[0]), ','); |
40 | if (input_subscripts->size() != 1 && input_subscripts->size() != 2) { |
41 | return errors::InvalidArgument( |
42 | "Expecting 1 or 2 input subscripts in equation '" , equation, |
43 | "' but got: " , input_subscripts->size()); |
44 | } |
45 | return OkStatus(); |
46 | } |
47 | |
48 | // Returns the EinsumDimensionType given whether the corresponding label is |
49 | // present in exactly one input subscript (is_unique) and whether it is absent |
50 | // from the output subscripts (is_removed). Does not handle broadcasting |
51 | // dimensions. |
52 | EinsumDimensionType GetDimensionType(bool is_removed, bool is_unique) { |
53 | if (!is_removed && !is_unique) |
54 | return kBatch; |
55 | else if (!is_removed && is_unique) |
56 | return kFree; |
57 | else if (is_removed && !is_unique) |
58 | return kContract; |
59 | else // is_removed && is_unique |
60 | return kReduce; |
61 | } |
62 | |
63 | // Maps the character labels to consecutive integers. |
64 | void MapToLabels(const string& subscript, Labels* labels, |
65 | absl::flat_hash_map<char, int>* label_mapping) { |
66 | for (int i = 0; i < subscript.size(); ++i) { |
67 | const char label_char = subscript[i]; |
68 | if (label_char == '.') { |
69 | labels->push_back(kEllipsisLabel); |
70 | i += 2; // Skip next 2 characters as well. |
71 | continue; |
72 | } |
73 | if (!label_mapping->contains(label_char)) { |
74 | const int next_label = label_mapping->size(); |
75 | (*label_mapping)[label_char] = next_label; |
76 | } |
77 | const int mapped_label = (*label_mapping)[label_char]; |
78 | labels->push_back(mapped_label); |
79 | } |
80 | } |
81 | |
82 | Status ParseEinsumEquation(const string& equation, OperandLabels* input_labels, |
83 | Labels* output_labels, |
84 | std::vector<EinsumDimensionType>* label_types, |
85 | OperandLabelCounts* input_label_counts, |
86 | LabelCounts* output_label_counts, |
87 | gtl::InlinedVector<bool, 2>* input_has_ellipsis, |
88 | bool* output_has_ellipsis) { |
89 | gtl::InlinedVector<string, 2> input_str; |
90 | string output_str; |
91 | TF_RETURN_IF_ERROR(ValidateEinsumEquation(equation, &input_str, &output_str)); |
92 | |
93 | // Temporary map from single character labels to (consecutive) integer labels. |
94 | absl::flat_hash_map<char, int> label_mapping; |
95 | int num_inputs = input_str.size(); |
96 | input_labels->resize(num_inputs); |
97 | |
98 | // Map from single characters to integer labels. |
99 | for (int i = 0; i < num_inputs; ++i) { |
100 | MapToLabels(input_str[i], &input_labels->at(i), &label_mapping); |
101 | } |
102 | MapToLabels(output_str, output_labels, &label_mapping); |
103 | |
104 | // Compute counts for input and output labels. |
105 | int num_labels = label_mapping.size(); |
106 | input_label_counts->resize(num_inputs); |
107 | input_has_ellipsis->resize(num_inputs); |
108 | for (int i = 0; i < num_inputs; ++i) { |
109 | input_label_counts->at(i).resize(num_labels); |
110 | input_has_ellipsis->at(i) = false; |
111 | for (const int label : input_labels->at(i)) { |
112 | if (label != kEllipsisLabel) |
113 | input_label_counts->at(i)[label] += 1; |
114 | else |
115 | input_has_ellipsis->at(i) = true; |
116 | } |
117 | } |
118 | output_label_counts->resize(num_labels); |
119 | *output_has_ellipsis = false; |
120 | for (const int label : *output_labels) { |
121 | if (label != kEllipsisLabel) |
122 | output_label_counts->at(label) += 1; |
123 | else |
124 | *output_has_ellipsis = true; |
125 | } |
126 | |
127 | // Map each label to a unique EinsumDimensionType. |
128 | label_types->resize(num_labels); |
129 | for (int label = 0; label < num_labels; ++label) { |
130 | if (label == kEllipsisLabel) continue; |
131 | bool removed = (*output_label_counts)[label] == 0; |
132 | bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 || |
133 | (*input_label_counts)[1][label] == 0; |
134 | (*label_types)[label] = GetDimensionType(removed, unique); |
135 | } |
136 | return OkStatus(); |
137 | } |
138 | |
139 | } // namespace tensorflow |
140 | |