1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_ |
17 | #define TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/tensor_shape.h" |
22 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
23 | #include "tensorflow/core/util/bcast.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | // Simple wrapper over BCast specialized for MatMul. |
28 | // Provides utilities for broadcasting across batch dimensions for binary |
29 | // MatMul-like operations. If neither argument has batch dimensions (rank <= 2) |
30 | // then no broadcasting is needed and the operation MatMul operation is |
31 | // considered valid. |
32 | class MatMulBCast { |
33 | public: |
34 | using Vec = BCast::Vec; |
35 | |
36 | MatMulBCast(const Vec& x, const Vec& y) { |
37 | if (std::max(x.size(), y.size()) == 2) return; |
38 | const Vec x_resized(x.begin(), x.end() - 2); |
39 | const Vec y_resized(y.begin(), y.end() - 2); |
40 | |
41 | batch_bcast_ = |
42 | absl::make_unique<BCast>(std::move(x_resized), std::move(y_resized)); |
43 | if (!batch_bcast_->IsValid()) { |
44 | // Set broadcasting_required_ to true to make IsValid() return false; |
45 | broadcasting_required_ = true; |
46 | return; |
47 | } |
48 | |
49 | x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements(); |
50 | y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements(); |
51 | output_batch_shape_ = TensorShape(batch_bcast_->output_shape()); |
52 | output_batch_size_ = output_batch_shape_.num_elements(); |
53 | broadcasting_required_ = |
54 | std::min(x_batch_size_, y_batch_size_) != output_batch_size_; |
55 | |
56 | if (broadcasting_required_) { |
57 | ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(), |
58 | batch_bcast_->x_bcast(), &x_batch_indices_); |
59 | ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(), |
60 | batch_bcast_->y_bcast(), &y_batch_indices_); |
61 | } |
62 | } |
63 | |
64 | bool IsValid() const { |
65 | return !broadcasting_required_ || (batch_bcast_ && batch_bcast_->IsValid()); |
66 | } |
67 | bool IsBroadcastingRequired() const { return broadcasting_required_; } |
68 | |
69 | const int64_t output_batch_size() const { return output_batch_size_; } |
70 | const int64_t x_batch_size() const { return x_batch_size_; } |
71 | const int64_t y_batch_size() const { return y_batch_size_; } |
72 | const TensorShape& output_batch_shape() const { return output_batch_shape_; } |
73 | |
74 | // Returns the mapping from the flattened output batch indices to x's |
75 | // flattened batch indices. The result is a vector of length |
76 | // output_batch_size(). To compute the i'th batch output, a binary matmul-like |
77 | // operation should use the `x_batch_indices()[i]`th batch index of `x`. |
78 | // Note: Returns an empty vector if broadcasting is not required. Callers |
79 | // should only use this when IsBroadcastingRequired() returns true. |
80 | const std::vector<int64_t>& x_batch_indices() const { |
81 | return x_batch_indices_; |
82 | } |
83 | // Returns the mapping from the flattened output batch indices to y's |
84 | // flattened batch indices. Similar to x_batch_indices(). |
85 | // Note: Returns an empty vector if broadcasting is not required. Callers |
86 | // should only use this when IsBroadcastingRequired() returns true. |
87 | const std::vector<int64_t>& y_batch_indices() const { |
88 | return y_batch_indices_; |
89 | } |
90 | |
91 | private: |
92 | std::unique_ptr<BCast> batch_bcast_; |
93 | bool broadcasting_required_ = false; |
94 | int64_t x_batch_size_ = 1; |
95 | int64_t y_batch_size_ = 1; |
96 | TensorShape output_batch_shape_; |
97 | int64_t output_batch_size_ = 1; |
98 | std::vector<int64_t> x_batch_indices_; |
99 | std::vector<int64_t> y_batch_indices_; |
100 | }; |
101 | |
102 | } // namespace tensorflow |
103 | |
104 | #endif // TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_ |
105 | |