1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_UTIL_BCAST_H_
17#define TENSORFLOW_CORE_UTIL_BCAST_H_
18
19#include <algorithm>
20
21#include "tensorflow/core/framework/tensor_shape.h"
22#include "tensorflow/core/lib/gtl/inlined_vector.h"
23#include "tensorflow/core/platform/macros.h"
24#include "tensorflow/core/platform/types.h"
25
26namespace tensorflow {
27
28// Returns the mapping from the output batch indices to the corresponding
29// input's batch indices, given the input's "reshape" and "bcast" shapes as
30// returned by the BCastList helper class. The i'th element denotes the
31// (flattened) batch index of the input that must be used to compute the i'th
32// batch output.
33//
34inline void ComputeBatchIndices(const int64_t output_batch_size,
35 const gtl::InlinedVector<int64_t, 4>& reshape,
36 const gtl::InlinedVector<int64_t, 4>& bcast,
37 std::vector<int64_t>* out_indices) {
38 // Populates the mapping in out_indices. This algorithm is identical to
39 // the following steps:
40 // - Reshape {0, 1, ..., input_batch_size - 1} to the input shape.
41 // - Broadcast to the output shape.
42 // - Reshape back to a flat 1D vector.
43 out_indices->resize(output_batch_size);
44 int64_t num_output_elements = 1;
45 int64_t num_input_elements = 1;
46 for (int64_t i = reshape.size() - 1; i >= 0; --i) {
47 // Replicate the already populated mapping an additional (dim - 1) times.
48 // If we are broadcasting, just copy the existing mapping.
49 // Otherwise, add another dimension from the input shape.
50 const int64_t dim = std::max(reshape[i], bcast[i]);
51 const int64_t incr = bcast[i] > 1 ? 0 : num_input_elements;
52 for (int64_t k = 0; k < (dim - 1) * num_output_elements; ++k) {
53 (*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr;
54 }
55 num_output_elements *= dim;
56 num_input_elements *= reshape[i];
57 }
58}
59
60template <int N>
61class BCastList {
62 public:
63 // A vector of int64 representing the shape of tensor. The 0-th
64 // element is the outer-most dimension and the last element is the
65 // inner-most dimension. Note that we do not use TensorShape since
66 // it's more convenient to manipulate Vec directly for this module.
67 typedef gtl::InlinedVector<int64_t, 4> Vec;
68
69 // Constructs all helper shapes, following the aforementioned rules.
70 //
71 // If "fewer_dims_optimization" is set to true (the default), the
72 // implementation tries to reduce intermediate dimensions needed to be more
73 // efficient. This is transparent to the caller.
74 //
75 // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
76 // the same number of dimensions as the larger of the two inputs.
77 //
78 // If return_flattened_batch_indices is true, the implementation will compute
79 // for each output member of the flattened output, which batch indices of
80 // each input correspond to it. This is disabled by default.
81 explicit BCastList(const Vec (&x)[N],
82 const bool fewer_dims_optimization = true,
83 const bool return_flattened_batch_indices = false);
84 ~BCastList() {}
85
86 // Returns true iff two operands are compatible according to the
87 // broadcasting rule.
88 bool IsValid() const { return valid_; }
89 bool IsBroadcastingRequired() const { return broadcasting_required_; }
90
91 // If and only if IsValid(), the following fields can be used in
92 // implementing a broadcasted binary tensor operation according to
93 // the broadcasting rule.
94 const Vec& reshape(int i) const { return reshape_[i]; }
95 const Vec& bcast(int i) const { return bcast_[i]; }
96 const Vec& result_shape() const { return result_; }
97 const Vec& output_shape() const { return output_; }
98 const Vec& grad_reduce_idx(int i) const { return grad_reduce_idx_[i]; }
99 const int64_t output_batch_size() const { return output_batch_size_; }
100
101 // Returns the mapping from the flattened output batch indices to x's
102 // flattened batch indices. The result is a vector of length
103 // output_batch_size(). To compute the i'th batch output, a binary matmul-like
104 // operation should use the `x_batch_indices()[i]`th batch index of `x`.
105 // Note: Returns an empty vector if broadcasting is not required. Callers
106 // should only use this when IsBroadcastingRequired() returns true.
107 const std::vector<int64_t>& batch_indices(int i) const {
108 return batch_indices_[i];
109 }
110
111 protected:
112 bool valid_ = true;
113 bool broadcasting_required_ = true;
114 Vec reshape_[N];
115 Vec bcast_[N];
116 Vec result_;
117 Vec output_;
118 Vec grad_reduce_idx_[N];
119
120 int64_t output_batch_size_;
121 std::vector<int64_t> batch_indices_[N];
122
123 static void Reverse(Vec* shape) {
124 std::reverse(shape->begin(), shape->end());
125 }
126
127 TF_DISALLOW_COPY_AND_ASSIGN(BCastList);
128};
129
130template <int N>
131BCastList<N>::BCastList(const BCastList::Vec (&x)[N],
132 const bool fewer_dims_optimization,
133 const bool return_flattened_batch_indices) {
134 typedef BCastList::Vec Vec;
135
136 // Safely multiplies dimensions taking into account symbolic shapes.
137 auto mul_dims = [](int64_t dim1, int64_t dim2) -> int64_t {
138 return dim1 != 0 && dim2 != 0 && (dim1 < 0 || dim2 < 0) ? -1 : dim1 * dim2;
139 };
140
141 bool all_equal = true;
142 size_t largest_rank = 0;
143 output_batch_size_ = 1;
144 for (int i = 0; i < N; ++i) {
145 if (x[i] != x[0]) {
146 all_equal = false;
147 }
148 if (x[i].size() > largest_rank) {
149 largest_rank = x[i].size();
150 }
151 }
152 if (all_equal) {
153 broadcasting_required_ = false;
154 }
155 if (all_equal && TF_PREDICT_TRUE(fewer_dims_optimization)) {
156 // Fast path for common case of identical shapes.
157 int64_t elements = 1;
158 const int rank = x[0].size();
159 output_.resize(rank);
160 for (int i = 0; i < rank; i++) {
161 const int64_t dim = x[0][i];
162 elements = mul_dims(elements, dim);
163 output_[i] = dim;
164 }
165 result_.push_back(elements);
166 output_batch_size_ = elements;
167 for (int i = 0; i < N; ++i) {
168 reshape_[i].push_back(elements);
169 bcast_[i].push_back(1);
170 }
171 // grad_reduce_ is left as empty
172 return;
173 }
174
175 // Reverse all the shapes for convenience
176 // After the reverse, 0-th is the inner-most dimension.
177 Vec copy[N];
178 for (int i = 0; i < N; ++i) {
179 copy[i] = x[i];
180 Reverse(&copy[i]);
181 }
182
183 // 1-extend and align all vectors.
184 for (int i = 0; i < N; ++i) {
185 if (copy[i].size() < largest_rank) {
186 copy[i].resize(largest_rank, 1);
187 }
188 }
189 // Going through each dimension starting from the inner-most
190 // dimension, compares dimension of x and y. They are compatible if
191 // they are equal or either is 1.
192
193 // indices of j-th component of each input.
194 bool prev_is_one[N];
195 bool current_is_one[N];
196 for (int i = 0; i < N; ++i) {
197 prev_is_one[i] = false;
198 current_is_one[i] = false;
199 }
200 Vec output;
201 bool output_dim_set = false;
202 int64_t output_dim = -1;
203 bool none_is_one = true;
204 bool set_one = false;
205 for (int j = 0; j < largest_rank; ++j) {
206 output_dim = -1;
207 output_dim_set = false;
208 none_is_one = true;
209 // Find which indices are 1.
210 for (int i = 0; i < N; ++i) {
211 // Keep track of which indices are 1.
212 if (copy[i][j] == 1) {
213 current_is_one[i] = true;
214 none_is_one = false;
215 } else {
216 current_is_one[i] = false;
217 if (!output_dim_set || copy[i][j] == output_dim) {
218 output_dim = copy[i][j];
219 output_dim_set = true;
220 } else {
221 valid_ = false;
222 return;
223 }
224 }
225 }
226 output_.push_back(output_dim_set ? output_dim : 1);
227 output_batch_size_ = mul_dims(output_batch_size_, output_.back());
228 // All dimensions are 1.
229 if (!output_dim_set) {
230 if (!TF_PREDICT_TRUE(fewer_dims_optimization)) {
231 for (int i = 0; i < N; ++i) {
232 bcast_[i].push_back(1);
233 reshape_[i].push_back(1);
234 }
235 result_.push_back(1);
236 }
237 for (int i = 0; i < N; ++i) {
238 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
239 }
240 // This will skip updating the previous state to the current one. We'll
241 // explain why this is safe below.
242 // Consider the previous state P, current state C and the next state N.
243 // In the case where N also is all ones (N == C), we'll do the same
244 // optimization here (push back one dimensions if we need to), which is
245 // safe and is expected.
246 //
247 // When N != C, we'll continue as usual. However, we might trigger the
248 // next block if N == P (because we didn't update the previous state).
249 // We trigger the next block if `fewer_dims_optimization` is true.
250 // This means that we did not modify and broadcast / reshapes in this
251 // block (we skipped updating, since the one dimensions can be ignored).
252 // In essence, we only need to check whether the previous non-one state is
253 // equal to the current non-one state.
254
255 continue;
256 } else if (TF_PREDICT_TRUE(fewer_dims_optimization) &&
257 std::equal(current_is_one, current_is_one + N, prev_is_one) &&
258 set_one) {
259 // It is a run of the same broadcasting case as last time.
260 // We can reshape the input so that fewer dimensions
261 // are involved in the intermediate computation.
262 result_.back() = mul_dims(result_.back(), output_dim);
263 for (int i = 0; i < N; ++i) {
264 reshape_[i].back() = mul_dims(reshape_[i].back(), copy[i][j]);
265 bcast_[i].back() =
266 mul_dims(bcast_[i].back(), current_is_one[i] ? output_dim : 1);
267 if (current_is_one[i] && !none_is_one) {
268 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
269 }
270 }
271 } else {
272 result_.push_back(output_dim);
273 for (int i = 0; i < N; ++i) {
274 reshape_[i].push_back(copy[i][j]);
275 bcast_[i].push_back(current_is_one[i] ? output_dim : 1);
276 if (current_is_one[i] && !none_is_one) {
277 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
278 }
279 }
280 }
281 set_one = true;
282 for (int i = 0; i < N; ++i) {
283 prev_is_one[i] = current_is_one[i];
284 }
285 }
286 if (result_.empty()) {
287 result_.push_back(1);
288 for (int i = 0; i < N; ++i) {
289 reshape_[i].push_back(1);
290 bcast_[i].push_back(1);
291 }
292 }
293 // Do something about batches.
294 for (int i = 0; i < N; ++i) {
295 Reverse(&reshape_[i]);
296 Reverse(&bcast_[i]);
297 Reverse(&grad_reduce_idx_[i]);
298 }
299 Reverse(&result_);
300 Reverse(&output_);
301 // Only compute batch indices when we need broadcasting, and we aren't doing
302 // needless work (when the output size is 0 or the
303 // return_flattened_batch_indices isn't enabled).
304 if (return_flattened_batch_indices && broadcasting_required_ &&
305 output_batch_size_ > 0) {
306 for (int i = 0; i < N; ++i) {
307 ComputeBatchIndices(output_batch_size_, reshape_[i], bcast_[i],
308 &batch_indices_[i]);
309 }
310 }
311}
312
313// BCast is a helper for broadcasting binary tensor operation.
314// TensorFlow's broadcasting rule follows that of numpy (See
315// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
316//
317// The rule has the following properties:
318//
319// 1. suffix matching: the rule starts with the right-most
320// dimension, and works towards the left-most dimension. Since
321// TensorFlow is row-major, the right-most dimension (the last
322// element in the shape of a tensor) is the inner-most, a.k.a.
323// the fastest changing, dimension.
324//
325// 2. Two dimensions are compatible for broadcasting if both are the
326// same or either is 1.
327//
328// BCast takes the shape of two tensors and computes a few vectors of
329// int32 that are useful for the caller to reshape the tensors, apply
330// the right broadcasts to them, compute the broadcasted operation,
331// and possibly the gradients. In a nutshell, the caller is expected
332// to compute the broadcasted operation as following:
333//
334// BCast b(x.shape(), y.shape());
335// output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
336// _op_
337// y.reshape(b.y_reshape()).broadcast(b.y_bcast())
338//
339// For the gradient computation,
340// grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
341// .reshape(x.shape())
342// grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
343// .reshape(y.shape())
344// backprop_x and backprop_y are functionals of the binary function "op",
345// e.g.,
346// for +, backprop_x(x, y) = backprop_y(x, y) = 1;
347// for *, backprop_x(x, y) = y, backprop_y(x, y) = x;
348// for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
349//
350// The multiplication in the grad * backprop_x itself is also
351// broadcasting following the same rule.
352class BCast : public BCastList<2> {
353 public:
354 // Constructs all helper shapes, following the aforementioned rules.
355 //
356 // If "fewer_dims_optimization" is set to true (the default), the
357 // implementation tries to reduce intermediate dimensions needed to be more
358 // efficient. This is transparent to the caller.
359 //
360 // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
361 // the same number of dimensions as the larger of the two inputs.
362 typedef gtl::InlinedVector<int64_t, 4> Vec;
363
364 BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true,
365 const bool return_flattened_batch_indices = false)
366 : BCastList<2>({x, y}, fewer_dims_optimization,
367 return_flattened_batch_indices) {}
368
369 ~BCast() {}
370
371 // If and only if IsValid(), the following fields can be used in
372 // implementing a broadcasted binary tensor operation according to
373 // the broadcasting rule.
374 const Vec& x_reshape() const { return reshape_[0]; }
375 const Vec& x_bcast() const { return bcast_[0]; }
376 const Vec& y_reshape() const { return reshape_[1]; }
377 const Vec& y_bcast() const { return bcast_[1]; }
378 const Vec& result_shape() const { return result_; }
379 const Vec& output_shape() const { return output_; }
380 const Vec& grad_x_reduce_idx() const { return grad_reduce_idx_[0]; }
381 const Vec& grad_y_reduce_idx() const { return grad_reduce_idx_[1]; }
382
383 // Returns the mapping from the flattened output batch indices to x's
384 // flattened batch indices. The result is a vector of length
385 // output_batch_size(). To compute the i'th batch output, a binary matmul-like
386 // operation should use the `x_batch_indices()[i]`th batch index of `x`.
387 // Note: Returns an empty vector if broadcasting is not required. Callers
388 // should only use this when IsBroadcastingRequired() returns true.
389 const std::vector<int64_t>& x_batch_indices() const {
390 return batch_indices_[0];
391 }
392 // Returns the mapping from the flattened output batch indices to y's
393 // flattened batch indices. Similar to x_batch_indices().
394 // Note: Returns an empty vector if broadcasting is not required. Callers
395 // should only use this when IsBroadcastingRequired() returns true.
396 const std::vector<int64_t>& y_batch_indices() const {
397 return batch_indices_[1];
398 }
399
400 template <typename IndexType, int NDIMS>
401 static Eigen::array<IndexType, NDIMS> ToIndexArrayType(
402 const BCast::Vec& vec) {
403 CHECK_EQ(vec.size(), NDIMS);
404 Eigen::array<IndexType, NDIMS> ret;
405 for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i];
406 return ret;
407 }
408
409 template <int NDIMS>
410 static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray(
411 const BCast::Vec& vec) {
412 return ToIndexArrayType<Eigen::DenseIndex, NDIMS>(vec);
413 }
414
415 // Static helpers.
416 static Vec FromShape(const TensorShape& shape);
417 static TensorShape ToShape(const Vec& vec);
418
419 private:
420 TF_DISALLOW_COPY_AND_ASSIGN(BCast);
421};
422
423} // end namespace tensorflow
424
425#endif // TENSORFLOW_CORE_UTIL_BCAST_H_
426