1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/kernels/conditional_accumulator.h" |
19 | #include "tensorflow/core/kernels/conditional_accumulator_base_op.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | /** |
24 | * Defines a ConditionalAccumulatorOp, which constructs a ConditionalAccumulator |
25 | * and returns its handle. |
26 | */ |
27 | template <typename Device, typename T> |
28 | class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { |
29 | public: |
30 | explicit ConditionalAccumulatorOp(OpKernelConstruction* context) |
31 | : ConditionalAccumulatorBaseOp(context) {} |
32 | |
33 | protected: |
34 | Creator GetCreator() const override { |
35 | return [this](ConditionalAccumulatorBase** ret) { |
36 | ConditionalAccumulator<Device, T>* accumulator = |
37 | new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(), |
38 | reduction_type_); |
39 | *ret = accumulator; |
40 | return OkStatus(); |
41 | }; |
42 | } |
43 | |
44 | Status CheckSignature(OpKernelContext* ctx) override { |
45 | TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF})); |
46 | return OkStatus(); |
47 | } |
48 | |
49 | void SetHandleToOutput(OpKernelContext* ctx) |
50 | TF_SHARED_LOCKS_REQUIRED(mu_) override { |
51 | ctx->set_output_ref(0, &mu_, &accumulator_); |
52 | } |
53 | |
54 | TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulatorOp); |
55 | }; |
56 | |
57 | #define REGISTER_KERNELS(type, dev) \ |
58 | REGISTER_KERNEL_BUILDER(Name("ConditionalAccumulator") \ |
59 | .Device(DEVICE_##dev) \ |
60 | .TypeConstraint<type>("dtype"), \ |
61 | ConditionalAccumulatorOp<dev##Device, type>) |
62 | |
63 | // Resource conditional accumulator |
64 | template <typename Device, typename T> |
65 | class ResourceConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { |
66 | public: |
67 | explicit ResourceConditionalAccumulatorOp(OpKernelConstruction* context) |
68 | : ConditionalAccumulatorBaseOp(context) {} |
69 | |
70 | protected: |
71 | Creator GetCreator() const override { |
72 | return [this](ConditionalAccumulatorBase** ret) { |
73 | ConditionalAccumulator<Device, T>* accumulator = |
74 | new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(), |
75 | reduction_type_); |
76 | *ret = accumulator; |
77 | return OkStatus(); |
78 | }; |
79 | } |
80 | |
81 | Status CheckSignature(OpKernelContext* ctx) override { |
82 | TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_RESOURCE})); |
83 | return OkStatus(); |
84 | } |
85 | |
86 | void SetHandleToOutput(OpKernelContext* ctx) |
87 | TF_SHARED_LOCKS_REQUIRED(mu_) override { |
88 | auto h = accumulator_.template flat<tstring>(); |
89 | h(0) = cinfo_.container(); |
90 | h(1) = cinfo_.name(); |
91 | OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( |
92 | ctx, 0, cinfo_.container(), cinfo_.name(), |
93 | TypeIndex::Make<ConditionalAccumulatorBase>())); |
94 | } |
95 | |
96 | TF_DISALLOW_COPY_AND_ASSIGN(ResourceConditionalAccumulatorOp); |
97 | }; |
98 | |
99 | #define REGISTER_RESOURCE_KERNELS(type, dev) \ |
100 | REGISTER_KERNEL_BUILDER(Name("ResourceConditionalAccumulator") \ |
101 | .Device(DEVICE_##dev) \ |
102 | .TypeConstraint<type>("dtype"), \ |
103 | ResourceConditionalAccumulatorOp<dev##Device, type>) |
104 | |
105 | // End of Resource conditional accumulator |
106 | |
107 | #define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU) |
108 | |
109 | TF_CALL_half(REGISTER_KERNELS_CPU); |
110 | TF_CALL_float(REGISTER_KERNELS_CPU); |
111 | TF_CALL_double(REGISTER_KERNELS_CPU); |
112 | |
113 | #undef REGISTER_KERNELS_CPU |
114 | #undef REGISTER_KERNELS |
115 | |
116 | #define REGISTER_RESOURCE_KERNELS_CPU(type) REGISTER_RESOURCE_KERNELS(type, CPU) |
117 | |
118 | TF_CALL_half(REGISTER_RESOURCE_KERNELS_CPU); |
119 | TF_CALL_float(REGISTER_RESOURCE_KERNELS_CPU); |
120 | TF_CALL_double(REGISTER_RESOURCE_KERNELS_CPU); |
121 | |
122 | #undef REGISTER_KERNELS_CPU |
123 | #undef REGISTER_KERNELS |
124 | |
125 | /** |
126 | * Defines a AccumulateGradientOp, the execution of which adds a gradient to the |
127 | * given ConditionalAccumulator. |
128 | */ |
129 | class AccumulatorApplyGradientOp |
130 | : public ConditionalAccumulatorBaseApplyGradientOp { |
131 | public: |
132 | explicit AccumulatorApplyGradientOp(OpKernelConstruction* context) |
133 | : ConditionalAccumulatorBaseApplyGradientOp(context) {} |
134 | |
135 | DataTypeVector GetExpectedInputs( |
136 | ConditionalAccumulatorBase* accumulator) override { |
137 | DataTypeVector expected_inputs; |
138 | expected_inputs = {DT_STRING_REF, DT_INT64}; |
139 | expected_inputs.push_back(accumulator->dtype()); |
140 | return expected_inputs; |
141 | } |
142 | |
143 | private: |
144 | TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorApplyGradientOp); |
145 | }; |
146 | |
147 | REGISTER_KERNEL_BUILDER(Name("AccumulatorApplyGradient" ).Device(DEVICE_CPU), |
148 | AccumulatorApplyGradientOp); |
149 | |
150 | class ResourceAccumulatorApplyGradientOp |
151 | : public ConditionalAccumulatorBaseApplyGradientOp { |
152 | public: |
153 | explicit ResourceAccumulatorApplyGradientOp(OpKernelConstruction* context) |
154 | : ConditionalAccumulatorBaseApplyGradientOp(context) {} |
155 | |
156 | DataTypeVector GetExpectedInputs( |
157 | ConditionalAccumulatorBase* accumulator) override { |
158 | DataTypeVector expected_inputs; |
159 | expected_inputs = {DT_RESOURCE, DT_INT64}; |
160 | expected_inputs.push_back(accumulator->dtype()); |
161 | return expected_inputs; |
162 | } |
163 | |
164 | private: |
165 | TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorApplyGradientOp); |
166 | }; |
167 | |
168 | REGISTER_KERNEL_BUILDER( |
169 | Name("ResourceAccumulatorApplyGradient" ).Device(DEVICE_CPU), |
170 | ResourceAccumulatorApplyGradientOp); |
171 | |
172 | /** |
173 | * Defines a ConditionalAccumulatorBaseTakeGradientOp, the execution of which |
174 | * returns the average gradient accumulated by the given ConditionalAccumulator. |
175 | */ |
176 | class AccumulatorTakeGradientOp |
177 | : public ConditionalAccumulatorBaseTakeGradientOp { |
178 | public: |
179 | explicit AccumulatorTakeGradientOp(OpKernelConstruction* context) |
180 | : ConditionalAccumulatorBaseTakeGradientOp(context) {} |
181 | |
182 | DataTypeVector GetExpectedInputs( |
183 | ConditionalAccumulatorBase* accumulator) override { |
184 | return {DT_STRING_REF, DT_INT32}; |
185 | } |
186 | |
187 | private: |
188 | TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorTakeGradientOp); |
189 | }; |
190 | REGISTER_KERNEL_BUILDER(Name("AccumulatorTakeGradient" ).Device(DEVICE_CPU), |
191 | AccumulatorTakeGradientOp); |
192 | |
193 | class ResourceAccumulatorTakeGradientOp |
194 | : public ConditionalAccumulatorBaseTakeGradientOp { |
195 | public: |
196 | explicit ResourceAccumulatorTakeGradientOp(OpKernelConstruction* context) |
197 | : ConditionalAccumulatorBaseTakeGradientOp(context) {} |
198 | |
199 | DataTypeVector GetExpectedInputs( |
200 | ConditionalAccumulatorBase* accumulator) override { |
201 | return {DT_RESOURCE, DT_INT32}; |
202 | } |
203 | |
204 | private: |
205 | TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorTakeGradientOp); |
206 | }; |
207 | |
208 | REGISTER_KERNEL_BUILDER( |
209 | Name("ResourceAccumulatorTakeGradient" ).Device(DEVICE_CPU), |
210 | ResourceAccumulatorTakeGradientOp); |
211 | |
212 | } // namespace tensorflow |
213 | |