1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/conditional_accumulator_base_op.h"
19#include "tensorflow/core/kernels/sparse_conditional_accumulator.h"
20
21namespace tensorflow {
22
23/**
24 * Defines a SparseConditionalAccumulatorOp, which constructs a
25 * SparseConditionalAccumulator and returns its handle.
26 */
27template <typename Device, typename T>
28class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
29 public:
30 explicit SparseConditionalAccumulatorOp(OpKernelConstruction* context)
31 : ConditionalAccumulatorBaseOp(context) {}
32
33 protected:
34 Creator GetCreator() const override {
35 return [this](ConditionalAccumulatorBase** ret) {
36 SparseConditionalAccumulator<Device, T>* accumulator =
37 new SparseConditionalAccumulator<Device, T>(
38 dtype_, shape_, cinfo_.name(), reduction_type_);
39 *ret = accumulator;
40 return OkStatus();
41 };
42 }
43
44 // TODO(tanzheny): actually switch it to resource. You won't be able to use
45 // it with cond2 otherwise.
46 Status CheckSignature(OpKernelContext* ctx) override {
47 TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF}));
48 return OkStatus();
49 }
50
51 void SetHandleToOutput(OpKernelContext* ctx)
52 TF_SHARED_LOCKS_REQUIRED(mu_) override {
53 ctx->set_output_ref(0, &mu_, &accumulator_);
54 }
55
56 TF_DISALLOW_COPY_AND_ASSIGN(SparseConditionalAccumulatorOp);
57};
58
59#define REGISTER_KERNELS(type, dev) \
60 REGISTER_KERNEL_BUILDER(Name("SparseConditionalAccumulator") \
61 .Device(DEVICE_##dev) \
62 .TypeConstraint<type>("dtype"), \
63 SparseConditionalAccumulatorOp<dev##Device, type>)
64
65#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU)
66
67TF_CALL_half(REGISTER_KERNELS_CPU);
68TF_CALL_float(REGISTER_KERNELS_CPU);
69TF_CALL_double(REGISTER_KERNELS_CPU);
70
71#undef REGISTER_KERNELS_CPU
72#undef REGISTER_KERNELS
73
74/**
75 * Defines a SparseAccumulateGradientOp, the execution of which adds a gradient
76 * to the given SparseConditionalAccumulator.
77 */
78class SparseAccumulatorApplyGradientOp
79 : public ConditionalAccumulatorBaseApplyGradientOp {
80 public:
81 explicit SparseAccumulatorApplyGradientOp(OpKernelConstruction* context)
82 : ConditionalAccumulatorBaseApplyGradientOp(context) {}
83
84 protected:
85 DataTypeVector GetExpectedInputs(
86 ConditionalAccumulatorBase* accumulator) override {
87 DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64, DT_INT64};
88 expected_inputs.push_back(accumulator->dtype());
89 expected_inputs.push_back(DT_INT64);
90 return expected_inputs;
91 }
92
93 private:
94 TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorApplyGradientOp);
95};
96
97REGISTER_KERNEL_BUILDER(
98 Name("SparseAccumulatorApplyGradient").Device(DEVICE_CPU),
99 SparseAccumulatorApplyGradientOp);
100
101/**
102 * Defines a SparseAccumulatorTakeGradientOp, the execution of which returns the
103 * average sparse gradient accumulated by the given ConditionalAccumulator.
104 */
105class SparseAccumulatorTakeGradientOp
106 : public ConditionalAccumulatorBaseTakeGradientOp {
107 public:
108 explicit SparseAccumulatorTakeGradientOp(OpKernelConstruction* context)
109 : ConditionalAccumulatorBaseTakeGradientOp(context) {}
110
111 protected:
112 void CheckSignature(OpKernelContext* ctx,
113 ConditionalAccumulatorBase* accumulator,
114 DoneCallback callback) override {
115 // Check signature
116 OP_REQUIRES_OK_ASYNC(
117 ctx,
118 ctx->MatchSignature({DT_STRING_REF, DT_INT32},
119 {DT_INT64, accumulator->dtype(), DT_INT64}),
120 callback);
121 }
122
123 DataTypeVector GetExpectedInputs(
124 ConditionalAccumulatorBase* accumulator) override {
125 return {DT_STRING_REF, DT_INT32};
126 }
127
128 private:
129 TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorTakeGradientOp);
130};
131
132REGISTER_KERNEL_BUILDER(
133 Name("SparseAccumulatorTakeGradient").Device(DEVICE_CPU),
134 SparseAccumulatorTakeGradientOp);
135
136} // namespace tensorflow
137