1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
21 | #define EIGEN_USE_GPU |
22 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
23 | |
24 | #include "tensorflow/core/kernels/strided_slice_op.h" |
25 | |
26 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
27 | #include "tensorflow/core/framework/bounds_check.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/op_requires.h" |
30 | #include "tensorflow/core/framework/register_types.h" |
31 | #include "tensorflow/core/framework/tensor.h" |
32 | #include "tensorflow/core/kernels/dense_update_functor.h" |
33 | #include "tensorflow/core/kernels/inplace_ops_functor.h" |
34 | #include "tensorflow/core/kernels/ops_util.h" |
35 | #include "tensorflow/core/kernels/slice_op.h" |
36 | #include "tensorflow/core/kernels/strided_slice_op_impl.h" |
37 | #include "tensorflow/core/kernels/training_op_helpers.h" |
38 | #include "tensorflow/core/kernels/variable_ops.h" |
39 | #include "tensorflow/core/lib/core/refcount.h" |
40 | #include "tensorflow/core/lib/core/status.h" |
41 | #include "tensorflow/core/lib/gtl/array_slice.h" |
42 | #include "tensorflow/core/platform/errors.h" |
43 | #include "tensorflow/core/platform/prefetch.h" |
44 | #include "tensorflow/core/platform/status.h" |
45 | #include "tensorflow/core/util/strided_slice_op.h" |
46 | |
47 | namespace tensorflow { |
48 | namespace { |
49 | |
50 | template <typename T> |
51 | struct MemCpyFunctor { |
52 | // Returns true if the copy was made with memcpy, false otherwise. |
53 | bool Copy(const Tensor& input, const gtl::InlinedVector<int64_t, 4>& begin, |
54 | const gtl::InlinedVector<int64_t, 4>& end, Tensor* result) { |
55 | if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { |
56 | auto in = input.tensor<T, 2>(); |
57 | auto output = result->tensor<T, 2>(); |
58 | // TODO(agarwal): Consider multi-threading if size[0] is large |
59 | for (int row_in = begin[0], row_out = 0; row_in < end[0]; |
60 | ++row_in, ++row_out) { |
61 | if (row_in + 1 < end[0]) { |
62 | port::prefetch<port::PREFETCH_HINT_T0>(&output(row_in + 1, 0)); |
63 | port::prefetch<port::PREFETCH_HINT_T0>(&in(row_in + 1, begin[1])); |
64 | } |
65 | memcpy(&output(row_out, 0), &in(row_in, begin[1]), |
66 | (end[1] - begin[1]) * sizeof(T)); |
67 | } |
68 | return true; |
69 | } |
70 | return false; |
71 | } |
72 | }; |
73 | |
74 | template <> |
75 | struct MemCpyFunctor<ResourceHandle> { |
76 | bool Copy(const Tensor& input, const gtl::InlinedVector<int64_t, 4>& begin, |
77 | const gtl::InlinedVector<int64_t, 4>& end, Tensor* result) { |
78 | return false; |
79 | } |
80 | }; |
81 | |
82 | } // namespace |
83 | |
84 | template <typename Device, typename T> |
85 | class StridedSliceOp : public OpKernel { |
86 | public: |
87 | explicit StridedSliceOp(OpKernelConstruction* context) : OpKernel(context) { |
88 | OP_REQUIRES_OK(context, context->GetAttr("begin_mask" , &begin_mask)); |
89 | OP_REQUIRES_OK(context, context->GetAttr("end_mask" , &end_mask)); |
90 | OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask" , &ellipsis_mask)); |
91 | OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask" , &new_axis_mask)); |
92 | OP_REQUIRES_OK(context, |
93 | context->GetAttr("shrink_axis_mask" , &shrink_axis_mask)); |
94 | } |
95 | |
96 | void Compute(OpKernelContext* context) override { |
97 | TensorShape processing_shape, final_shape; |
98 | bool is_identity = true; |
99 | bool slice_dim0 = true; |
100 | bool is_simple_slice = true; |
101 | gtl::InlinedVector<int64_t, 4> begin; |
102 | gtl::InlinedVector<int64_t, 4> end; |
103 | gtl::InlinedVector<int64_t, 4> strides; |
104 | |
105 | OP_REQUIRES_OK( |
106 | context, ValidateStridedSliceOp( |
107 | &context->input(1), &context->input(2), context->input(3), |
108 | context->input(0).shape(), begin_mask, end_mask, |
109 | ellipsis_mask, new_axis_mask, shrink_axis_mask, |
110 | &processing_shape, &final_shape, &is_identity, |
111 | &is_simple_slice, &slice_dim0, &begin, &end, &strides)); |
112 | const Tensor& input = context->input(0); |
113 | |
114 | // Optimization #1, slice is a no-op plus reshape |
115 | if (is_identity) { |
116 | VLOG(1) << "Strided slice identity " ; |
117 | Tensor tmp; |
118 | OP_REQUIRES(context, tmp.CopyFrom(input, final_shape), |
119 | errors::Internal("Copy failed" )); |
120 | context->set_output(0, tmp); |
121 | return; |
122 | } |
123 | |
124 | // Optimization #2, slice is memory contiguous (only occurs in dim 0) |
125 | if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], end[0])) { |
126 | OP_REQUIRES(context, input.dims() >= 1, |
127 | errors::InvalidArgument( |
128 | "Input must have rank at least 1, got: " , input.dims())); |
129 | // Otherwise, is_identity should be true. |
130 | VLOG(1) << "Strided slice dim 0: " << input.shape().DebugString(); |
131 | // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end). |
132 | Tensor slice = input.Slice(std::min(begin[0], end[0]), end[0]); |
133 | Tensor tmp; |
134 | OP_REQUIRES(context, tmp.CopyFrom(slice, final_shape), |
135 | errors::Internal("Copy failed" )); |
136 | context->set_output(0, tmp); |
137 | return; |
138 | } |
139 | |
140 | Tensor* result = nullptr; |
141 | OP_REQUIRES_OK(context, context->allocate_output(0, final_shape, &result)); |
142 | const int input_dims = input.dims(); |
143 | const int processing_dims = processing_shape.dims(); |
144 | |
145 | if (processing_shape.num_elements() > 0) { |
146 | // Optimization #3, slice has stride 1 in all dimensions |
147 | // Optimization #3A, slice has only two dimensions |
148 | // TODO(aselle): Here we are restricting to processing_shape and |
149 | // final_shape being 2D. This isn't strictly necessary, but I don't |
150 | // want to blow up code gen size, because to shape<> you need static |
151 | // NDIM and T |
152 | if (is_simple_slice && std::is_same<Device, CPUDevice>::value && |
153 | input_dims == 2 && processing_shape.dims() == 2 && |
154 | final_shape.dims() == 2 && new_axis_mask == 0) { |
155 | MemCpyFunctor<T> functor; |
156 | if (functor.Copy(input, begin, end, result)) { |
157 | return; |
158 | } |
159 | } |
160 | |
161 | #define HANDLE_DIM(NDIM) \ |
162 | if (processing_dims == NDIM) { \ |
163 | HandleStridedSliceCase<Device, T, NDIM>(context, begin, end, strides, \ |
164 | processing_shape, is_simple_slice, \ |
165 | result); \ |
166 | return; \ |
167 | } |
168 | |
169 | HANDLE_DIM(1); |
170 | HANDLE_DIM(2); |
171 | HANDLE_DIM(3); |
172 | HANDLE_DIM(4); |
173 | HANDLE_DIM(5); |
174 | HANDLE_DIM(6); |
175 | HANDLE_DIM(7); |
176 | HANDLE_DIM(8); |
177 | |
178 | #undef HANDLE_DIM |
179 | |
180 | OP_REQUIRES( |
181 | context, false, |
182 | errors::Unimplemented("Unhandled input dimensions " , input_dims)); |
183 | } |
184 | } |
185 | |
186 | private: |
187 | int32 begin_mask, end_mask; |
188 | int32 ellipsis_mask, new_axis_mask, shrink_axis_mask; |
189 | }; |
190 | |
191 | template <typename Device, typename T> |
192 | class StridedSliceGradOp : public OpKernel { |
193 | public: |
194 | explicit StridedSliceGradOp(OpKernelConstruction* context) |
195 | : OpKernel(context) { |
196 | OP_REQUIRES_OK(context, context->GetAttr("begin_mask" , &begin_mask)); |
197 | OP_REQUIRES_OK(context, context->GetAttr("end_mask" , &end_mask)); |
198 | OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask" , &ellipsis_mask)); |
199 | OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask" , &new_axis_mask)); |
200 | OP_REQUIRES_OK(context, |
201 | context->GetAttr("shrink_axis_mask" , &shrink_axis_mask)); |
202 | } |
203 | |
204 | void Compute(OpKernelContext* context) override { |
205 | TensorShape processing_shape, final_shape; |
206 | bool is_identity = true; |
207 | bool slice_dim0 = true; |
208 | bool is_simple_slice = true; |
209 | gtl::InlinedVector<int64_t, 4> begin; |
210 | gtl::InlinedVector<int64_t, 4> end; |
211 | gtl::InlinedVector<int64_t, 4> strides; |
212 | |
213 | TensorShape input_shape; |
214 | const Tensor& input_shape_tensor = context->input(0); |
215 | OP_REQUIRES( |
216 | context, input_shape_tensor.dims() == 1, |
217 | errors::InvalidArgument("shape must be 1-D, got shape.shape = " , |
218 | input_shape_tensor.shape().DebugString())); |
219 | if (input_shape_tensor.dtype() == DT_INT32) { |
220 | OP_REQUIRES_OK( |
221 | context, TensorShapeUtils::MakeShape(input_shape_tensor.vec<int32>(), |
222 | &input_shape)); |
223 | } else if (input_shape_tensor.dtype() == DT_INT64) { |
224 | OP_REQUIRES_OK(context, |
225 | TensorShapeUtils::MakeShape( |
226 | input_shape_tensor.vec<int64_t>(), &input_shape)); |
227 | } else { |
228 | LOG(FATAL) << "shape must have type int32 or int64." ; |
229 | } |
230 | |
231 | OP_REQUIRES_OK( |
232 | context, |
233 | ValidateStridedSliceOp( |
234 | &context->input(1), &context->input(2), context->input(3), |
235 | input_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask, |
236 | shrink_axis_mask, &processing_shape, &final_shape, &is_identity, |
237 | &is_simple_slice, &slice_dim0, &begin, &end, &strides)); |
238 | |
239 | // Check to make sure dy is consistent with the original slice |
240 | TensorShape dy_shape = context->input(4).shape(); |
241 | OP_REQUIRES( |
242 | context, final_shape == dy_shape, |
243 | errors::InvalidArgument("shape of dy was " , dy_shape.DebugString(), |
244 | " instead of " , final_shape.DebugString())); |
245 | |
246 | if (!context->status().ok()) return; |
247 | |
248 | // const int input_dims = input.dims(); |
249 | const int processing_dims = processing_shape.dims(); |
250 | Tensor* result = nullptr; |
251 | OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &result)); |
252 | |
253 | if (processing_shape.dims() == 0) { |
254 | auto in = context->input(4); |
255 | OP_REQUIRES(context, result->CopyFrom(in, processing_shape), |
256 | errors::Internal("Copy failed" )); |
257 | return; |
258 | } |
259 | |
260 | #define HANDLE_DIM(NDIM) \ |
261 | if (processing_dims == NDIM) { \ |
262 | HandleStridedSliceGradCase<Device, T, NDIM>(context, begin, end, strides, \ |
263 | processing_shape, \ |
264 | is_simple_slice, result); \ |
265 | return; \ |
266 | } |
267 | |
268 | HANDLE_DIM(1); |
269 | HANDLE_DIM(2); |
270 | HANDLE_DIM(3); |
271 | HANDLE_DIM(4); |
272 | HANDLE_DIM(5); |
273 | HANDLE_DIM(6); |
274 | HANDLE_DIM(7); |
275 | HANDLE_DIM(8); |
276 | |
277 | #undef HANDLE_DIM |
278 | } |
279 | |
280 | private: |
281 | int32 begin_mask, end_mask; |
282 | int32 ellipsis_mask, new_axis_mask, shrink_axis_mask; |
283 | }; |
284 | |
285 | template <typename Device, typename T, bool isTensor> |
286 | class StridedSliceAssignOp : public OpKernel { |
287 | public: |
288 | explicit StridedSliceAssignOp(OpKernelConstruction* context) |
289 | : OpKernel(context) { |
290 | OP_REQUIRES_OK(context, context->GetAttr("begin_mask" , &begin_mask)); |
291 | OP_REQUIRES_OK(context, context->GetAttr("end_mask" , &end_mask)); |
292 | OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask" , &ellipsis_mask)); |
293 | OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask" , &new_axis_mask)); |
294 | OP_REQUIRES_OK(context, |
295 | context->GetAttr("shrink_axis_mask" , &shrink_axis_mask)); |
296 | } |
297 | |
298 | void Compute(OpKernelContext* context) override { |
299 | TensorShape processing_shape; // Without reshaping input. |
300 | TensorShape final_shape; // After reshaping input. |
301 | bool is_identity = true; |
302 | bool slice_dim0 = true; |
303 | bool is_simple_slice = true; |
304 | gtl::InlinedVector<int64_t, 4> begin; |
305 | gtl::InlinedVector<int64_t, 4> end; |
306 | gtl::InlinedVector<int64_t, 4> strides; |
307 | |
308 | Tensor* old_lhs = nullptr; |
309 | Tensor tmp; |
310 | if (isTensor) { |
311 | const Tensor& input = context->input(0); |
312 | |
313 | int forwarded_input; |
314 | OP_REQUIRES_OK(context, |
315 | context->forward_input_or_allocate_output( |
316 | {0}, 0, input.shape(), &old_lhs, &forwarded_input)); |
317 | if (forwarded_input < 0) { |
318 | OP_REQUIRES_OK(context, |
319 | tensorflow::functor::DoCopy( |
320 | context->eigen_device<Device>(), input, old_lhs)); |
321 | } |
322 | } else { |
323 | if (context->input_dtype(0) == DT_RESOURCE) { |
324 | core::RefCountPtr<Var> v; |
325 | OP_REQUIRES_OK( |
326 | context, LookupResource(context, HandleFromInput(context, 0), &v)); |
327 | OP_REQUIRES_OK(context, |
328 | EnsureSparseVariableAccess<Device, T>(context, v.get())); |
329 | mutex_lock ml(*v->mu()); |
330 | old_lhs = v->tensor(); |
331 | OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value, |
332 | errors::InvalidArgument( |
333 | "l-value dtype " , DataTypeString(old_lhs->dtype()), |
334 | " does not match r-value dtype " , |
335 | DataTypeString(DataTypeToEnum<T>::value))); |
336 | } else { |
337 | context->forward_ref_input_to_ref_output(0, 0); |
338 | tmp = context->mutable_input(0, true); |
339 | old_lhs = &tmp; |
340 | } |
341 | } |
342 | |
343 | StridedSliceShapeSpec shape_spec; |
344 | OP_REQUIRES_OK( |
345 | context, ValidateStridedSliceOp( |
346 | &context->input(1), &context->input(2), context->input(3), |
347 | old_lhs->shape(), begin_mask, end_mask, ellipsis_mask, |
348 | new_axis_mask, shrink_axis_mask, &processing_shape, |
349 | &final_shape, &is_identity, &is_simple_slice, &slice_dim0, |
350 | &begin, &end, &strides, &shape_spec)); |
351 | |
352 | if (processing_shape.num_elements() > 0) { |
353 | const Tensor& input = context->input(4); |
354 | TensorShape input_shape = input.shape(); |
355 | TensorShape original_shape = old_lhs->shape(); |
356 | const int processing_dims = processing_shape.dims(); |
357 | |
358 | StridedSliceAssignBCast bcast(input_shape.dim_sizes(), |
359 | final_shape.dim_sizes()); |
360 | OP_REQUIRES(context, bcast.IsValid(), |
361 | errors::InvalidArgument("Cannot broadcast input shape " , |
362 | input_shape.DebugString(), |
363 | " into final shape " , |
364 | final_shape.DebugString())); |
365 | |
366 | // The assignment RHS and broadcast spec need to be remapped to |
367 | // the same number of dimensions as the unstrided LHS (i.e. processing |
368 | // dimensions). This adds back any shrink axes and removes new axes. |
369 | bool remap_valid = bcast.RemapDimensions( |
370 | processing_dims, shape_spec.output_to_processing_mapping); |
371 | // Sanity check. The following should never fail. |
372 | DCHECK(remap_valid) << "Failed to remap output shape " |
373 | << final_shape.DebugString() |
374 | << " to processing shape " |
375 | << processing_shape.DebugString(); |
376 | |
377 | // 0-dimensional case implies the left and right are exactly the same |
378 | // scalar shape |
379 | |
380 | // Handle general dimensions. The do{}while(false) construct is a common |
381 | // approach to avoid pedantic extra semicolon warnings. |
382 | #define HANDLE_DIM(NDIM) \ |
383 | do { \ |
384 | if (processing_dims == NDIM) { \ |
385 | HandleStridedSliceAssignCase<Device, T, NDIM>()( \ |
386 | context, begin, end, strides, bcast, old_lhs); \ |
387 | return; \ |
388 | } \ |
389 | } while (false) |
390 | HANDLE_DIM(0); |
391 | HANDLE_DIM(1); |
392 | HANDLE_DIM(2); |
393 | HANDLE_DIM(3); |
394 | HANDLE_DIM(4); |
395 | HANDLE_DIM(5); |
396 | HANDLE_DIM(6); |
397 | HANDLE_DIM(7); |
398 | HANDLE_DIM(8); |
399 | #undef HANDLE_DIM |
400 | |
401 | OP_REQUIRES(context, false, |
402 | errors::Unimplemented("Unhandled input dimensions " , |
403 | processing_dims)); |
404 | } |
405 | } |
406 | |
407 | private: |
408 | int32 begin_mask, end_mask; |
409 | int32 ellipsis_mask, new_axis_mask, shrink_axis_mask; |
410 | }; |
411 | |
412 | #define REGISTER_STRIDED_SLICE(type) \ |
413 | REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ |
414 | .Device(DEVICE_CPU) \ |
415 | .TypeConstraint<type>("T") \ |
416 | .HostMemory("begin") \ |
417 | .HostMemory("end") \ |
418 | .HostMemory("strides"), \ |
419 | StridedSliceOp<CPUDevice, type>) \ |
420 | REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ |
421 | .Device(DEVICE_CPU) \ |
422 | .TypeConstraint<type>("T") \ |
423 | .HostMemory("shape") \ |
424 | .HostMemory("begin") \ |
425 | .HostMemory("end") \ |
426 | .HostMemory("strides"), \ |
427 | StridedSliceGradOp<CPUDevice, type>) \ |
428 | REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ |
429 | .Device(DEVICE_CPU) \ |
430 | .TypeConstraint<type>("T") \ |
431 | .HostMemory("begin") \ |
432 | .HostMemory("end") \ |
433 | .HostMemory("strides"), \ |
434 | StridedSliceAssignOp<CPUDevice, type, false>) \ |
435 | REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ |
436 | .Device(DEVICE_CPU) \ |
437 | .TypeConstraint<type>("T") \ |
438 | .HostMemory("ref") \ |
439 | .HostMemory("begin") \ |
440 | .HostMemory("end") \ |
441 | .HostMemory("strides"), \ |
442 | StridedSliceAssignOp<CPUDevice, type, false>) \ |
443 | REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \ |
444 | .Device(DEVICE_CPU) \ |
445 | .TypeConstraint<type>("T") \ |
446 | .HostMemory("begin") \ |
447 | .HostMemory("end") \ |
448 | .HostMemory("strides"), \ |
449 | StridedSliceAssignOp<CPUDevice, type, true>) |
450 | |
451 | TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE); |
452 | TF_CALL_QUANTIZED_TYPES(REGISTER_STRIDED_SLICE); |
453 | |
454 | #undef REGISTER_STRIDED_SLICE |
455 | |
456 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
457 | |
458 | #define REGISTER_GPU(type) \ |
459 | REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ |
460 | .Device(DEVICE_GPU) \ |
461 | .TypeConstraint<type>("T") \ |
462 | .HostMemory("begin") \ |
463 | .HostMemory("end") \ |
464 | .HostMemory("strides"), \ |
465 | StridedSliceOp<GPUDevice, type>) \ |
466 | REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ |
467 | .Device(DEVICE_GPU) \ |
468 | .TypeConstraint<type>("T") \ |
469 | .HostMemory("shape") \ |
470 | .HostMemory("begin") \ |
471 | .HostMemory("end") \ |
472 | .HostMemory("strides"), \ |
473 | StridedSliceGradOp<GPUDevice, type>) \ |
474 | REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ |
475 | .Device(DEVICE_GPU) \ |
476 | .TypeConstraint<type>("T") \ |
477 | .HostMemory("begin") \ |
478 | .HostMemory("end") \ |
479 | .HostMemory("strides"), \ |
480 | StridedSliceAssignOp<GPUDevice, type, false>) \ |
481 | REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ |
482 | .Device(DEVICE_GPU) \ |
483 | .TypeConstraint<type>("T") \ |
484 | .HostMemory("ref") \ |
485 | .HostMemory("begin") \ |
486 | .HostMemory("end") \ |
487 | .HostMemory("strides"), \ |
488 | StridedSliceAssignOp<GPUDevice, type, false>) \ |
489 | REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \ |
490 | .Device(DEVICE_GPU) \ |
491 | .TypeConstraint<type>("T") \ |
492 | .HostMemory("begin") \ |
493 | .HostMemory("end") \ |
494 | .HostMemory("strides"), \ |
495 | StridedSliceAssignOp<GPUDevice, type, true>) |
496 | |
497 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU); |
498 | TF_CALL_GPU_ALL_TYPES(REGISTER_GPU); |
499 | |
500 | // A special GPU kernel for int32. |
501 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
502 | // registration requires all int32 inputs and outputs to be in host memory. |
503 | REGISTER_KERNEL_BUILDER(Name("StridedSlice" ) |
504 | .Device(DEVICE_GPU) |
505 | .TypeConstraint<int32>("T" ) |
506 | .HostMemory("input" ) |
507 | .HostMemory("begin" ) |
508 | .HostMemory("end" ) |
509 | .HostMemory("strides" ) |
510 | .HostMemory("output" ), |
511 | StridedSliceOp<CPUDevice, int32>); |
512 | REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad" ) |
513 | .Device(DEVICE_GPU) |
514 | .TypeConstraint<int32>("T" ) |
515 | .HostMemory("shape" ) |
516 | .HostMemory("begin" ) |
517 | .HostMemory("end" ) |
518 | .HostMemory("strides" ) |
519 | .HostMemory("dy" ) |
520 | .HostMemory("output" ), |
521 | StridedSliceGradOp<CPUDevice, int32>); |
522 | REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign" ) |
523 | .Device(DEVICE_GPU) |
524 | .TypeConstraint<int32>("T" ) |
525 | .HostMemory("ref" ) |
526 | .HostMemory("begin" ) |
527 | .HostMemory("end" ) |
528 | .HostMemory("strides" ) |
529 | .HostMemory("value" ), |
530 | StridedSliceAssignOp<CPUDevice, int32, false>); |
531 | REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign" ) |
532 | .Device(DEVICE_GPU) |
533 | .TypeConstraint<int32>("T" ) |
534 | .HostMemory("ref" ) |
535 | .HostMemory("begin" ) |
536 | .HostMemory("end" ) |
537 | .HostMemory("strides" ) |
538 | .HostMemory("value" ), |
539 | StridedSliceAssignOp<CPUDevice, int32, false>); |
540 | REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate" ) |
541 | .Device(DEVICE_GPU) |
542 | .TypeConstraint<int32>("T" ) |
543 | .HostMemory("input" ) |
544 | .HostMemory("begin" ) |
545 | .HostMemory("end" ) |
546 | .HostMemory("strides" ) |
547 | .HostMemory("value" ) |
548 | .HostMemory("output" ), |
549 | StridedSliceAssignOp<CPUDevice, int32, true>); |
550 | #undef REGISTER_GPU |
551 | |
552 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
553 | |
554 | } // namespace tensorflow |
555 | |