1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_ |
18 | |
19 | #define EIGEN_USE_THREADS |
20 | |
21 | #include "tensorflow/core/kernels/conditional_accumulator_base.h" |
22 | |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/resource_mgr.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | #include "tensorflow/core/platform/macros.h" |
31 | #include "tensorflow/core/platform/mutex.h" |
32 | #include "tensorflow/core/platform/thread_annotations.h" |
33 | #include "tensorflow/core/platform/types.h" |
34 | |
35 | typedef Eigen::ThreadPoolDevice CPUDevice; |
36 | |
37 | typedef std::function<void()> DoneCallback; |
38 | |
39 | namespace tensorflow { |
40 | |
41 | /** |
42 | * Defines a ConditionalAccumulatorBaseOp, which constructs a |
43 | * ConditionalAccumulatorBase (via sub-class's Creator) and returns its handle. |
44 | */ |
45 | class ConditionalAccumulatorBaseOp : public OpKernel { |
46 | public: |
47 | explicit ConditionalAccumulatorBaseOp(OpKernelConstruction* context) |
48 | : OpKernel(context), accumulator_set_(false) { |
49 | OP_REQUIRES_OK(context, context->allocate_temp(DT_STRING, TensorShape({2}), |
50 | &accumulator_)); |
51 | OP_REQUIRES_OK(context, context->GetAttr("shape" , &shape_)); |
52 | OP_REQUIRES_OK(context, context->GetAttr("dtype" , &dtype_)); |
53 | OP_REQUIRES_OK(context, |
54 | context->GetAttr("reduction_type" , &reduction_type_)); |
55 | } |
56 | |
57 | void Compute(OpKernelContext* ctx) override { |
58 | mutex_lock l(mu_); |
59 | if (!accumulator_set_) { |
60 | OP_REQUIRES_OK(ctx, SetAccumulatorHandle(ctx)); |
61 | } |
62 | SetHandleToOutput(ctx); |
63 | } |
64 | |
65 | protected: |
66 | ~ConditionalAccumulatorBaseOp() override { |
67 | // If the accumulator object was not shared, delete it. |
68 | if (accumulator_set_ && cinfo_.resource_is_private_to_kernel()) { |
69 | TF_CHECK_OK((cinfo_.resource_manager() |
70 | ->template Delete<ConditionalAccumulatorBase>( |
71 | cinfo_.container(), cinfo_.name()))); |
72 | } |
73 | } |
74 | |
75 | protected: |
76 | virtual void SetHandleToOutput(OpKernelContext* ctx) |
77 | TF_SHARED_LOCKS_REQUIRED(mu_) = 0; |
78 | |
79 | virtual Status CheckSignature(OpKernelContext* ctx) = 0; |
80 | |
81 | protected: |
82 | typedef std::function<Status(ConditionalAccumulatorBase**)> Creator; |
83 | |
84 | // Subclasses must override this |
85 | virtual Creator GetCreator() const = 0; |
86 | |
87 | // Variables required to construct ConditionalAccumulator |
88 | DataType dtype_; |
89 | PartialTensorShape shape_; |
90 | ContainerInfo cinfo_; |
91 | string reduction_type_; |
92 | mutex mu_; |
93 | Tensor accumulator_ TF_GUARDED_BY(mu_); |
94 | bool accumulator_set_ TF_GUARDED_BY(mu_); |
95 | |
96 | private: |
97 | Status SetAccumulatorHandle(OpKernelContext* ctx) |
98 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
99 | TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); |
100 | |
101 | // Check input signature |
102 | TF_RETURN_IF_ERROR(CheckSignature(ctx)); |
103 | |
104 | Creator creator = GetCreator(); |
105 | ConditionalAccumulatorBase* accumulator; |
106 | TF_RETURN_IF_ERROR( |
107 | (cinfo_.resource_manager() |
108 | ->template LookupOrCreate<ConditionalAccumulatorBase>( |
109 | cinfo_.container(), cinfo_.name(), &accumulator, creator))); |
110 | core::ScopedUnref unref_me(accumulator); |
111 | |
112 | // Verify that the shared accumulator is compatible |
113 | // with the requested arguments. |
114 | TF_RETURN_IF_ERROR(accumulator->MatchesNodeDef(def())); |
115 | auto h = accumulator_.template flat<tstring>(); |
116 | h(0) = cinfo_.container(); |
117 | h(1) = cinfo_.name(); |
118 | accumulator_set_ = true; |
119 | return OkStatus(); |
120 | } |
121 | }; |
122 | |
123 | // ------------------Sync kernels ------------------------------------------ |
124 | |
125 | /** |
126 | * General OpKernel for ConditionalAccumulatorBase-related ops. |
127 | */ |
128 | class ConditionalAccumulatorBaseSyncOpKernel : public OpKernel { |
129 | public: |
130 | explicit ConditionalAccumulatorBaseSyncOpKernel(OpKernelConstruction* context) |
131 | : OpKernel(context) {} |
132 | |
133 | void Compute(OpKernelContext* ctx) final { |
134 | ConditionalAccumulatorBase* accumulator; |
135 | OP_REQUIRES_OK(ctx, GetResourceFromContext(ctx, "handle" , &accumulator)); |
136 | Compute(ctx, accumulator); |
137 | accumulator->Unref(); |
138 | } |
139 | |
140 | protected: |
141 | virtual void Compute(OpKernelContext* ctx, |
142 | ConditionalAccumulatorBase* accumulator) = 0; |
143 | |
144 | virtual DataTypeVector GetExpectedInputs( |
145 | ConditionalAccumulatorBase* accumulator) = 0; |
146 | |
147 | virtual void CheckSignature(OpKernelContext* ctx, |
148 | ConditionalAccumulatorBase* accumulator) { |
149 | // Check input signature |
150 | DataTypeVector expected_inputs = GetExpectedInputs(accumulator); |
151 | OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); |
152 | } |
153 | }; |
154 | |
155 | /** |
156 | * Defines a AccumulateGradientOp, the execution of which adds a gradient to the |
157 | * given ConditionalAccumulator. |
158 | */ |
159 | class ConditionalAccumulatorBaseApplyGradientOp |
160 | : public ConditionalAccumulatorBaseSyncOpKernel { |
161 | public: |
162 | explicit ConditionalAccumulatorBaseApplyGradientOp( |
163 | OpKernelConstruction* context) |
164 | : ConditionalAccumulatorBaseSyncOpKernel(context) {} |
165 | |
166 | protected: |
167 | void Compute(OpKernelContext* ctx, |
168 | ConditionalAccumulatorBase* accumulator) override { |
169 | // Check input signature |
170 | CheckSignature(ctx, accumulator); |
171 | |
172 | // Get input local_step |
173 | const Tensor* local_step_tensor; |
174 | OP_REQUIRES_OK(ctx, ctx->input("local_step" , &local_step_tensor)); |
175 | if (!TensorShapeUtils::IsScalar(local_step_tensor->shape())) { |
176 | ctx->CtxFailureWithWarning(errors::InvalidArgument( |
177 | "Argument local_step must be scalar, but had bad shape " , |
178 | local_step_tensor->shape().DebugString())); |
179 | } |
180 | |
181 | // Actually try to apply gradient now |
182 | accumulator->TryApplyGrad(local_step_tensor->scalar<int64_t>()(), ctx); |
183 | } |
184 | }; |
185 | |
186 | // -------------------- Async kernels -------------------------------------- |
187 | /** |
188 | * General OpKernel for ConditionalAccumulatorBase-related ops. |
189 | */ |
190 | class ConditionalAccumulatorBaseAsyncOpKernel : public AsyncOpKernel { |
191 | public: |
192 | explicit ConditionalAccumulatorBaseAsyncOpKernel( |
193 | OpKernelConstruction* context) |
194 | : AsyncOpKernel(context) {} |
195 | |
196 | void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { |
197 | ConditionalAccumulatorBase* accumulator; |
198 | OP_REQUIRES_OK_ASYNC( |
199 | ctx, GetResourceFromContext(ctx, "handle" , &accumulator), callback); |
200 | ComputeAsync(ctx, accumulator, [callback, accumulator]() { |
201 | accumulator->Unref(); |
202 | callback(); |
203 | }); |
204 | } |
205 | |
206 | protected: |
207 | virtual void ComputeAsync(OpKernelContext* ctx, |
208 | ConditionalAccumulatorBase* accumulator, |
209 | DoneCallback callback) = 0; |
210 | |
211 | virtual DataTypeVector GetExpectedInputs( |
212 | ConditionalAccumulatorBase* accumulator) = 0; |
213 | |
214 | virtual void CheckSignature(OpKernelContext* ctx, |
215 | ConditionalAccumulatorBase* accumulator, |
216 | DoneCallback callback) { |
217 | // Check input signature |
218 | OP_REQUIRES_OK_ASYNC(ctx, |
219 | ctx->MatchSignature(GetExpectedInputs(accumulator), |
220 | {accumulator->dtype()}), |
221 | callback); |
222 | } |
223 | }; |
224 | |
225 | /** |
226 | * Defines a TakeAccumulatedGradientOp, the execution of which adds a gradient |
227 | * to the given ConditionalAccumulator. |
228 | */ |
229 | class ConditionalAccumulatorBaseTakeGradientOp |
230 | : public ConditionalAccumulatorBaseAsyncOpKernel { |
231 | public: |
232 | explicit ConditionalAccumulatorBaseTakeGradientOp( |
233 | OpKernelConstruction* context) |
234 | : ConditionalAccumulatorBaseAsyncOpKernel(context) {} |
235 | |
236 | protected: |
237 | void ComputeAsync(OpKernelContext* ctx, |
238 | ConditionalAccumulatorBase* accumulator, |
239 | DoneCallback callback) override { |
240 | // Check signature |
241 | CheckSignature(ctx, accumulator, callback); |
242 | |
243 | // Get input num_required |
244 | const Tensor* num_required_tensor; |
245 | OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_required" , &num_required_tensor), |
246 | callback); |
247 | if (!TensorShapeUtils::IsScalar(num_required_tensor->shape())) { |
248 | ctx->CtxFailureWithWarning(errors::InvalidArgument( |
249 | "Argument num_required must be scalar, but had bad shape " , |
250 | num_required_tensor->shape().DebugString())); |
251 | callback(); |
252 | } |
253 | |
254 | // Actually try to take gradient now |
255 | accumulator->TryTakeGrad(num_required_tensor->scalar<int32>()(), ctx, |
256 | callback); |
257 | } |
258 | }; |
259 | |
260 | } // namespace tensorflow |
261 | |
262 | #endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_ |
263 | |