1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/util/batch_util.h"
17
18#include "tensorflow/core/framework/register_types.h"
19#include "tensorflow/core/framework/types.h"
20#include "tensorflow/core/lib/core/errors.h"
21
22#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
23
24namespace tensorflow {
25namespace batch_util {
26
27namespace {
28
29Status ValidateInput(const Tensor& parent, const Tensor& element,
30 int64_t index) {
31 DCHECK_NE(parent.dim_size(0), 0);
32 DCHECK_GE(index, 0);
33 if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) {
34 TensorShape chip_shape = parent.shape();
35 chip_shape.RemoveDim(0);
36 return errors::Internal(
37 "ValidateInput Cannot perform copy: number of elements does not match. "
38 " Shapes are: [element]: ",
39 element.shape().DebugString(),
40 ", [parent slice]: ", chip_shape.DebugString());
41 }
42 return OkStatus();
43}
44
45template <typename T>
46Status HandleElementToSlice(const Tensor& /* element */, T* src, T* dest,
47 int64_t num_values) {
48 static_assert(tsl::is_simple_type<T>::value,
49 "Memcpy requires a simple type.");
50 memcpy(dest, src, num_values * sizeof(T));
51 return OkStatus();
52}
53
54template <>
55Status HandleElementToSlice<tstring>(const Tensor& element, tstring* src,
56 tstring* dest, int64_t num_values) {
57 if (element.RefCountIsOne()) {
58 for (int64_t i = 0; i < num_values; ++i) {
59 *dest++ = std::move(*src++);
60 }
61 } else {
62 std::copy_n(src, num_values, dest);
63 }
64 return OkStatus();
65}
66
67template <>
68Status HandleElementToSlice<Variant>(const Tensor& element, Variant* src,
69 Variant* dest, int64_t num_values) {
70 if (element.RefCountIsOne()) {
71 for (int64_t i = 0; i < num_values; ++i) {
72 *dest++ = std::move(*src++);
73 }
74 } else {
75 std::copy_n(src, num_values, dest);
76 }
77 return OkStatus();
78}
79
80template <>
81Status HandleElementToSlice<ResourceHandle>(const Tensor& /* element */,
82 ResourceHandle* src,
83 ResourceHandle* dest,
84 int64_t num_values) {
85 std::copy_n(src, num_values, dest);
86 return OkStatus();
87}
88
89template <>
90Status HandleElementToSlice<Eigen::half>(const Tensor& /* element */,
91 Eigen::half* src, Eigen::half* dest,
92 int64_t num_values) {
93 std::copy_n(src, num_values, dest);
94 return OkStatus();
95}
96
97template <typename T>
98void HandleSliceToElement(const T* src, T* dest, int64_t num_values) {
99 static_assert(tsl::is_simple_type<T>::value,
100 "Memcpy requires a simple type.");
101 memcpy(dest, src, num_values * sizeof(T));
102}
103
104template <>
105void HandleSliceToElement<tstring>(const tstring* src, tstring* dest,
106 int64_t num_values) {
107 std::copy_n(src, num_values, dest);
108}
109
110template <>
111void HandleSliceToElement<Variant>(const Variant* src, Variant* dest,
112 int64_t num_values) {
113 std::copy_n(src, num_values, dest);
114}
115
116template <>
117void HandleSliceToElement<ResourceHandle>(const ResourceHandle* src,
118 ResourceHandle* dest,
119 int64_t num_values) {
120 std::copy_n(src, num_values, dest);
121}
122
123template <>
124void HandleSliceToElement<Eigen::half>(const Eigen::half* src,
125 Eigen::half* dest, int64_t num_values) {
126 std::copy_n(src, num_values, dest);
127}
128
129template <typename T>
130void HandleSliceToElement(Tensor* parent, T* src, T* dest, int64_t num_values) {
131 static_assert(tsl::is_simple_type<T>::value,
132 "Memcpy requires a simple type.");
133 memcpy(dest, src, num_values * sizeof(T));
134}
135
136template <>
137void HandleSliceToElement<tstring>(Tensor* parent, tstring* src, tstring* dest,
138 int64_t num_values) {
139 if (parent->RefCountIsOne()) {
140 for (int64_t i = 0; i < num_values; ++i) {
141 dest[i] = std::move(src[i]);
142 }
143 } else {
144 std::copy_n(src, num_values, dest);
145 }
146}
147
148template <>
149void HandleSliceToElement<Variant>(Tensor* parent, Variant* src, Variant* dest,
150 int64_t num_values) {
151 if (parent->RefCountIsOne()) {
152 for (int64_t i = 0; i < num_values; ++i) {
153 dest[i] = std::move(src[i]);
154 }
155 } else {
156 std::copy_n(src, num_values, dest);
157 }
158}
159
160template <>
161void HandleSliceToElement<ResourceHandle>(Tensor* parent, ResourceHandle* src,
162 ResourceHandle* dest,
163 int64_t num_values) {
164 std::copy_n(src, num_values, dest);
165}
166
167template <>
168void HandleSliceToElement<Eigen::half>(Tensor* parent, Eigen::half* src,
169 Eigen::half* dest, int64_t num_values) {
170 std::copy_n(src, num_values, dest);
171}
172
173} // namespace
174
175// Copies element into the index^th slice of parent (in the 0th dimension).
176Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index) {
177 TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index));
178 const int64_t num_values = element.NumElements();
179#define HANDLE_TYPE(T) \
180 case DataTypeToEnum<T>::value: { \
181 T* src = element.base<T>(); \
182 T* dest = parent->base<T>() + (num_values * index); \
183 return HandleElementToSlice<T>(element, src, dest, num_values); \
184 }
185
186 switch (element.dtype()) {
187 TF_CALL_ALL_TYPES(HANDLE_TYPE);
188 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
189#undef HANDLE_TYPE
190 default:
191 return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
192 element.dtype());
193 }
194}
195
196// Copies the index^th slice of parent (in the 0th dimension) into element.
197Status CopySliceToElement(const Tensor& parent, Tensor* element,
198 int64_t index) {
199 TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
200 const int64_t num_values = element->NumElements();
201
202#define HANDLE_TYPE(T) \
203 case DataTypeToEnum<T>::value: { \
204 const T* src = parent.base<T>() + (num_values * index); \
205 T* dest = element->base<T>(); \
206 HandleSliceToElement<T>(src, dest, num_values); \
207 return OkStatus(); \
208 }
209
210 switch (parent.dtype()) {
211 TF_CALL_ALL_TYPES(HANDLE_TYPE);
212 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
213#undef HANDLE_TYPE
214 default:
215 return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
216 element->dtype());
217 }
218}
219
220Status CopyContiguousSlices(const Tensor& src, int64_t src_offset,
221 int64_t dst_offset, int64_t num_slices,
222 Tensor* dst) {
223 if (src.dtype() != dst->dtype()) {
224 return errors::FailedPrecondition(
225 "CopyContiguousSlices cannot perform copy: src and dst have different "
226 "dtypes. Source dtype: ",
227 src.dtype(), " dstination dtype: ", dst->dtype(), ".");
228 }
229 if (src.dims() < 1) {
230 return errors::FailedPrecondition(
231 "CopyContiguousSlices cannot perform copy: src has to be a tensor with "
232 "rank >= 1. Source shape: ",
233 src.shape().DebugString());
234 }
235
236 if (dst->dims() < 1) {
237 return errors::FailedPrecondition(
238 "CopyContiguousSlices cannot perform copy: dst has to be a tensor "
239 "with rank >= 1. Dest shape: ",
240 dst->shape().DebugString());
241 }
242
243 const int64_t src_dim0 = src.dim_size(0);
244 const int64_t dst_dim0 = dst->dim_size(0);
245 int64_t src_chip_size = 1;
246 int64_t dst_chip_size = 1;
247 for (int i = 1; i < src.dims(); ++i) {
248 src_chip_size *= src.dim_size(i);
249 }
250 for (int i = 1; i < dst->dims(); ++i) {
251 dst_chip_size *= dst->dim_size(i);
252 }
253
254 if (src_chip_size != dst_chip_size) {
255 return errors::FailedPrecondition(
256 "CopyContiguousSlices cannot perform copy: source and dst shapes are"
257 "not compatible. Source shape: ",
258 src.shape().DebugString(), ", dst shape: ", dst->shape().DebugString());
259 }
260
261 if (src_chip_size == 0 && dst_chip_size == 0) {
262 return OkStatus();
263 }
264
265 if (src_offset < 0 || src_offset + num_slices > src_dim0 || dst_offset < 0 ||
266 dst_offset + num_slices > dst_dim0) {
267 return errors::FailedPrecondition(
268 "CopyContiguousSlices cannot perform copy: index out of range. "
269 "src_offset: ",
270 src_offset, ", num_slices: ", num_slices, ", src_dim0: ", src_dim0,
271 ", dst_offset: ", dst_offset, ", dst_dim0: ", dst_dim0, ".");
272 }
273
274#define HANDLE_TYPE(T) \
275 case DataTypeToEnum<T>::value: { \
276 const T* src_p = src.base<T>() + (src_chip_size * src_offset); \
277 T* dst_p = dst->base<T>() + (dst_chip_size * dst_offset); \
278 HandleSliceToElement<T>(src_p, dst_p, src_chip_size * num_slices); \
279 return OkStatus(); \
280 }
281
282 switch (src.dtype()) {
283 TF_CALL_ALL_TYPES(HANDLE_TYPE);
284 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
285#undef HANDLE_TYPE
286 default:
287 return errors::Unimplemented("CopyContiguousSlices unhandled data type: ",
288 src.dtype());
289 }
290}
291
292// Copies the index^th slice of parent (in the 0th dimension) into element.
293//
294// NOTE(mrry): The implementation may be able to optimize the copy to a move.
295// This is particularly important for DT_STRING tensors.
296Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64_t index) {
297 TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index));
298 const int64_t num_values = element->NumElements();
299
300#define HANDLE_TYPE(T) \
301 case DataTypeToEnum<T>::value: { \
302 T* src = parent->base<T>() + (num_values * index); \
303 T* dest = element->base<T>(); \
304 HandleSliceToElement<T>(parent, src, dest, num_values); \
305 return OkStatus(); \
306 }
307
308 switch (parent->dtype()) {
309 TF_CALL_ALL_TYPES(HANDLE_TYPE);
310 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
311#undef HANDLE_TYPE
312 default:
313 return errors::Unimplemented(
314 "MaybeMoveSliceToElement Unhandled data type: ", element->dtype());
315 }
316}
317
318// The following five functions are copied from padding_fifo_queue.cc.
319// TODO(mrry): Reconcile these functions with the similar methods in the
320// queue implementation.
321Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
322 DCHECK_NE(parent->dim_size(0), 0);
323 if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
324 TensorShape chip_shape = parent->shape();
325 chip_shape.RemoveDim(0);
326 return errors::Internal(
327 "HandleElementToLargerSlice Cannot copy slice: number of entries in "
328 "element is greater than number of elements in parent slice. ",
329 "Shapes are: [element]: ", element.shape().DebugString(),
330 ", [parent slice]: ", chip_shape.DebugString());
331 }
332 return OkStatus();
333}
334
335template <typename T, int NDIMS>
336Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
337 int index) {
338 TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
339 if (element.NumElements() == 0) {
340 return OkStatus();
341 }
342 auto element_t = element.tensor<T, NDIMS>();
343 auto parent_t = parent->tensor<T, NDIMS + 1>();
344 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
345 slice_indices[0] = index;
346 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
347 slice_size[0] = 1;
348 for (size_t i = 1; i < slice_size.size(); ++i) {
349 slice_size[i] = element_t.dimension(i - 1);
350 }
351 parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
352 return OkStatus();
353}
354
355template <int NDIMS>
356Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
357 int index) {
358#define HANDLE_TYPE(T) \
359 case DataTypeToEnum<T>::value: { \
360 return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
361 }
362
363 switch (element.dtype()) {
364 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
365#undef HANDLE_TYPE
366 default:
367 return errors::Unimplemented(
368 "HandleElementToLargerSliceWithRank Unhandled data type: ",
369 element.dtype());
370 }
371}
372
373Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
374 int index) {
375 if (parent->dims() != element.dims() + 1) {
376 return errors::Internal(
377 "Mismatched ranks. Element's rank is: ", element.dims(),
378 " but element is meant to be a slice in output Tensor having rank: ",
379 parent->dims(), " (should be: ", element.dims() + 1, ")");
380 }
381
382#define HANDLE_DIMS(NDIMS) \
383 case NDIMS: { \
384 TF_RETURN_IF_ERROR( \
385 HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
386 return OkStatus(); \
387 }
388
389 switch (element.dims()) {
390 HANDLE_DIMS(0);
391 HANDLE_DIMS(1);
392 HANDLE_DIMS(2);
393 HANDLE_DIMS(3);
394 HANDLE_DIMS(4);
395 HANDLE_DIMS(5);
396#undef HANDLE_DIMS
397 default:
398 return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
399 element.dims());
400 }
401}
402
403Status SetElementZero(Tensor* element, const Tensor& padding) {
404#define HANDLE_TYPE(T) \
405 if (element->dtype() == DataTypeToEnum<T>::value) { \
406 element->flat<T>().setConstant(padding.scalar<T>()()); \
407 return OkStatus(); \
408 }
409 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
410#undef HANDLE_TYPE
411 return errors::Unimplemented("SetElementZero Unhandled data type: ",
412 element->dtype());
413}
414
415} // namespace batch_util
416} // namespace tensorflow
417