1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_kernel.h" |
17 | #include "tensorflow/core/kernels/variable_ops.h" |
18 | #include "tensorflow/core/lib/core/errors.h" |
19 | #include "tensorflow/core/lib/core/refcount.h" |
20 | #include "tensorflow/core/platform/mutex.h" |
21 | #include "tensorflow/core/platform/types.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | template <class T> |
26 | class CountUpToOp : public OpKernel { |
27 | public: |
28 | explicit CountUpToOp(OpKernelConstruction* context) : OpKernel(context) { |
29 | OP_REQUIRES_OK(context, context->GetAttr("limit" , &limit_)); |
30 | } |
31 | |
32 | void Compute(OpKernelContext* context) override { |
33 | T before_increment; |
34 | { |
35 | mutex_lock l(*context->input_ref_mutex(0)); |
36 | Tensor tensor = context->mutable_input(0, true); |
37 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(tensor.shape()), |
38 | errors::InvalidArgument("input is not a scalar: " , |
39 | tensor.shape().DebugString())); |
40 | T* ptr = &tensor.scalar<T>()(); |
41 | before_increment = *ptr; |
42 | if (*ptr >= limit_) { |
43 | context->SetStatus(errors::OutOfRange("Reached limit of " , limit_)); |
44 | return; |
45 | } |
46 | ++*ptr; |
47 | } |
48 | // Output if no error. |
49 | Tensor* out_tensor; |
50 | OP_REQUIRES_OK(context, context->allocate_output("output" , TensorShape({}), |
51 | &out_tensor)); |
52 | out_tensor->scalar<T>()() = before_increment; |
53 | } |
54 | |
55 | private: |
56 | T limit_; |
57 | }; |
58 | |
59 | template <class T> |
60 | class ResourceCountUpToOp : public OpKernel { |
61 | public: |
62 | explicit ResourceCountUpToOp(OpKernelConstruction* context) |
63 | : OpKernel(context) { |
64 | OP_REQUIRES_OK(context, context->GetAttr("limit" , &limit_)); |
65 | OP_REQUIRES_OK(context, context->GetAttr("T" , &dtype_)); |
66 | } |
67 | |
68 | void Compute(OpKernelContext* context) override { |
69 | core::RefCountPtr<Var> variable; |
70 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), |
71 | &variable)); |
72 | mutex_lock l(*variable->mu()); |
73 | Tensor before_increment = *variable->tensor(); |
74 | OP_REQUIRES( |
75 | context, TensorShapeUtils::IsScalar(before_increment.shape()), |
76 | errors::InvalidArgument("input is not a scalar: " , |
77 | before_increment.shape().DebugString())); |
78 | if (before_increment.scalar<T>()() >= limit_) { |
79 | context->SetStatus(errors::OutOfRange("Reached limit of " , limit_)); |
80 | return; |
81 | } |
82 | // Allocate new buffer |
83 | AllocatorAttributes attr; |
84 | attr.set_gpu_compatible(true); |
85 | attr.set_nic_compatible(true); |
86 | OP_REQUIRES_OK(context, context->allocate_temp(dtype_, TensorShape({}), |
87 | variable->tensor(), attr)); |
88 | variable->tensor()->scalar<T>()() = before_increment.scalar<T>()() + 1; |
89 | context->set_output(0, before_increment); |
90 | } |
91 | |
92 | private: |
93 | T limit_; |
94 | DataType dtype_; |
95 | }; |
96 | |
97 | #define REGISTER(TYPE) \ |
98 | REGISTER_KERNEL_BUILDER( \ |
99 | Name("CountUpTo").TypeConstraint<TYPE>("T").Device(DEVICE_CPU), \ |
100 | CountUpToOp<TYPE>) \ |
101 | REGISTER_KERNEL_BUILDER( \ |
102 | Name("ResourceCountUpTo").TypeConstraint<TYPE>("T").Device(DEVICE_CPU), \ |
103 | ResourceCountUpToOp<TYPE>) |
104 | |
105 | REGISTER(int32); |
106 | REGISTER(int64_t); |
107 | |
108 | #undef REGISTER |
109 | |
110 | } // namespace tensorflow |
111 | |