1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/batch_util.h" |
17 | |
18 | #include "tensorflow/core/framework/register_types.h" |
19 | #include "tensorflow/core/framework/types.h" |
20 | #include "tensorflow/core/lib/core/errors.h" |
21 | |
22 | #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) |
23 | |
24 | namespace tensorflow { |
25 | namespace batch_util { |
26 | |
27 | namespace { |
28 | |
29 | Status ValidateInput(const Tensor& parent, const Tensor& element, |
30 | int64_t index) { |
31 | DCHECK_NE(parent.dim_size(0), 0); |
32 | DCHECK_GE(index, 0); |
33 | if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) { |
34 | TensorShape chip_shape = parent.shape(); |
35 | chip_shape.RemoveDim(0); |
36 | return errors::Internal( |
37 | "ValidateInput Cannot perform copy: number of elements does not match. " |
38 | " Shapes are: [element]: " , |
39 | element.shape().DebugString(), |
40 | ", [parent slice]: " , chip_shape.DebugString()); |
41 | } |
42 | return OkStatus(); |
43 | } |
44 | |
45 | template <typename T> |
46 | Status HandleElementToSlice(const Tensor& /* element */, T* src, T* dest, |
47 | int64_t num_values) { |
48 | static_assert(tsl::is_simple_type<T>::value, |
49 | "Memcpy requires a simple type." ); |
50 | memcpy(dest, src, num_values * sizeof(T)); |
51 | return OkStatus(); |
52 | } |
53 | |
54 | template <> |
55 | Status HandleElementToSlice<tstring>(const Tensor& element, tstring* src, |
56 | tstring* dest, int64_t num_values) { |
57 | if (element.RefCountIsOne()) { |
58 | for (int64_t i = 0; i < num_values; ++i) { |
59 | *dest++ = std::move(*src++); |
60 | } |
61 | } else { |
62 | std::copy_n(src, num_values, dest); |
63 | } |
64 | return OkStatus(); |
65 | } |
66 | |
67 | template <> |
68 | Status HandleElementToSlice<Variant>(const Tensor& element, Variant* src, |
69 | Variant* dest, int64_t num_values) { |
70 | if (element.RefCountIsOne()) { |
71 | for (int64_t i = 0; i < num_values; ++i) { |
72 | *dest++ = std::move(*src++); |
73 | } |
74 | } else { |
75 | std::copy_n(src, num_values, dest); |
76 | } |
77 | return OkStatus(); |
78 | } |
79 | |
80 | template <> |
81 | Status HandleElementToSlice<ResourceHandle>(const Tensor& /* element */, |
82 | ResourceHandle* src, |
83 | ResourceHandle* dest, |
84 | int64_t num_values) { |
85 | std::copy_n(src, num_values, dest); |
86 | return OkStatus(); |
87 | } |
88 | |
89 | template <> |
90 | Status HandleElementToSlice<Eigen::half>(const Tensor& /* element */, |
91 | Eigen::half* src, Eigen::half* dest, |
92 | int64_t num_values) { |
93 | std::copy_n(src, num_values, dest); |
94 | return OkStatus(); |
95 | } |
96 | |
97 | template <typename T> |
98 | void HandleSliceToElement(const T* src, T* dest, int64_t num_values) { |
99 | static_assert(tsl::is_simple_type<T>::value, |
100 | "Memcpy requires a simple type." ); |
101 | memcpy(dest, src, num_values * sizeof(T)); |
102 | } |
103 | |
104 | template <> |
105 | void HandleSliceToElement<tstring>(const tstring* src, tstring* dest, |
106 | int64_t num_values) { |
107 | std::copy_n(src, num_values, dest); |
108 | } |
109 | |
110 | template <> |
111 | void HandleSliceToElement<Variant>(const Variant* src, Variant* dest, |
112 | int64_t num_values) { |
113 | std::copy_n(src, num_values, dest); |
114 | } |
115 | |
116 | template <> |
117 | void HandleSliceToElement<ResourceHandle>(const ResourceHandle* src, |
118 | ResourceHandle* dest, |
119 | int64_t num_values) { |
120 | std::copy_n(src, num_values, dest); |
121 | } |
122 | |
123 | template <> |
124 | void HandleSliceToElement<Eigen::half>(const Eigen::half* src, |
125 | Eigen::half* dest, int64_t num_values) { |
126 | std::copy_n(src, num_values, dest); |
127 | } |
128 | |
129 | template <typename T> |
130 | void HandleSliceToElement(Tensor* parent, T* src, T* dest, int64_t num_values) { |
131 | static_assert(tsl::is_simple_type<T>::value, |
132 | "Memcpy requires a simple type." ); |
133 | memcpy(dest, src, num_values * sizeof(T)); |
134 | } |
135 | |
136 | template <> |
137 | void HandleSliceToElement<tstring>(Tensor* parent, tstring* src, tstring* dest, |
138 | int64_t num_values) { |
139 | if (parent->RefCountIsOne()) { |
140 | for (int64_t i = 0; i < num_values; ++i) { |
141 | dest[i] = std::move(src[i]); |
142 | } |
143 | } else { |
144 | std::copy_n(src, num_values, dest); |
145 | } |
146 | } |
147 | |
148 | template <> |
149 | void HandleSliceToElement<Variant>(Tensor* parent, Variant* src, Variant* dest, |
150 | int64_t num_values) { |
151 | if (parent->RefCountIsOne()) { |
152 | for (int64_t i = 0; i < num_values; ++i) { |
153 | dest[i] = std::move(src[i]); |
154 | } |
155 | } else { |
156 | std::copy_n(src, num_values, dest); |
157 | } |
158 | } |
159 | |
160 | template <> |
161 | void HandleSliceToElement<ResourceHandle>(Tensor* parent, ResourceHandle* src, |
162 | ResourceHandle* dest, |
163 | int64_t num_values) { |
164 | std::copy_n(src, num_values, dest); |
165 | } |
166 | |
167 | template <> |
168 | void HandleSliceToElement<Eigen::half>(Tensor* parent, Eigen::half* src, |
169 | Eigen::half* dest, int64_t num_values) { |
170 | std::copy_n(src, num_values, dest); |
171 | } |
172 | |
173 | } // namespace |
174 | |
175 | // Copies element into the index^th slice of parent (in the 0th dimension). |
176 | Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index) { |
177 | TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index)); |
178 | const int64_t num_values = element.NumElements(); |
179 | #define HANDLE_TYPE(T) \ |
180 | case DataTypeToEnum<T>::value: { \ |
181 | T* src = element.base<T>(); \ |
182 | T* dest = parent->base<T>() + (num_values * index); \ |
183 | return HandleElementToSlice<T>(element, src, dest, num_values); \ |
184 | } |
185 | |
186 | switch (element.dtype()) { |
187 | TF_CALL_ALL_TYPES(HANDLE_TYPE); |
188 | TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); |
189 | #undef HANDLE_TYPE |
190 | default: |
191 | return errors::Unimplemented("CopyElementToSlice Unhandled data type: " , |
192 | element.dtype()); |
193 | } |
194 | } |
195 | |
196 | // Copies the index^th slice of parent (in the 0th dimension) into element. |
197 | Status CopySliceToElement(const Tensor& parent, Tensor* element, |
198 | int64_t index) { |
199 | TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index)); |
200 | const int64_t num_values = element->NumElements(); |
201 | |
202 | #define HANDLE_TYPE(T) \ |
203 | case DataTypeToEnum<T>::value: { \ |
204 | const T* src = parent.base<T>() + (num_values * index); \ |
205 | T* dest = element->base<T>(); \ |
206 | HandleSliceToElement<T>(src, dest, num_values); \ |
207 | return OkStatus(); \ |
208 | } |
209 | |
210 | switch (parent.dtype()) { |
211 | TF_CALL_ALL_TYPES(HANDLE_TYPE); |
212 | TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); |
213 | #undef HANDLE_TYPE |
214 | default: |
215 | return errors::Unimplemented("CopySliceToElement Unhandled data type: " , |
216 | element->dtype()); |
217 | } |
218 | } |
219 | |
220 | Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, |
221 | int64_t dst_offset, int64_t num_slices, |
222 | Tensor* dst) { |
223 | if (src.dtype() != dst->dtype()) { |
224 | return errors::FailedPrecondition( |
225 | "CopyContiguousSlices cannot perform copy: src and dst have different " |
226 | "dtypes. Source dtype: " , |
227 | src.dtype(), " dstination dtype: " , dst->dtype(), "." ); |
228 | } |
229 | if (src.dims() < 1) { |
230 | return errors::FailedPrecondition( |
231 | "CopyContiguousSlices cannot perform copy: src has to be a tensor with " |
232 | "rank >= 1. Source shape: " , |
233 | src.shape().DebugString()); |
234 | } |
235 | |
236 | if (dst->dims() < 1) { |
237 | return errors::FailedPrecondition( |
238 | "CopyContiguousSlices cannot perform copy: dst has to be a tensor " |
239 | "with rank >= 1. Dest shape: " , |
240 | dst->shape().DebugString()); |
241 | } |
242 | |
243 | const int64_t src_dim0 = src.dim_size(0); |
244 | const int64_t dst_dim0 = dst->dim_size(0); |
245 | int64_t src_chip_size = 1; |
246 | int64_t dst_chip_size = 1; |
247 | for (int i = 1; i < src.dims(); ++i) { |
248 | src_chip_size *= src.dim_size(i); |
249 | } |
250 | for (int i = 1; i < dst->dims(); ++i) { |
251 | dst_chip_size *= dst->dim_size(i); |
252 | } |
253 | |
254 | if (src_chip_size != dst_chip_size) { |
255 | return errors::FailedPrecondition( |
256 | "CopyContiguousSlices cannot perform copy: source and dst shapes are" |
257 | "not compatible. Source shape: " , |
258 | src.shape().DebugString(), ", dst shape: " , dst->shape().DebugString()); |
259 | } |
260 | |
261 | if (src_chip_size == 0 && dst_chip_size == 0) { |
262 | return OkStatus(); |
263 | } |
264 | |
265 | if (src_offset < 0 || src_offset + num_slices > src_dim0 || dst_offset < 0 || |
266 | dst_offset + num_slices > dst_dim0) { |
267 | return errors::FailedPrecondition( |
268 | "CopyContiguousSlices cannot perform copy: index out of range. " |
269 | "src_offset: " , |
270 | src_offset, ", num_slices: " , num_slices, ", src_dim0: " , src_dim0, |
271 | ", dst_offset: " , dst_offset, ", dst_dim0: " , dst_dim0, "." ); |
272 | } |
273 | |
274 | #define HANDLE_TYPE(T) \ |
275 | case DataTypeToEnum<T>::value: { \ |
276 | const T* src_p = src.base<T>() + (src_chip_size * src_offset); \ |
277 | T* dst_p = dst->base<T>() + (dst_chip_size * dst_offset); \ |
278 | HandleSliceToElement<T>(src_p, dst_p, src_chip_size * num_slices); \ |
279 | return OkStatus(); \ |
280 | } |
281 | |
282 | switch (src.dtype()) { |
283 | TF_CALL_ALL_TYPES(HANDLE_TYPE); |
284 | TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); |
285 | #undef HANDLE_TYPE |
286 | default: |
287 | return errors::Unimplemented("CopyContiguousSlices unhandled data type: " , |
288 | src.dtype()); |
289 | } |
290 | } |
291 | |
292 | // Copies the index^th slice of parent (in the 0th dimension) into element. |
293 | // |
294 | // NOTE(mrry): The implementation may be able to optimize the copy to a move. |
295 | // This is particularly important for DT_STRING tensors. |
296 | Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64_t index) { |
297 | TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index)); |
298 | const int64_t num_values = element->NumElements(); |
299 | |
300 | #define HANDLE_TYPE(T) \ |
301 | case DataTypeToEnum<T>::value: { \ |
302 | T* src = parent->base<T>() + (num_values * index); \ |
303 | T* dest = element->base<T>(); \ |
304 | HandleSliceToElement<T>(parent, src, dest, num_values); \ |
305 | return OkStatus(); \ |
306 | } |
307 | |
308 | switch (parent->dtype()) { |
309 | TF_CALL_ALL_TYPES(HANDLE_TYPE); |
310 | TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); |
311 | #undef HANDLE_TYPE |
312 | default: |
313 | return errors::Unimplemented( |
314 | "MaybeMoveSliceToElement Unhandled data type: " , element->dtype()); |
315 | } |
316 | } |
317 | |
318 | // The following five functions are copied from padding_fifo_queue.cc. |
319 | // TODO(mrry): Reconcile these functions with the similar methods in the |
320 | // queue implementation. |
321 | Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { |
322 | DCHECK_NE(parent->dim_size(0), 0); |
323 | if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { |
324 | TensorShape chip_shape = parent->shape(); |
325 | chip_shape.RemoveDim(0); |
326 | return errors::Internal( |
327 | "HandleElementToLargerSlice Cannot copy slice: number of entries in " |
328 | "element is greater than number of elements in parent slice. " , |
329 | "Shapes are: [element]: " , element.shape().DebugString(), |
330 | ", [parent slice]: " , chip_shape.DebugString()); |
331 | } |
332 | return OkStatus(); |
333 | } |
334 | |
335 | template <typename T, int NDIMS> |
336 | Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, |
337 | int index) { |
338 | TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent)); |
339 | if (element.NumElements() == 0) { |
340 | return OkStatus(); |
341 | } |
342 | auto element_t = element.tensor<T, NDIMS>(); |
343 | auto parent_t = parent->tensor<T, NDIMS + 1>(); |
344 | Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices; |
345 | slice_indices[0] = index; |
346 | Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size; |
347 | slice_size[0] = 1; |
348 | for (size_t i = 1; i < slice_size.size(); ++i) { |
349 | slice_size[i] = element_t.dimension(i - 1); |
350 | } |
351 | parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); |
352 | return OkStatus(); |
353 | } |
354 | |
355 | template <int NDIMS> |
356 | Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, |
357 | int index) { |
358 | #define HANDLE_TYPE(T) \ |
359 | case DataTypeToEnum<T>::value: { \ |
360 | return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \ |
361 | } |
362 | |
363 | switch (element.dtype()) { |
364 | TF_CALL_DATASET_TYPES(HANDLE_TYPE); |
365 | #undef HANDLE_TYPE |
366 | default: |
367 | return errors::Unimplemented( |
368 | "HandleElementToLargerSliceWithRank Unhandled data type: " , |
369 | element.dtype()); |
370 | } |
371 | } |
372 | |
373 | Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, |
374 | int index) { |
375 | if (parent->dims() != element.dims() + 1) { |
376 | return errors::Internal( |
377 | "Mismatched ranks. Element's rank is: " , element.dims(), |
378 | " but element is meant to be a slice in output Tensor having rank: " , |
379 | parent->dims(), " (should be: " , element.dims() + 1, ")" ); |
380 | } |
381 | |
382 | #define HANDLE_DIMS(NDIMS) \ |
383 | case NDIMS: { \ |
384 | TF_RETURN_IF_ERROR( \ |
385 | HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \ |
386 | return OkStatus(); \ |
387 | } |
388 | |
389 | switch (element.dims()) { |
390 | HANDLE_DIMS(0); |
391 | HANDLE_DIMS(1); |
392 | HANDLE_DIMS(2); |
393 | HANDLE_DIMS(3); |
394 | HANDLE_DIMS(4); |
395 | HANDLE_DIMS(5); |
396 | #undef HANDLE_DIMS |
397 | default: |
398 | return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: " , |
399 | element.dims()); |
400 | } |
401 | } |
402 | |
403 | Status SetElementZero(Tensor* element, const Tensor& padding) { |
404 | #define HANDLE_TYPE(T) \ |
405 | if (element->dtype() == DataTypeToEnum<T>::value) { \ |
406 | element->flat<T>().setConstant(padding.scalar<T>()()); \ |
407 | return OkStatus(); \ |
408 | } |
409 | TF_CALL_DATASET_TYPES(HANDLE_TYPE); |
410 | #undef HANDLE_TYPE |
411 | return errors::Unimplemented("SetElementZero Unhandled data type: " , |
412 | element->dtype()); |
413 | } |
414 | |
415 | } // namespace batch_util |
416 | } // namespace tensorflow |
417 | |