1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_UTIL_BCAST_H_ |
17 | #define TENSORFLOW_CORE_UTIL_BCAST_H_ |
18 | |
19 | #include <algorithm> |
20 | |
21 | #include "tensorflow/core/framework/tensor_shape.h" |
22 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
23 | #include "tensorflow/core/platform/macros.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | // Returns the mapping from the output batch indices to the corresponding |
29 | // input's batch indices, given the input's "reshape" and "bcast" shapes as |
30 | // returned by the BCastList helper class. The i'th element denotes the |
31 | // (flattened) batch index of the input that must be used to compute the i'th |
32 | // batch output. |
33 | // |
34 | inline void ComputeBatchIndices(const int64_t output_batch_size, |
35 | const gtl::InlinedVector<int64_t, 4>& reshape, |
36 | const gtl::InlinedVector<int64_t, 4>& bcast, |
37 | std::vector<int64_t>* out_indices) { |
38 | // Populates the mapping in out_indices. This algorithm is identical to |
39 | // the following steps: |
40 | // - Reshape {0, 1, ..., input_batch_size - 1} to the input shape. |
41 | // - Broadcast to the output shape. |
42 | // - Reshape back to a flat 1D vector. |
43 | out_indices->resize(output_batch_size); |
44 | int64_t num_output_elements = 1; |
45 | int64_t num_input_elements = 1; |
46 | for (int64_t i = reshape.size() - 1; i >= 0; --i) { |
47 | // Replicate the already populated mapping an additional (dim - 1) times. |
48 | // If we are broadcasting, just copy the existing mapping. |
49 | // Otherwise, add another dimension from the input shape. |
50 | const int64_t dim = std::max(reshape[i], bcast[i]); |
51 | const int64_t incr = bcast[i] > 1 ? 0 : num_input_elements; |
52 | for (int64_t k = 0; k < (dim - 1) * num_output_elements; ++k) { |
53 | (*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr; |
54 | } |
55 | num_output_elements *= dim; |
56 | num_input_elements *= reshape[i]; |
57 | } |
58 | } |
59 | |
60 | template <int N> |
61 | class BCastList { |
62 | public: |
63 | // A vector of int64 representing the shape of tensor. The 0-th |
64 | // element is the outer-most dimension and the last element is the |
65 | // inner-most dimension. Note that we do not use TensorShape since |
66 | // it's more convenient to manipulate Vec directly for this module. |
67 | typedef gtl::InlinedVector<int64_t, 4> Vec; |
68 | |
69 | // Constructs all helper shapes, following the aforementioned rules. |
70 | // |
71 | // If "fewer_dims_optimization" is set to true (the default), the |
72 | // implementation tries to reduce intermediate dimensions needed to be more |
73 | // efficient. This is transparent to the caller. |
74 | // |
75 | // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have |
76 | // the same number of dimensions as the larger of the two inputs. |
77 | // |
78 | // If return_flattened_batch_indices is true, the implementation will compute |
79 | // for each output member of the flattened output, which batch indices of |
80 | // each input correspond to it. This is disabled by default. |
81 | explicit BCastList(const Vec (&x)[N], |
82 | const bool fewer_dims_optimization = true, |
83 | const bool return_flattened_batch_indices = false); |
84 | ~BCastList() {} |
85 | |
86 | // Returns true iff two operands are compatible according to the |
87 | // broadcasting rule. |
88 | bool IsValid() const { return valid_; } |
89 | bool IsBroadcastingRequired() const { return broadcasting_required_; } |
90 | |
91 | // If and only if IsValid(), the following fields can be used in |
92 | // implementing a broadcasted binary tensor operation according to |
93 | // the broadcasting rule. |
94 | const Vec& reshape(int i) const { return reshape_[i]; } |
95 | const Vec& bcast(int i) const { return bcast_[i]; } |
96 | const Vec& result_shape() const { return result_; } |
97 | const Vec& output_shape() const { return output_; } |
98 | const Vec& grad_reduce_idx(int i) const { return grad_reduce_idx_[i]; } |
99 | const int64_t output_batch_size() const { return output_batch_size_; } |
100 | |
101 | // Returns the mapping from the flattened output batch indices to x's |
102 | // flattened batch indices. The result is a vector of length |
103 | // output_batch_size(). To compute the i'th batch output, a binary matmul-like |
104 | // operation should use the `x_batch_indices()[i]`th batch index of `x`. |
105 | // Note: Returns an empty vector if broadcasting is not required. Callers |
106 | // should only use this when IsBroadcastingRequired() returns true. |
107 | const std::vector<int64_t>& batch_indices(int i) const { |
108 | return batch_indices_[i]; |
109 | } |
110 | |
111 | protected: |
112 | bool valid_ = true; |
113 | bool broadcasting_required_ = true; |
114 | Vec reshape_[N]; |
115 | Vec bcast_[N]; |
116 | Vec result_; |
117 | Vec output_; |
118 | Vec grad_reduce_idx_[N]; |
119 | |
120 | int64_t output_batch_size_; |
121 | std::vector<int64_t> batch_indices_[N]; |
122 | |
123 | static void Reverse(Vec* shape) { |
124 | std::reverse(shape->begin(), shape->end()); |
125 | } |
126 | |
127 | TF_DISALLOW_COPY_AND_ASSIGN(BCastList); |
128 | }; |
129 | |
130 | template <int N> |
131 | BCastList<N>::BCastList(const BCastList::Vec (&x)[N], |
132 | const bool fewer_dims_optimization, |
133 | const bool return_flattened_batch_indices) { |
134 | typedef BCastList::Vec Vec; |
135 | |
136 | // Safely multiplies dimensions taking into account symbolic shapes. |
137 | auto mul_dims = [](int64_t dim1, int64_t dim2) -> int64_t { |
138 | return dim1 != 0 && dim2 != 0 && (dim1 < 0 || dim2 < 0) ? -1 : dim1 * dim2; |
139 | }; |
140 | |
141 | bool all_equal = true; |
142 | size_t largest_rank = 0; |
143 | output_batch_size_ = 1; |
144 | for (int i = 0; i < N; ++i) { |
145 | if (x[i] != x[0]) { |
146 | all_equal = false; |
147 | } |
148 | if (x[i].size() > largest_rank) { |
149 | largest_rank = x[i].size(); |
150 | } |
151 | } |
152 | if (all_equal) { |
153 | broadcasting_required_ = false; |
154 | } |
155 | if (all_equal && TF_PREDICT_TRUE(fewer_dims_optimization)) { |
156 | // Fast path for common case of identical shapes. |
157 | int64_t elements = 1; |
158 | const int rank = x[0].size(); |
159 | output_.resize(rank); |
160 | for (int i = 0; i < rank; i++) { |
161 | const int64_t dim = x[0][i]; |
162 | elements = mul_dims(elements, dim); |
163 | output_[i] = dim; |
164 | } |
165 | result_.push_back(elements); |
166 | output_batch_size_ = elements; |
167 | for (int i = 0; i < N; ++i) { |
168 | reshape_[i].push_back(elements); |
169 | bcast_[i].push_back(1); |
170 | } |
171 | // grad_reduce_ is left as empty |
172 | return; |
173 | } |
174 | |
175 | // Reverse all the shapes for convenience |
176 | // After the reverse, 0-th is the inner-most dimension. |
177 | Vec copy[N]; |
178 | for (int i = 0; i < N; ++i) { |
179 | copy[i] = x[i]; |
180 | Reverse(©[i]); |
181 | } |
182 | |
183 | // 1-extend and align all vectors. |
184 | for (int i = 0; i < N; ++i) { |
185 | if (copy[i].size() < largest_rank) { |
186 | copy[i].resize(largest_rank, 1); |
187 | } |
188 | } |
189 | // Going through each dimension starting from the inner-most |
190 | // dimension, compares dimension of x and y. They are compatible if |
191 | // they are equal or either is 1. |
192 | |
193 | // indices of j-th component of each input. |
194 | bool prev_is_one[N]; |
195 | bool current_is_one[N]; |
196 | for (int i = 0; i < N; ++i) { |
197 | prev_is_one[i] = false; |
198 | current_is_one[i] = false; |
199 | } |
200 | Vec output; |
201 | bool output_dim_set = false; |
202 | int64_t output_dim = -1; |
203 | bool none_is_one = true; |
204 | bool set_one = false; |
205 | for (int j = 0; j < largest_rank; ++j) { |
206 | output_dim = -1; |
207 | output_dim_set = false; |
208 | none_is_one = true; |
209 | // Find which indices are 1. |
210 | for (int i = 0; i < N; ++i) { |
211 | // Keep track of which indices are 1. |
212 | if (copy[i][j] == 1) { |
213 | current_is_one[i] = true; |
214 | none_is_one = false; |
215 | } else { |
216 | current_is_one[i] = false; |
217 | if (!output_dim_set || copy[i][j] == output_dim) { |
218 | output_dim = copy[i][j]; |
219 | output_dim_set = true; |
220 | } else { |
221 | valid_ = false; |
222 | return; |
223 | } |
224 | } |
225 | } |
226 | output_.push_back(output_dim_set ? output_dim : 1); |
227 | output_batch_size_ = mul_dims(output_batch_size_, output_.back()); |
228 | // All dimensions are 1. |
229 | if (!output_dim_set) { |
230 | if (!TF_PREDICT_TRUE(fewer_dims_optimization)) { |
231 | for (int i = 0; i < N; ++i) { |
232 | bcast_[i].push_back(1); |
233 | reshape_[i].push_back(1); |
234 | } |
235 | result_.push_back(1); |
236 | } |
237 | for (int i = 0; i < N; ++i) { |
238 | grad_reduce_idx_[i].push_back(largest_rank - 1 - j); |
239 | } |
240 | // This will skip updating the previous state to the current one. We'll |
241 | // explain why this is safe below. |
242 | // Consider the previous state P, current state C and the next state N. |
243 | // In the case where N also is all ones (N == C), we'll do the same |
244 | // optimization here (push back one dimensions if we need to), which is |
245 | // safe and is expected. |
246 | // |
247 | // When N != C, we'll continue as usual. However, we might trigger the |
248 | // next block if N == P (because we didn't update the previous state). |
249 | // We trigger the next block if `fewer_dims_optimization` is true. |
250 | // This means that we did not modify and broadcast / reshapes in this |
251 | // block (we skipped updating, since the one dimensions can be ignored). |
252 | // In essence, we only need to check whether the previous non-one state is |
253 | // equal to the current non-one state. |
254 | |
255 | continue; |
256 | } else if (TF_PREDICT_TRUE(fewer_dims_optimization) && |
257 | std::equal(current_is_one, current_is_one + N, prev_is_one) && |
258 | set_one) { |
259 | // It is a run of the same broadcasting case as last time. |
260 | // We can reshape the input so that fewer dimensions |
261 | // are involved in the intermediate computation. |
262 | result_.back() = mul_dims(result_.back(), output_dim); |
263 | for (int i = 0; i < N; ++i) { |
264 | reshape_[i].back() = mul_dims(reshape_[i].back(), copy[i][j]); |
265 | bcast_[i].back() = |
266 | mul_dims(bcast_[i].back(), current_is_one[i] ? output_dim : 1); |
267 | if (current_is_one[i] && !none_is_one) { |
268 | grad_reduce_idx_[i].push_back(largest_rank - 1 - j); |
269 | } |
270 | } |
271 | } else { |
272 | result_.push_back(output_dim); |
273 | for (int i = 0; i < N; ++i) { |
274 | reshape_[i].push_back(copy[i][j]); |
275 | bcast_[i].push_back(current_is_one[i] ? output_dim : 1); |
276 | if (current_is_one[i] && !none_is_one) { |
277 | grad_reduce_idx_[i].push_back(largest_rank - 1 - j); |
278 | } |
279 | } |
280 | } |
281 | set_one = true; |
282 | for (int i = 0; i < N; ++i) { |
283 | prev_is_one[i] = current_is_one[i]; |
284 | } |
285 | } |
286 | if (result_.empty()) { |
287 | result_.push_back(1); |
288 | for (int i = 0; i < N; ++i) { |
289 | reshape_[i].push_back(1); |
290 | bcast_[i].push_back(1); |
291 | } |
292 | } |
293 | // Do something about batches. |
294 | for (int i = 0; i < N; ++i) { |
295 | Reverse(&reshape_[i]); |
296 | Reverse(&bcast_[i]); |
297 | Reverse(&grad_reduce_idx_[i]); |
298 | } |
299 | Reverse(&result_); |
300 | Reverse(&output_); |
301 | // Only compute batch indices when we need broadcasting, and we aren't doing |
302 | // needless work (when the output size is 0 or the |
303 | // return_flattened_batch_indices isn't enabled). |
304 | if (return_flattened_batch_indices && broadcasting_required_ && |
305 | output_batch_size_ > 0) { |
306 | for (int i = 0; i < N; ++i) { |
307 | ComputeBatchIndices(output_batch_size_, reshape_[i], bcast_[i], |
308 | &batch_indices_[i]); |
309 | } |
310 | } |
311 | } |
312 | |
313 | // BCast is a helper for broadcasting binary tensor operation. |
314 | // TensorFlow's broadcasting rule follows that of numpy (See |
315 | // http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). |
316 | // |
317 | // The rule has the following properties: |
318 | // |
319 | // 1. suffix matching: the rule starts with the right-most |
320 | // dimension, and works towards the left-most dimension. Since |
321 | // TensorFlow is row-major, the right-most dimension (the last |
322 | // element in the shape of a tensor) is the inner-most, a.k.a. |
323 | // the fastest changing, dimension. |
324 | // |
325 | // 2. Two dimensions are compatible for broadcasting if both are the |
326 | // same or either is 1. |
327 | // |
328 | // BCast takes the shape of two tensors and computes a few vectors of |
329 | // int32 that are useful for the caller to reshape the tensors, apply |
330 | // the right broadcasts to them, compute the broadcasted operation, |
331 | // and possibly the gradients. In a nutshell, the caller is expected |
332 | // to compute the broadcasted operation as following: |
333 | // |
334 | // BCast b(x.shape(), y.shape()); |
335 | // output = x.reshape(b.x_reshape()).broadcast(b.x_bcast()) |
336 | // _op_ |
337 | // y.reshape(b.y_reshape()).broadcast(b.y_bcast()) |
338 | // |
339 | // For the gradient computation, |
340 | // grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx) |
341 | // .reshape(x.shape()) |
342 | // grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx) |
343 | // .reshape(y.shape()) |
344 | // backprop_x and backprop_y are functionals of the binary function "op", |
345 | // e.g., |
346 | // for +, backprop_x(x, y) = backprop_y(x, y) = 1; |
347 | // for *, backprop_x(x, y) = y, backprop_y(x, y) = x; |
348 | // for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2; |
349 | // |
350 | // The multiplication in the grad * backprop_x itself is also |
351 | // broadcasting following the same rule. |
352 | class BCast : public BCastList<2> { |
353 | public: |
354 | // Constructs all helper shapes, following the aforementioned rules. |
355 | // |
356 | // If "fewer_dims_optimization" is set to true (the default), the |
357 | // implementation tries to reduce intermediate dimensions needed to be more |
358 | // efficient. This is transparent to the caller. |
359 | // |
360 | // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have |
361 | // the same number of dimensions as the larger of the two inputs. |
362 | typedef gtl::InlinedVector<int64_t, 4> Vec; |
363 | |
364 | BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true, |
365 | const bool return_flattened_batch_indices = false) |
366 | : BCastList<2>({x, y}, fewer_dims_optimization, |
367 | return_flattened_batch_indices) {} |
368 | |
369 | ~BCast() {} |
370 | |
371 | // If and only if IsValid(), the following fields can be used in |
372 | // implementing a broadcasted binary tensor operation according to |
373 | // the broadcasting rule. |
374 | const Vec& x_reshape() const { return reshape_[0]; } |
375 | const Vec& x_bcast() const { return bcast_[0]; } |
376 | const Vec& y_reshape() const { return reshape_[1]; } |
377 | const Vec& y_bcast() const { return bcast_[1]; } |
378 | const Vec& result_shape() const { return result_; } |
379 | const Vec& output_shape() const { return output_; } |
380 | const Vec& grad_x_reduce_idx() const { return grad_reduce_idx_[0]; } |
381 | const Vec& grad_y_reduce_idx() const { return grad_reduce_idx_[1]; } |
382 | |
383 | // Returns the mapping from the flattened output batch indices to x's |
384 | // flattened batch indices. The result is a vector of length |
385 | // output_batch_size(). To compute the i'th batch output, a binary matmul-like |
386 | // operation should use the `x_batch_indices()[i]`th batch index of `x`. |
387 | // Note: Returns an empty vector if broadcasting is not required. Callers |
388 | // should only use this when IsBroadcastingRequired() returns true. |
389 | const std::vector<int64_t>& x_batch_indices() const { |
390 | return batch_indices_[0]; |
391 | } |
392 | // Returns the mapping from the flattened output batch indices to y's |
393 | // flattened batch indices. Similar to x_batch_indices(). |
394 | // Note: Returns an empty vector if broadcasting is not required. Callers |
395 | // should only use this when IsBroadcastingRequired() returns true. |
396 | const std::vector<int64_t>& y_batch_indices() const { |
397 | return batch_indices_[1]; |
398 | } |
399 | |
400 | template <typename IndexType, int NDIMS> |
401 | static Eigen::array<IndexType, NDIMS> ToIndexArrayType( |
402 | const BCast::Vec& vec) { |
403 | CHECK_EQ(vec.size(), NDIMS); |
404 | Eigen::array<IndexType, NDIMS> ret; |
405 | for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i]; |
406 | return ret; |
407 | } |
408 | |
409 | template <int NDIMS> |
410 | static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray( |
411 | const BCast::Vec& vec) { |
412 | return ToIndexArrayType<Eigen::DenseIndex, NDIMS>(vec); |
413 | } |
414 | |
415 | // Static helpers. |
416 | static Vec FromShape(const TensorShape& shape); |
417 | static TensorShape ToShape(const Vec& vec); |
418 | |
419 | private: |
420 | TF_DISALLOW_COPY_AND_ASSIGN(BCast); |
421 | }; |
422 | |
423 | } // end namespace tensorflow |
424 | |
425 | #endif // TENSORFLOW_CORE_UTIL_BCAST_H_ |
426 | |