1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
17#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
18
19#define EIGEN_USE_THREADS
20
21#include "tensorflow/core/kernels/conditional_accumulator_base.h"
22
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/resource_mgr.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/platform/macros.h"
31#include "tensorflow/core/platform/mutex.h"
32#include "tensorflow/core/platform/thread_annotations.h"
33#include "tensorflow/core/platform/types.h"
34
35typedef Eigen::ThreadPoolDevice CPUDevice;
36
37typedef std::function<void()> DoneCallback;
38
39namespace tensorflow {
40
41/**
42 * Defines a ConditionalAccumulatorBaseOp, which constructs a
43 * ConditionalAccumulatorBase (via sub-class's Creator) and returns its handle.
44 */
45class ConditionalAccumulatorBaseOp : public OpKernel {
46 public:
47 explicit ConditionalAccumulatorBaseOp(OpKernelConstruction* context)
48 : OpKernel(context), accumulator_set_(false) {
49 OP_REQUIRES_OK(context, context->allocate_temp(DT_STRING, TensorShape({2}),
50 &accumulator_));
51 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
52 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
53 OP_REQUIRES_OK(context,
54 context->GetAttr("reduction_type", &reduction_type_));
55 }
56
57 void Compute(OpKernelContext* ctx) override {
58 mutex_lock l(mu_);
59 if (!accumulator_set_) {
60 OP_REQUIRES_OK(ctx, SetAccumulatorHandle(ctx));
61 }
62 SetHandleToOutput(ctx);
63 }
64
65 protected:
66 ~ConditionalAccumulatorBaseOp() override {
67 // If the accumulator object was not shared, delete it.
68 if (accumulator_set_ && cinfo_.resource_is_private_to_kernel()) {
69 TF_CHECK_OK((cinfo_.resource_manager()
70 ->template Delete<ConditionalAccumulatorBase>(
71 cinfo_.container(), cinfo_.name())));
72 }
73 }
74
75 protected:
76 virtual void SetHandleToOutput(OpKernelContext* ctx)
77 TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
78
79 virtual Status CheckSignature(OpKernelContext* ctx) = 0;
80
81 protected:
82 typedef std::function<Status(ConditionalAccumulatorBase**)> Creator;
83
84 // Subclasses must override this
85 virtual Creator GetCreator() const = 0;
86
87 // Variables required to construct ConditionalAccumulator
88 DataType dtype_;
89 PartialTensorShape shape_;
90 ContainerInfo cinfo_;
91 string reduction_type_;
92 mutex mu_;
93 Tensor accumulator_ TF_GUARDED_BY(mu_);
94 bool accumulator_set_ TF_GUARDED_BY(mu_);
95
96 private:
97 Status SetAccumulatorHandle(OpKernelContext* ctx)
98 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
99 TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
100
101 // Check input signature
102 TF_RETURN_IF_ERROR(CheckSignature(ctx));
103
104 Creator creator = GetCreator();
105 ConditionalAccumulatorBase* accumulator;
106 TF_RETURN_IF_ERROR(
107 (cinfo_.resource_manager()
108 ->template LookupOrCreate<ConditionalAccumulatorBase>(
109 cinfo_.container(), cinfo_.name(), &accumulator, creator)));
110 core::ScopedUnref unref_me(accumulator);
111
112 // Verify that the shared accumulator is compatible
113 // with the requested arguments.
114 TF_RETURN_IF_ERROR(accumulator->MatchesNodeDef(def()));
115 auto h = accumulator_.template flat<tstring>();
116 h(0) = cinfo_.container();
117 h(1) = cinfo_.name();
118 accumulator_set_ = true;
119 return OkStatus();
120 }
121};
122
123// ------------------Sync kernels ------------------------------------------
124
125/**
126 * General OpKernel for ConditionalAccumulatorBase-related ops.
127 */
128class ConditionalAccumulatorBaseSyncOpKernel : public OpKernel {
129 public:
130 explicit ConditionalAccumulatorBaseSyncOpKernel(OpKernelConstruction* context)
131 : OpKernel(context) {}
132
133 void Compute(OpKernelContext* ctx) final {
134 ConditionalAccumulatorBase* accumulator;
135 OP_REQUIRES_OK(ctx, GetResourceFromContext(ctx, "handle", &accumulator));
136 Compute(ctx, accumulator);
137 accumulator->Unref();
138 }
139
140 protected:
141 virtual void Compute(OpKernelContext* ctx,
142 ConditionalAccumulatorBase* accumulator) = 0;
143
144 virtual DataTypeVector GetExpectedInputs(
145 ConditionalAccumulatorBase* accumulator) = 0;
146
147 virtual void CheckSignature(OpKernelContext* ctx,
148 ConditionalAccumulatorBase* accumulator) {
149 // Check input signature
150 DataTypeVector expected_inputs = GetExpectedInputs(accumulator);
151 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
152 }
153};
154
155/**
156 * Defines a AccumulateGradientOp, the execution of which adds a gradient to the
157 * given ConditionalAccumulator.
158 */
159class ConditionalAccumulatorBaseApplyGradientOp
160 : public ConditionalAccumulatorBaseSyncOpKernel {
161 public:
162 explicit ConditionalAccumulatorBaseApplyGradientOp(
163 OpKernelConstruction* context)
164 : ConditionalAccumulatorBaseSyncOpKernel(context) {}
165
166 protected:
167 void Compute(OpKernelContext* ctx,
168 ConditionalAccumulatorBase* accumulator) override {
169 // Check input signature
170 CheckSignature(ctx, accumulator);
171
172 // Get input local_step
173 const Tensor* local_step_tensor;
174 OP_REQUIRES_OK(ctx, ctx->input("local_step", &local_step_tensor));
175 if (!TensorShapeUtils::IsScalar(local_step_tensor->shape())) {
176 ctx->CtxFailureWithWarning(errors::InvalidArgument(
177 "Argument local_step must be scalar, but had bad shape ",
178 local_step_tensor->shape().DebugString()));
179 }
180
181 // Actually try to apply gradient now
182 accumulator->TryApplyGrad(local_step_tensor->scalar<int64_t>()(), ctx);
183 }
184};
185
186// -------------------- Async kernels --------------------------------------
187/**
188 * General OpKernel for ConditionalAccumulatorBase-related ops.
189 */
190class ConditionalAccumulatorBaseAsyncOpKernel : public AsyncOpKernel {
191 public:
192 explicit ConditionalAccumulatorBaseAsyncOpKernel(
193 OpKernelConstruction* context)
194 : AsyncOpKernel(context) {}
195
196 void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
197 ConditionalAccumulatorBase* accumulator;
198 OP_REQUIRES_OK_ASYNC(
199 ctx, GetResourceFromContext(ctx, "handle", &accumulator), callback);
200 ComputeAsync(ctx, accumulator, [callback, accumulator]() {
201 accumulator->Unref();
202 callback();
203 });
204 }
205
206 protected:
207 virtual void ComputeAsync(OpKernelContext* ctx,
208 ConditionalAccumulatorBase* accumulator,
209 DoneCallback callback) = 0;
210
211 virtual DataTypeVector GetExpectedInputs(
212 ConditionalAccumulatorBase* accumulator) = 0;
213
214 virtual void CheckSignature(OpKernelContext* ctx,
215 ConditionalAccumulatorBase* accumulator,
216 DoneCallback callback) {
217 // Check input signature
218 OP_REQUIRES_OK_ASYNC(ctx,
219 ctx->MatchSignature(GetExpectedInputs(accumulator),
220 {accumulator->dtype()}),
221 callback);
222 }
223};
224
225/**
226 * Defines a TakeAccumulatedGradientOp, the execution of which adds a gradient
227 * to the given ConditionalAccumulator.
228 */
229class ConditionalAccumulatorBaseTakeGradientOp
230 : public ConditionalAccumulatorBaseAsyncOpKernel {
231 public:
232 explicit ConditionalAccumulatorBaseTakeGradientOp(
233 OpKernelConstruction* context)
234 : ConditionalAccumulatorBaseAsyncOpKernel(context) {}
235
236 protected:
237 void ComputeAsync(OpKernelContext* ctx,
238 ConditionalAccumulatorBase* accumulator,
239 DoneCallback callback) override {
240 // Check signature
241 CheckSignature(ctx, accumulator, callback);
242
243 // Get input num_required
244 const Tensor* num_required_tensor;
245 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_required", &num_required_tensor),
246 callback);
247 if (!TensorShapeUtils::IsScalar(num_required_tensor->shape())) {
248 ctx->CtxFailureWithWarning(errors::InvalidArgument(
249 "Argument num_required must be scalar, but had bad shape ",
250 num_required_tensor->shape().DebugString()));
251 callback();
252 }
253
254 // Actually try to take gradient now
255 accumulator->TryTakeGrad(num_required_tensor->scalar<int32>()(), ctx,
256 callback);
257 }
258};
259
260} // namespace tensorflow
261
262#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
263