1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/kernels/conditional_accumulator_base_op.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | /** |
23 | * Defines a AccumulatorSetGlobalStepOp, the execution of which sets the |
24 | * global_step variable of the given ConditionalAccumulator. |
25 | */ |
26 | class AccumulatorSetGlobalStepOp |
27 | : public ConditionalAccumulatorBaseSyncOpKernel { |
28 | public: |
29 | explicit AccumulatorSetGlobalStepOp(OpKernelConstruction* context) |
30 | : ConditionalAccumulatorBaseSyncOpKernel(context) {} |
31 | |
32 | protected: |
33 | DataTypeVector GetExpectedInputs( |
34 | ConditionalAccumulatorBase* accumulator) override { |
35 | return {DT_STRING_REF, DT_INT64}; |
36 | } |
37 | |
38 | void Compute(OpKernelContext* ctx, |
39 | ConditionalAccumulatorBase* accumulator) override { |
40 | // Check signature |
41 | CheckSignature(ctx, accumulator); |
42 | |
43 | // Get input new_global_step |
44 | const Tensor* new_global_step_tensor; |
45 | OP_REQUIRES_OK(ctx, ctx->input("new_global_step" , &new_global_step_tensor)); |
46 | if (!TensorShapeUtils::IsScalar(new_global_step_tensor->shape())) { |
47 | ctx->CtxFailureWithWarning(errors::InvalidArgument( |
48 | "Argument num_required must be scalar, but had bad shape " , |
49 | new_global_step_tensor->shape().DebugString())); |
50 | } |
51 | |
52 | Status status = |
53 | accumulator->SetGlobalStep(new_global_step_tensor->scalar<int64_t>()()); |
54 | if (!status.ok()) ctx->CtxFailureWithWarning(status); |
55 | } |
56 | |
57 | private: |
58 | TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorSetGlobalStepOp); |
59 | }; |
60 | |
61 | REGISTER_KERNEL_BUILDER(Name("AccumulatorSetGlobalStep" ).Device(DEVICE_CPU), |
62 | AccumulatorSetGlobalStepOp); |
63 | |
64 | class ResourceAccumulatorSetGlobalStepOp : public AccumulatorSetGlobalStepOp { |
65 | public: |
66 | explicit ResourceAccumulatorSetGlobalStepOp(OpKernelConstruction* context) |
67 | : AccumulatorSetGlobalStepOp(context) {} |
68 | |
69 | DataTypeVector GetExpectedInputs( |
70 | ConditionalAccumulatorBase* accumulator) override { |
71 | return {DT_RESOURCE, DT_INT64}; |
72 | } |
73 | |
74 | private: |
75 | TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorSetGlobalStepOp); |
76 | }; |
77 | |
78 | REGISTER_KERNEL_BUILDER( |
79 | Name("ResourceAccumulatorSetGlobalStep" ).Device(DEVICE_CPU), |
80 | ResourceAccumulatorSetGlobalStepOp); |
81 | |
82 | /** |
83 | * Defines a AccumulatorNumAccumulatedOp, which returns the number of gradients |
84 | * that have been accumulated in the given ConditionalAccumulator, and emits it |
85 | * as an output tensor. |
86 | */ |
87 | class AccumulatorNumAccumulatedOp |
88 | : public ConditionalAccumulatorBaseSyncOpKernel { |
89 | public: |
90 | explicit AccumulatorNumAccumulatedOp(OpKernelConstruction* context) |
91 | : ConditionalAccumulatorBaseSyncOpKernel(context) {} |
92 | |
93 | protected: |
94 | void CheckSignature(OpKernelContext* ctx, |
95 | ConditionalAccumulatorBase* accumulator) override { |
96 | // Check input signature |
97 | OP_REQUIRES_OK( |
98 | ctx, ctx->MatchSignature(GetExpectedInputs(accumulator), {DT_INT32})); |
99 | } |
100 | |
101 | DataTypeVector GetExpectedInputs( |
102 | ConditionalAccumulatorBase* accumulator) override { |
103 | return {DT_STRING_REF}; |
104 | } |
105 | |
106 | void Compute(OpKernelContext* ctx, |
107 | ConditionalAccumulatorBase* accumulator) override { |
108 | // Check signature |
109 | CheckSignature(ctx, accumulator); |
110 | |
111 | Tensor* Taccumulator_size = nullptr; |
112 | OP_REQUIRES_OK( |
113 | ctx, ctx->allocate_output(0, TensorShape({}), &Taccumulator_size)); |
114 | Taccumulator_size->flat<int32>().setConstant( |
115 | accumulator->num_accumulated()); |
116 | } |
117 | |
118 | private: |
119 | TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorNumAccumulatedOp); |
120 | }; |
121 | |
122 | REGISTER_KERNEL_BUILDER(Name("AccumulatorNumAccumulated" ).Device(DEVICE_CPU), |
123 | AccumulatorNumAccumulatedOp); |
124 | |
125 | class ResourceAccumulatorNumAccumulatedOp : public AccumulatorNumAccumulatedOp { |
126 | public: |
127 | explicit ResourceAccumulatorNumAccumulatedOp(OpKernelConstruction* context) |
128 | : AccumulatorNumAccumulatedOp(context) {} |
129 | |
130 | DataTypeVector GetExpectedInputs( |
131 | ConditionalAccumulatorBase* accumulator) override { |
132 | return {DT_RESOURCE}; |
133 | } |
134 | |
135 | private: |
136 | TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorNumAccumulatedOp); |
137 | }; |
138 | |
139 | REGISTER_KERNEL_BUILDER( |
140 | Name("ResourceAccumulatorNumAccumulated" ).Device(DEVICE_CPU), |
141 | ResourceAccumulatorNumAccumulatedOp); |
142 | |
143 | } // namespace tensorflow |
144 | |