1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/conditional_accumulator.h"
19#include "tensorflow/core/kernels/conditional_accumulator_base_op.h"
20
21namespace tensorflow {
22
23/**
24 * Defines a ConditionalAccumulatorOp, which constructs a ConditionalAccumulator
25 * and returns its handle.
26 */
27template <typename Device, typename T>
28class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
29 public:
30 explicit ConditionalAccumulatorOp(OpKernelConstruction* context)
31 : ConditionalAccumulatorBaseOp(context) {}
32
33 protected:
34 Creator GetCreator() const override {
35 return [this](ConditionalAccumulatorBase** ret) {
36 ConditionalAccumulator<Device, T>* accumulator =
37 new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
38 reduction_type_);
39 *ret = accumulator;
40 return OkStatus();
41 };
42 }
43
44 Status CheckSignature(OpKernelContext* ctx) override {
45 TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF}));
46 return OkStatus();
47 }
48
49 void SetHandleToOutput(OpKernelContext* ctx)
50 TF_SHARED_LOCKS_REQUIRED(mu_) override {
51 ctx->set_output_ref(0, &mu_, &accumulator_);
52 }
53
54 TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulatorOp);
55};
56
57#define REGISTER_KERNELS(type, dev) \
58 REGISTER_KERNEL_BUILDER(Name("ConditionalAccumulator") \
59 .Device(DEVICE_##dev) \
60 .TypeConstraint<type>("dtype"), \
61 ConditionalAccumulatorOp<dev##Device, type>)
62
63// Resource conditional accumulator
64template <typename Device, typename T>
65class ResourceConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
66 public:
67 explicit ResourceConditionalAccumulatorOp(OpKernelConstruction* context)
68 : ConditionalAccumulatorBaseOp(context) {}
69
70 protected:
71 Creator GetCreator() const override {
72 return [this](ConditionalAccumulatorBase** ret) {
73 ConditionalAccumulator<Device, T>* accumulator =
74 new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
75 reduction_type_);
76 *ret = accumulator;
77 return OkStatus();
78 };
79 }
80
81 Status CheckSignature(OpKernelContext* ctx) override {
82 TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_RESOURCE}));
83 return OkStatus();
84 }
85
86 void SetHandleToOutput(OpKernelContext* ctx)
87 TF_SHARED_LOCKS_REQUIRED(mu_) override {
88 auto h = accumulator_.template flat<tstring>();
89 h(0) = cinfo_.container();
90 h(1) = cinfo_.name();
91 OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
92 ctx, 0, cinfo_.container(), cinfo_.name(),
93 TypeIndex::Make<ConditionalAccumulatorBase>()));
94 }
95
96 TF_DISALLOW_COPY_AND_ASSIGN(ResourceConditionalAccumulatorOp);
97};
98
99#define REGISTER_RESOURCE_KERNELS(type, dev) \
100 REGISTER_KERNEL_BUILDER(Name("ResourceConditionalAccumulator") \
101 .Device(DEVICE_##dev) \
102 .TypeConstraint<type>("dtype"), \
103 ResourceConditionalAccumulatorOp<dev##Device, type>)
104
105// End of Resource conditional accumulator
106
107#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU)
108
109TF_CALL_half(REGISTER_KERNELS_CPU);
110TF_CALL_float(REGISTER_KERNELS_CPU);
111TF_CALL_double(REGISTER_KERNELS_CPU);
112
113#undef REGISTER_KERNELS_CPU
114#undef REGISTER_KERNELS
115
116#define REGISTER_RESOURCE_KERNELS_CPU(type) REGISTER_RESOURCE_KERNELS(type, CPU)
117
118TF_CALL_half(REGISTER_RESOURCE_KERNELS_CPU);
119TF_CALL_float(REGISTER_RESOURCE_KERNELS_CPU);
120TF_CALL_double(REGISTER_RESOURCE_KERNELS_CPU);
121
122#undef REGISTER_KERNELS_CPU
123#undef REGISTER_KERNELS
124
125/**
126 * Defines a AccumulateGradientOp, the execution of which adds a gradient to the
127 * given ConditionalAccumulator.
128 */
129class AccumulatorApplyGradientOp
130 : public ConditionalAccumulatorBaseApplyGradientOp {
131 public:
132 explicit AccumulatorApplyGradientOp(OpKernelConstruction* context)
133 : ConditionalAccumulatorBaseApplyGradientOp(context) {}
134
135 DataTypeVector GetExpectedInputs(
136 ConditionalAccumulatorBase* accumulator) override {
137 DataTypeVector expected_inputs;
138 expected_inputs = {DT_STRING_REF, DT_INT64};
139 expected_inputs.push_back(accumulator->dtype());
140 return expected_inputs;
141 }
142
143 private:
144 TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorApplyGradientOp);
145};
146
147REGISTER_KERNEL_BUILDER(Name("AccumulatorApplyGradient").Device(DEVICE_CPU),
148 AccumulatorApplyGradientOp);
149
150class ResourceAccumulatorApplyGradientOp
151 : public ConditionalAccumulatorBaseApplyGradientOp {
152 public:
153 explicit ResourceAccumulatorApplyGradientOp(OpKernelConstruction* context)
154 : ConditionalAccumulatorBaseApplyGradientOp(context) {}
155
156 DataTypeVector GetExpectedInputs(
157 ConditionalAccumulatorBase* accumulator) override {
158 DataTypeVector expected_inputs;
159 expected_inputs = {DT_RESOURCE, DT_INT64};
160 expected_inputs.push_back(accumulator->dtype());
161 return expected_inputs;
162 }
163
164 private:
165 TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorApplyGradientOp);
166};
167
168REGISTER_KERNEL_BUILDER(
169 Name("ResourceAccumulatorApplyGradient").Device(DEVICE_CPU),
170 ResourceAccumulatorApplyGradientOp);
171
172/**
173 * Defines a ConditionalAccumulatorBaseTakeGradientOp, the execution of which
174 * returns the average gradient accumulated by the given ConditionalAccumulator.
175 */
176class AccumulatorTakeGradientOp
177 : public ConditionalAccumulatorBaseTakeGradientOp {
178 public:
179 explicit AccumulatorTakeGradientOp(OpKernelConstruction* context)
180 : ConditionalAccumulatorBaseTakeGradientOp(context) {}
181
182 DataTypeVector GetExpectedInputs(
183 ConditionalAccumulatorBase* accumulator) override {
184 return {DT_STRING_REF, DT_INT32};
185 }
186
187 private:
188 TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorTakeGradientOp);
189};
190REGISTER_KERNEL_BUILDER(Name("AccumulatorTakeGradient").Device(DEVICE_CPU),
191 AccumulatorTakeGradientOp);
192
193class ResourceAccumulatorTakeGradientOp
194 : public ConditionalAccumulatorBaseTakeGradientOp {
195 public:
196 explicit ResourceAccumulatorTakeGradientOp(OpKernelConstruction* context)
197 : ConditionalAccumulatorBaseTakeGradientOp(context) {}
198
199 DataTypeVector GetExpectedInputs(
200 ConditionalAccumulatorBase* accumulator) override {
201 return {DT_RESOURCE, DT_INT32};
202 }
203
204 private:
205 TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorTakeGradientOp);
206};
207
208REGISTER_KERNEL_BUILDER(
209 Name("ResourceAccumulatorTakeGradient").Device(DEVICE_CPU),
210 ResourceAccumulatorTakeGradientOp);
211
212} // namespace tensorflow
213