1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
17#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
18
19#include <deque>
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/numeric_op.h"
23
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/resource_mgr.h"
26
27namespace tensorflow {
28
29/**
30 * ConditionalAccumulator/ConditionalAccumulatorBase implements an aggregation
31 * object for adding gradients.
32 * The two main methods of this class are TryApplyGrad and TryTakeGrad.
33 *
34 * TryApplyGrad tries add a gradient to the accumulator. The attempt is
35 * successful if local_step >= global_step, i.e., if the gradient is not stale,
36 * having been computed using up-to-date information. Otherwise, the gradient is
37 * silently dropped.
38 *
39 * TryTakeGrad logs an attempt to read the average gradient. The attempt is
40 * blocked until the number of gradients accumulated (via TryApplyGrad) is equal
41 * or exceeds the number requested by TryTakeGrad.
42 * Once this condition is satisfied, the following actions are taken:
43 * (1) the value of the average gradient is returned
44 * (2) the count of accumulated gradients is reset to 0
45 * (3) the internal global_step value (current_global_step_) is incremented by 1
46 */
47class ConditionalAccumulatorBase : public ResourceBase {
48 public:
49 // Args:
50 // dtype: The datatype of the gradients to be accumulated.
51 // shape: The shape of the accumulated gradients.
52 // name: A name to use for the ConditionalAccumulator.
53 ConditionalAccumulatorBase(const DataType& dtype,
54 const PartialTensorShape& shape,
55 const string& name, const string& reduction_type);
56
57 typedef AsyncOpKernel::DoneCallback DoneCallback;
58
59 virtual void TryApplyGrad(int64_t local_step, OpKernelContext* ctx) = 0;
60 void TryTakeGrad(int num_required, OpKernelContext* ctx,
61 DoneCallback callback);
62
63 // Accessor methods
64 uint32 num_accumulated() {
65 mutex_lock lock(mu_);
66 return counter_;
67 }
68
69 const DataType& dtype() const { return dtype_; }
70
71 string DebugString() const override { return "A conditional accumulator"; }
72
73 // SetGlobalStep is a modifier method for current_global_step.
74 // It returns an InvalidArgument error if the new_global_step is less than
75 // current_global_step.
76 Status SetGlobalStep(int64_t new_global_step);
77
78 Status MatchesNodeDef(const NodeDef& node_def);
79
80 protected:
81 // Virtual methods to be implemented by sub-classes for different datatypes.
82 // Implements arithmetic operations specific to datatype.
83 virtual void DivideAccumGradByCounter(OpKernelContext* ctx)
84 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
85 virtual bool SetOutput(OpKernelContext* ctx) = 0;
86
87 enum RunResult { kNoProgress, kComplete };
88
89 // Helper struct holding information about a TakeGrad attempt
90 struct Attempt;
91 typedef std::function<RunResult(Attempt*)> RunCallback;
92 struct Attempt {
93 int elements_requested;
94 DoneCallback done_callback; // must be run outside mu_
95 OpKernelContext* context;
96 CancellationManager* cancellation_manager; // not owned
97 CancellationToken cancellation_token;
98 RunCallback run_callback; // must be run while holding mu_
99 bool is_cancelled;
100
101 Attempt(int elements_requested, DoneCallback done_callback,
102 OpKernelContext* context, CancellationManager* cancellation_manager,
103 CancellationToken cancellation_token, RunCallback run_callback)
104 : elements_requested(elements_requested),
105 done_callback(std::move(done_callback)),
106 context(context),
107 cancellation_manager(cancellation_manager),
108 cancellation_token(cancellation_token),
109 run_callback(std::move(run_callback)),
110 is_cancelled(false) {}
111 };
112
113 // Helper struct for deregistration of a cancellation token and executing a
114 // DoneCallback after a TakeGrad attempt is complete.
115 struct CleanUp {
116 CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
117 : finished(f), to_deregister(ct), cm(cm) {}
118 DoneCallback finished;
119 CancellationToken to_deregister;
120 CancellationManager* cm;
121 };
122
123 // Fields
124
125 const DataType dtype_;
126 const PartialTensorShape shape_;
127 const string name_;
128 const string reduction_type_;
129 mutex mu_;
130 int counter_ TF_GUARDED_BY(mu_);
131 int64_t current_global_step_ TF_GUARDED_BY(mu_);
132
133 std::deque<Attempt> takegrad_attempts_ TF_GUARDED_BY(mu_);
134
135 // Methods
136
137 // Helper function for creating cancellation callback
138 void Cancel(CancellationManager* cancellation_manager,
139 CancellationToken token);
140
141 // Helper functions to process TakeGrad attempts.
142 // FlushUnlocked is called at the end of each TryApplyGrad and TryTakeGrad
143 // calls to try to clear the TakeGrad attempts. This in turn calls
144 // TryAttemptLocked, which then executes the RunCallback of the logged
145 // attempts.
146 // Both functions are modeled after core/kernels/queue_base.
147 // Note: ApplyGrad attempts never block -- unlike in a queue with limited
148 // capacity, we can always add the newest gradient to our accumulator
149 // (if it is not stale) or drop it silently (if it is stale).
150 void FlushUnlocked();
151 bool TryAttemptLocked(std::vector<CleanUp>* clean_up)
152 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
153
154 // Helper methods
155 // void DeepCopy(Tensor* dst);
156 bool TakeGradLockedHelper(OpKernelContext* ctx, DoneCallback callback)
157 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
158};
159
160/*
161 * Modifications to convenience macros defined in core/framework/op_kernel.h.
162 * The below macros return a boolean if the test fails, so that the calling
163 * function can get an indication that a failure has occurred.
164 */
165#define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \
166 do { \
167 if (!TF_PREDICT_TRUE(EXP)) { \
168 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
169 return false; \
170 } \
171 } while (0)
172
173#define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS) \
174 do { \
175 ::tensorflow::Status _s(STATUS); \
176 if (!TF_PREDICT_TRUE(_s.ok())) { \
177 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
178 return false; \
179 } \
180 } while (0)
181
182/*
183 * Convenience classes for helping to convert between numeric types.
184 * The specialization for Eigen::half here simplifies specialization of
185 * ConditionalAccumulator classes later.
186 */
187template <typename T, typename U>
188class TypeConverter {
189 public:
190 static T ConvertUToT(U c) { return c; /* implicit conversion */ }
191};
192
193template <typename U>
194class TypeConverter<Eigen::half, U> {
195 public:
196 static Eigen::half ConvertUToT(U c) { return static_cast<Eigen::half>(c); }
197};
198
199} // namespace tensorflow
200
201#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
202