1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/kernels/conditional_accumulator_base_op.h" |
19 | #include "tensorflow/core/kernels/sparse_conditional_accumulator.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | /** |
24 | * Defines a SparseConditionalAccumulatorOp, which constructs a |
25 | * SparseConditionalAccumulator and returns its handle. |
26 | */ |
27 | template <typename Device, typename T> |
28 | class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { |
29 | public: |
30 | explicit SparseConditionalAccumulatorOp(OpKernelConstruction* context) |
31 | : ConditionalAccumulatorBaseOp(context) {} |
32 | |
33 | protected: |
34 | Creator GetCreator() const override { |
35 | return [this](ConditionalAccumulatorBase** ret) { |
36 | SparseConditionalAccumulator<Device, T>* accumulator = |
37 | new SparseConditionalAccumulator<Device, T>( |
38 | dtype_, shape_, cinfo_.name(), reduction_type_); |
39 | *ret = accumulator; |
40 | return OkStatus(); |
41 | }; |
42 | } |
43 | |
44 | // TODO(tanzheny): actually switch it to resource. You won't be able to use |
45 | // it with cond2 otherwise. |
46 | Status CheckSignature(OpKernelContext* ctx) override { |
47 | TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF})); |
48 | return OkStatus(); |
49 | } |
50 | |
51 | void SetHandleToOutput(OpKernelContext* ctx) |
52 | TF_SHARED_LOCKS_REQUIRED(mu_) override { |
53 | ctx->set_output_ref(0, &mu_, &accumulator_); |
54 | } |
55 | |
56 | TF_DISALLOW_COPY_AND_ASSIGN(SparseConditionalAccumulatorOp); |
57 | }; |
58 | |
59 | #define REGISTER_KERNELS(type, dev) \ |
60 | REGISTER_KERNEL_BUILDER(Name("SparseConditionalAccumulator") \ |
61 | .Device(DEVICE_##dev) \ |
62 | .TypeConstraint<type>("dtype"), \ |
63 | SparseConditionalAccumulatorOp<dev##Device, type>) |
64 | |
65 | #define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU) |
66 | |
67 | TF_CALL_half(REGISTER_KERNELS_CPU); |
68 | TF_CALL_float(REGISTER_KERNELS_CPU); |
69 | TF_CALL_double(REGISTER_KERNELS_CPU); |
70 | |
71 | #undef REGISTER_KERNELS_CPU |
72 | #undef REGISTER_KERNELS |
73 | |
74 | /** |
75 | * Defines a SparseAccumulateGradientOp, the execution of which adds a gradient |
76 | * to the given SparseConditionalAccumulator. |
77 | */ |
78 | class SparseAccumulatorApplyGradientOp |
79 | : public ConditionalAccumulatorBaseApplyGradientOp { |
80 | public: |
81 | explicit SparseAccumulatorApplyGradientOp(OpKernelConstruction* context) |
82 | : ConditionalAccumulatorBaseApplyGradientOp(context) {} |
83 | |
84 | protected: |
85 | DataTypeVector GetExpectedInputs( |
86 | ConditionalAccumulatorBase* accumulator) override { |
87 | DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64, DT_INT64}; |
88 | expected_inputs.push_back(accumulator->dtype()); |
89 | expected_inputs.push_back(DT_INT64); |
90 | return expected_inputs; |
91 | } |
92 | |
93 | private: |
94 | TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorApplyGradientOp); |
95 | }; |
96 | |
97 | REGISTER_KERNEL_BUILDER( |
98 | Name("SparseAccumulatorApplyGradient" ).Device(DEVICE_CPU), |
99 | SparseAccumulatorApplyGradientOp); |
100 | |
101 | /** |
102 | * Defines a SparseAccumulatorTakeGradientOp, the execution of which returns the |
103 | * average sparse gradient accumulated by the given ConditionalAccumulator. |
104 | */ |
105 | class SparseAccumulatorTakeGradientOp |
106 | : public ConditionalAccumulatorBaseTakeGradientOp { |
107 | public: |
108 | explicit SparseAccumulatorTakeGradientOp(OpKernelConstruction* context) |
109 | : ConditionalAccumulatorBaseTakeGradientOp(context) {} |
110 | |
111 | protected: |
112 | void CheckSignature(OpKernelContext* ctx, |
113 | ConditionalAccumulatorBase* accumulator, |
114 | DoneCallback callback) override { |
115 | // Check signature |
116 | OP_REQUIRES_OK_ASYNC( |
117 | ctx, |
118 | ctx->MatchSignature({DT_STRING_REF, DT_INT32}, |
119 | {DT_INT64, accumulator->dtype(), DT_INT64}), |
120 | callback); |
121 | } |
122 | |
123 | DataTypeVector GetExpectedInputs( |
124 | ConditionalAccumulatorBase* accumulator) override { |
125 | return {DT_STRING_REF, DT_INT32}; |
126 | } |
127 | |
128 | private: |
129 | TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorTakeGradientOp); |
130 | }; |
131 | |
132 | REGISTER_KERNEL_BUILDER( |
133 | Name("SparseAccumulatorTakeGradient" ).Device(DEVICE_CPU), |
134 | SparseAccumulatorTakeGradientOp); |
135 | |
136 | } // namespace tensorflow |
137 | |