1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ |
18 | |
19 | #include <deque> |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/numeric_op.h" |
23 | |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/resource_mgr.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | /** |
30 | * ConditionalAccumulator/ConditionalAccumulatorBase implements an aggregation |
31 | * object for adding gradients. |
32 | * The two main methods of this class are TryApplyGrad and TryTakeGrad. |
33 | * |
34 | * TryApplyGrad tries add a gradient to the accumulator. The attempt is |
35 | * successful if local_step >= global_step, i.e., if the gradient is not stale, |
36 | * having been computed using up-to-date information. Otherwise, the gradient is |
37 | * silently dropped. |
38 | * |
39 | * TryTakeGrad logs an attempt to read the average gradient. The attempt is |
40 | * blocked until the number of gradients accumulated (via TryApplyGrad) is equal |
41 | * or exceeds the number requested by TryTakeGrad. |
42 | * Once this condition is satisfied, the following actions are taken: |
43 | * (1) the value of the average gradient is returned |
44 | * (2) the count of accumulated gradients is reset to 0 |
45 | * (3) the internal global_step value (current_global_step_) is incremented by 1 |
46 | */ |
47 | class ConditionalAccumulatorBase : public ResourceBase { |
48 | public: |
49 | // Args: |
50 | // dtype: The datatype of the gradients to be accumulated. |
51 | // shape: The shape of the accumulated gradients. |
52 | // name: A name to use for the ConditionalAccumulator. |
53 | ConditionalAccumulatorBase(const DataType& dtype, |
54 | const PartialTensorShape& shape, |
55 | const string& name, const string& reduction_type); |
56 | |
57 | typedef AsyncOpKernel::DoneCallback DoneCallback; |
58 | |
59 | virtual void TryApplyGrad(int64_t local_step, OpKernelContext* ctx) = 0; |
60 | void TryTakeGrad(int num_required, OpKernelContext* ctx, |
61 | DoneCallback callback); |
62 | |
63 | // Accessor methods |
64 | uint32 num_accumulated() { |
65 | mutex_lock lock(mu_); |
66 | return counter_; |
67 | } |
68 | |
69 | const DataType& dtype() const { return dtype_; } |
70 | |
71 | string DebugString() const override { return "A conditional accumulator" ; } |
72 | |
73 | // SetGlobalStep is a modifier method for current_global_step. |
74 | // It returns an InvalidArgument error if the new_global_step is less than |
75 | // current_global_step. |
76 | Status SetGlobalStep(int64_t new_global_step); |
77 | |
78 | Status MatchesNodeDef(const NodeDef& node_def); |
79 | |
80 | protected: |
81 | // Virtual methods to be implemented by sub-classes for different datatypes. |
82 | // Implements arithmetic operations specific to datatype. |
83 | virtual void DivideAccumGradByCounter(OpKernelContext* ctx) |
84 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; |
85 | virtual bool SetOutput(OpKernelContext* ctx) = 0; |
86 | |
87 | enum RunResult { kNoProgress, kComplete }; |
88 | |
89 | // Helper struct holding information about a TakeGrad attempt |
90 | struct Attempt; |
91 | typedef std::function<RunResult(Attempt*)> RunCallback; |
92 | struct Attempt { |
93 | int elements_requested; |
94 | DoneCallback done_callback; // must be run outside mu_ |
95 | OpKernelContext* context; |
96 | CancellationManager* cancellation_manager; // not owned |
97 | CancellationToken cancellation_token; |
98 | RunCallback run_callback; // must be run while holding mu_ |
99 | bool is_cancelled; |
100 | |
101 | Attempt(int elements_requested, DoneCallback done_callback, |
102 | OpKernelContext* context, CancellationManager* cancellation_manager, |
103 | CancellationToken cancellation_token, RunCallback run_callback) |
104 | : elements_requested(elements_requested), |
105 | done_callback(std::move(done_callback)), |
106 | context(context), |
107 | cancellation_manager(cancellation_manager), |
108 | cancellation_token(cancellation_token), |
109 | run_callback(std::move(run_callback)), |
110 | is_cancelled(false) {} |
111 | }; |
112 | |
113 | // Helper struct for deregistration of a cancellation token and executing a |
114 | // DoneCallback after a TakeGrad attempt is complete. |
115 | struct CleanUp { |
116 | CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) |
117 | : finished(f), to_deregister(ct), cm(cm) {} |
118 | DoneCallback finished; |
119 | CancellationToken to_deregister; |
120 | CancellationManager* cm; |
121 | }; |
122 | |
123 | // Fields |
124 | |
125 | const DataType dtype_; |
126 | const PartialTensorShape shape_; |
127 | const string name_; |
128 | const string reduction_type_; |
129 | mutex mu_; |
130 | int counter_ TF_GUARDED_BY(mu_); |
131 | int64_t current_global_step_ TF_GUARDED_BY(mu_); |
132 | |
133 | std::deque<Attempt> takegrad_attempts_ TF_GUARDED_BY(mu_); |
134 | |
135 | // Methods |
136 | |
137 | // Helper function for creating cancellation callback |
138 | void Cancel(CancellationManager* cancellation_manager, |
139 | CancellationToken token); |
140 | |
141 | // Helper functions to process TakeGrad attempts. |
142 | // FlushUnlocked is called at the end of each TryApplyGrad and TryTakeGrad |
143 | // calls to try to clear the TakeGrad attempts. This in turn calls |
144 | // TryAttemptLocked, which then executes the RunCallback of the logged |
145 | // attempts. |
146 | // Both functions are modeled after core/kernels/queue_base. |
147 | // Note: ApplyGrad attempts never block -- unlike in a queue with limited |
148 | // capacity, we can always add the newest gradient to our accumulator |
149 | // (if it is not stale) or drop it silently (if it is stale). |
150 | void FlushUnlocked(); |
151 | bool TryAttemptLocked(std::vector<CleanUp>* clean_up) |
152 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
153 | |
154 | // Helper methods |
155 | // void DeepCopy(Tensor* dst); |
156 | bool TakeGradLockedHelper(OpKernelContext* ctx, DoneCallback callback) |
157 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
158 | }; |
159 | |
160 | /* |
161 | * Modifications to convenience macros defined in core/framework/op_kernel.h. |
162 | * The below macros return a boolean if the test fails, so that the calling |
163 | * function can get an indication that a failure has occurred. |
164 | */ |
165 | #define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \ |
166 | do { \ |
167 | if (!TF_PREDICT_TRUE(EXP)) { \ |
168 | (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ |
169 | return false; \ |
170 | } \ |
171 | } while (0) |
172 | |
173 | #define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS) \ |
174 | do { \ |
175 | ::tensorflow::Status _s(STATUS); \ |
176 | if (!TF_PREDICT_TRUE(_s.ok())) { \ |
177 | (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ |
178 | return false; \ |
179 | } \ |
180 | } while (0) |
181 | |
182 | /* |
183 | * Convenience classes for helping to convert between numeric types. |
184 | * The specialization for Eigen::half here simplifies specialization of |
185 | * ConditionalAccumulator classes later. |
186 | */ |
187 | template <typename T, typename U> |
188 | class TypeConverter { |
189 | public: |
190 | static T ConvertUToT(U c) { return c; /* implicit conversion */ } |
191 | }; |
192 | |
193 | template <typename U> |
194 | class TypeConverter<Eigen::half, U> { |
195 | public: |
196 | static Eigen::half ConvertUToT(U c) { return static_cast<Eigen::half>(c); } |
197 | }; |
198 | |
199 | } // namespace tensorflow |
200 | |
201 | #endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ |
202 | |