1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/conditional_accumulator_base.h" |
17 | #include "tensorflow/core/lib/core/errors.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | ConditionalAccumulatorBase::ConditionalAccumulatorBase( |
22 | const DataType& dtype, const PartialTensorShape& shape, const string& name, |
23 | const string& reduction_type) |
24 | : dtype_(dtype), |
25 | shape_(shape), |
26 | name_(name), |
27 | reduction_type_(reduction_type) { |
28 | counter_ = 0; |
29 | current_global_step_ = 0; |
30 | } |
31 | |
32 | Status ConditionalAccumulatorBase::MatchesNodeDef(const NodeDef& node_def) { |
33 | // TODO(xinghao@): implement the checks for the node definition |
34 | return OkStatus(); |
35 | } |
36 | |
37 | /** |
38 | * Sets the time step of the accumulator to be in line with the global time |
39 | * step. Logs warning if the accumulator's time step is already larger than the |
40 | * provided time step. |
41 | */ |
42 | Status ConditionalAccumulatorBase::SetGlobalStep(int64_t new_global_step) { |
43 | mutex_lock lock(mu_); |
44 | if (new_global_step < current_global_step_) { |
45 | LOG(WARNING) << "Attempt to set current_global_step_ to smaller value: " |
46 | << "current_global_step_ = " << current_global_step_ |
47 | << " >= " << new_global_step << " = new_global_step." ; |
48 | } |
49 | current_global_step_ = new_global_step; |
50 | return OkStatus(); |
51 | } |
52 | |
53 | /** |
54 | * Logs an attempt to extract the average gradient, and tries to flush all |
55 | * TakeGrad attempts. |
56 | * A TakeGrad attempt is blocked until num_required > counter_, i.e., |
57 | * sufficient gradients have been accumulated. |
58 | * |
59 | * num_required: Number of gradients that needs to be accumulated before the |
60 | * attempt is unblocked. |
61 | * ctx: Context in which the op is executed. |
62 | * callback: A callback to be executed after the attempt has been completed. |
63 | */ |
64 | void ConditionalAccumulatorBase::TryTakeGrad(int num_required, |
65 | OpKernelContext* ctx, |
66 | DoneCallback callback) { |
67 | if (num_required <= 0) { |
68 | ctx->CtxFailureWithWarning(errors::InvalidArgument( |
69 | "Argument num_required must be positive, but was " , num_required)); |
70 | callback(); |
71 | } else { |
72 | CancellationManager* cm = ctx->cancellation_manager(); |
73 | CancellationToken token = cm->get_cancellation_token(); |
74 | bool already_cancelled; |
75 | { |
76 | mutex_lock l(mu_); |
77 | already_cancelled = !cm->RegisterCallback( |
78 | token, [this, cm, token]() { Cancel(cm, token); }); |
79 | if (!already_cancelled) { |
80 | takegrad_attempts_.emplace_back( |
81 | num_required, callback, ctx, cm, token, |
82 | [this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
83 | if (counter_ >= attempt->elements_requested) { |
84 | bool successful_take_grad = TakeGradLockedHelper( |
85 | attempt->context, attempt->done_callback); |
86 | if (successful_take_grad) { |
87 | return kComplete; |
88 | } else { |
89 | // Try again |
90 | return kNoProgress; |
91 | } |
92 | } else { |
93 | return kNoProgress; |
94 | } |
95 | }); |
96 | } |
97 | } |
98 | if (!already_cancelled) { |
99 | FlushUnlocked(); |
100 | } else { |
101 | ctx->SetStatus(errors::Cancelled("TakeGrad operation was cancelled" )); |
102 | callback(); |
103 | } |
104 | } |
105 | } |
106 | |
107 | /** |
108 | * Cancellation callback. |
109 | */ |
110 | void ConditionalAccumulatorBase::Cancel( |
111 | CancellationManager* cancellation_manager, CancellationToken token) { |
112 | DoneCallback callback = nullptr; |
113 | { |
114 | mutex_lock lock(mu_); |
115 | |
116 | for (Attempt& attempt : takegrad_attempts_) { |
117 | if (attempt.cancellation_manager == cancellation_manager && |
118 | attempt.cancellation_token == token) { |
119 | if (!attempt.is_cancelled) { |
120 | attempt.is_cancelled = true; |
121 | attempt.context->SetStatus( |
122 | errors::Cancelled("TakeGrad operation was cancelled" )); |
123 | std::swap(callback, attempt.done_callback); |
124 | } |
125 | break; |
126 | } |
127 | } |
128 | } |
129 | if (callback) { |
130 | callback(); |
131 | FlushUnlocked(); |
132 | } |
133 | } |
134 | |
135 | /** |
136 | * Try to flush logged, blocked TakeGrad attempts. |
137 | */ |
138 | bool ConditionalAccumulatorBase::TryAttemptLocked( |
139 | std::vector<CleanUp>* clean_up) { |
140 | bool progress = false; |
141 | bool done = false; |
142 | while (!done && !takegrad_attempts_.empty()) { |
143 | if (takegrad_attempts_.front().is_cancelled) { |
144 | VLOG(1) << "Skipping cancelled TakeGrad attempt" ; |
145 | takegrad_attempts_.pop_front(); |
146 | } else { |
147 | Attempt* cur_attempt = &takegrad_attempts_.front(); |
148 | switch (cur_attempt->run_callback(cur_attempt)) { |
149 | case kNoProgress: |
150 | done = true; |
151 | break; |
152 | case kComplete: |
153 | progress = true; |
154 | clean_up->emplace_back(std::move(cur_attempt->done_callback), |
155 | cur_attempt->cancellation_token, |
156 | cur_attempt->context->cancellation_manager()); |
157 | takegrad_attempts_.pop_front(); |
158 | break; |
159 | } |
160 | } |
161 | } |
162 | return progress; |
163 | } |
164 | |
165 | /** |
166 | * Try to flush logged, blocked TakeGrad attempts. |
167 | */ |
168 | void ConditionalAccumulatorBase::FlushUnlocked() { |
169 | std::vector<CleanUp> clean_up; |
170 | Ref(); |
171 | { |
172 | mutex_lock lock(mu_); |
173 | bool changed; |
174 | do { |
175 | changed = TryAttemptLocked(&clean_up); |
176 | } while (changed); |
177 | } |
178 | Unref(); |
179 | for (const auto& to_clean : clean_up) { |
180 | if (to_clean.to_deregister != CancellationManager::kInvalidToken) { |
181 | // NOTE(mrry): We can safely ignore the return value of |
182 | // DeregisterCallback because the mutex mu_ ensures that the |
183 | // cleanup action only executes once. |
184 | to_clean.cm->DeregisterCallback(to_clean.to_deregister); |
185 | } |
186 | to_clean.finished(); |
187 | } |
188 | } |
189 | |
190 | bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx, |
191 | DoneCallback callback) { |
192 | // At this point, the conditional should have been passed |
193 | |
194 | // Implicitly increment global_step |
195 | current_global_step_++; |
196 | |
197 | // Average the accumulated gradient |
198 | if (reduction_type_ == "MEAN" ) { |
199 | DivideAccumGradByCounter(ctx); |
200 | } |
201 | |
202 | // Set output for accumulated gradient tensor |
203 | bool successful_set_output = SetOutput(ctx); |
204 | |
205 | // Reset counter |
206 | if (successful_set_output) counter_ = 0; |
207 | |
208 | return successful_set_output; |
209 | } |
210 | |
211 | } // namespace tensorflow |
212 | |