1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/conditional_accumulator_base_op.h"
19
20namespace tensorflow {
21
22/**
23 * Defines a AccumulatorSetGlobalStepOp, the execution of which sets the
24 * global_step variable of the given ConditionalAccumulator.
25 */
26class AccumulatorSetGlobalStepOp
27 : public ConditionalAccumulatorBaseSyncOpKernel {
28 public:
29 explicit AccumulatorSetGlobalStepOp(OpKernelConstruction* context)
30 : ConditionalAccumulatorBaseSyncOpKernel(context) {}
31
32 protected:
33 DataTypeVector GetExpectedInputs(
34 ConditionalAccumulatorBase* accumulator) override {
35 return {DT_STRING_REF, DT_INT64};
36 }
37
38 void Compute(OpKernelContext* ctx,
39 ConditionalAccumulatorBase* accumulator) override {
40 // Check signature
41 CheckSignature(ctx, accumulator);
42
43 // Get input new_global_step
44 const Tensor* new_global_step_tensor;
45 OP_REQUIRES_OK(ctx, ctx->input("new_global_step", &new_global_step_tensor));
46 if (!TensorShapeUtils::IsScalar(new_global_step_tensor->shape())) {
47 ctx->CtxFailureWithWarning(errors::InvalidArgument(
48 "Argument num_required must be scalar, but had bad shape ",
49 new_global_step_tensor->shape().DebugString()));
50 }
51
52 Status status =
53 accumulator->SetGlobalStep(new_global_step_tensor->scalar<int64_t>()());
54 if (!status.ok()) ctx->CtxFailureWithWarning(status);
55 }
56
57 private:
58 TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorSetGlobalStepOp);
59};
60
61REGISTER_KERNEL_BUILDER(Name("AccumulatorSetGlobalStep").Device(DEVICE_CPU),
62 AccumulatorSetGlobalStepOp);
63
64class ResourceAccumulatorSetGlobalStepOp : public AccumulatorSetGlobalStepOp {
65 public:
66 explicit ResourceAccumulatorSetGlobalStepOp(OpKernelConstruction* context)
67 : AccumulatorSetGlobalStepOp(context) {}
68
69 DataTypeVector GetExpectedInputs(
70 ConditionalAccumulatorBase* accumulator) override {
71 return {DT_RESOURCE, DT_INT64};
72 }
73
74 private:
75 TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorSetGlobalStepOp);
76};
77
78REGISTER_KERNEL_BUILDER(
79 Name("ResourceAccumulatorSetGlobalStep").Device(DEVICE_CPU),
80 ResourceAccumulatorSetGlobalStepOp);
81
82/**
83 * Defines a AccumulatorNumAccumulatedOp, which returns the number of gradients
84 * that have been accumulated in the given ConditionalAccumulator, and emits it
85 * as an output tensor.
86 */
87class AccumulatorNumAccumulatedOp
88 : public ConditionalAccumulatorBaseSyncOpKernel {
89 public:
90 explicit AccumulatorNumAccumulatedOp(OpKernelConstruction* context)
91 : ConditionalAccumulatorBaseSyncOpKernel(context) {}
92
93 protected:
94 void CheckSignature(OpKernelContext* ctx,
95 ConditionalAccumulatorBase* accumulator) override {
96 // Check input signature
97 OP_REQUIRES_OK(
98 ctx, ctx->MatchSignature(GetExpectedInputs(accumulator), {DT_INT32}));
99 }
100
101 DataTypeVector GetExpectedInputs(
102 ConditionalAccumulatorBase* accumulator) override {
103 return {DT_STRING_REF};
104 }
105
106 void Compute(OpKernelContext* ctx,
107 ConditionalAccumulatorBase* accumulator) override {
108 // Check signature
109 CheckSignature(ctx, accumulator);
110
111 Tensor* Taccumulator_size = nullptr;
112 OP_REQUIRES_OK(
113 ctx, ctx->allocate_output(0, TensorShape({}), &Taccumulator_size));
114 Taccumulator_size->flat<int32>().setConstant(
115 accumulator->num_accumulated());
116 }
117
118 private:
119 TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorNumAccumulatedOp);
120};
121
122REGISTER_KERNEL_BUILDER(Name("AccumulatorNumAccumulated").Device(DEVICE_CPU),
123 AccumulatorNumAccumulatedOp);
124
125class ResourceAccumulatorNumAccumulatedOp : public AccumulatorNumAccumulatedOp {
126 public:
127 explicit ResourceAccumulatorNumAccumulatedOp(OpKernelConstruction* context)
128 : AccumulatorNumAccumulatedOp(context) {}
129
130 DataTypeVector GetExpectedInputs(
131 ConditionalAccumulatorBase* accumulator) override {
132 return {DT_RESOURCE};
133 }
134
135 private:
136 TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorNumAccumulatedOp);
137};
138
139REGISTER_KERNEL_BUILDER(
140 Name("ResourceAccumulatorNumAccumulated").Device(DEVICE_CPU),
141 ResourceAccumulatorNumAccumulatedOp);
142
143} // namespace tensorflow
144