1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/conditional_accumulator_base.h"
17#include "tensorflow/core/lib/core/errors.h"
18
19namespace tensorflow {
20
21ConditionalAccumulatorBase::ConditionalAccumulatorBase(
22 const DataType& dtype, const PartialTensorShape& shape, const string& name,
23 const string& reduction_type)
24 : dtype_(dtype),
25 shape_(shape),
26 name_(name),
27 reduction_type_(reduction_type) {
28 counter_ = 0;
29 current_global_step_ = 0;
30}
31
32Status ConditionalAccumulatorBase::MatchesNodeDef(const NodeDef& node_def) {
33 // TODO(xinghao@): implement the checks for the node definition
34 return OkStatus();
35}
36
37/**
38 * Sets the time step of the accumulator to be in line with the global time
39 * step. Logs warning if the accumulator's time step is already larger than the
40 * provided time step.
41 */
42Status ConditionalAccumulatorBase::SetGlobalStep(int64_t new_global_step) {
43 mutex_lock lock(mu_);
44 if (new_global_step < current_global_step_) {
45 LOG(WARNING) << "Attempt to set current_global_step_ to smaller value: "
46 << "current_global_step_ = " << current_global_step_
47 << " >= " << new_global_step << " = new_global_step.";
48 }
49 current_global_step_ = new_global_step;
50 return OkStatus();
51}
52
53/**
54 * Logs an attempt to extract the average gradient, and tries to flush all
55 * TakeGrad attempts.
56 * A TakeGrad attempt is blocked until num_required > counter_, i.e.,
57 * sufficient gradients have been accumulated.
58 *
59 * num_required: Number of gradients that needs to be accumulated before the
60 * attempt is unblocked.
61 * ctx: Context in which the op is executed.
62 * callback: A callback to be executed after the attempt has been completed.
63 */
64void ConditionalAccumulatorBase::TryTakeGrad(int num_required,
65 OpKernelContext* ctx,
66 DoneCallback callback) {
67 if (num_required <= 0) {
68 ctx->CtxFailureWithWarning(errors::InvalidArgument(
69 "Argument num_required must be positive, but was ", num_required));
70 callback();
71 } else {
72 CancellationManager* cm = ctx->cancellation_manager();
73 CancellationToken token = cm->get_cancellation_token();
74 bool already_cancelled;
75 {
76 mutex_lock l(mu_);
77 already_cancelled = !cm->RegisterCallback(
78 token, [this, cm, token]() { Cancel(cm, token); });
79 if (!already_cancelled) {
80 takegrad_attempts_.emplace_back(
81 num_required, callback, ctx, cm, token,
82 [this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
83 if (counter_ >= attempt->elements_requested) {
84 bool successful_take_grad = TakeGradLockedHelper(
85 attempt->context, attempt->done_callback);
86 if (successful_take_grad) {
87 return kComplete;
88 } else {
89 // Try again
90 return kNoProgress;
91 }
92 } else {
93 return kNoProgress;
94 }
95 });
96 }
97 }
98 if (!already_cancelled) {
99 FlushUnlocked();
100 } else {
101 ctx->SetStatus(errors::Cancelled("TakeGrad operation was cancelled"));
102 callback();
103 }
104 }
105}
106
107/**
108 * Cancellation callback.
109 */
110void ConditionalAccumulatorBase::Cancel(
111 CancellationManager* cancellation_manager, CancellationToken token) {
112 DoneCallback callback = nullptr;
113 {
114 mutex_lock lock(mu_);
115
116 for (Attempt& attempt : takegrad_attempts_) {
117 if (attempt.cancellation_manager == cancellation_manager &&
118 attempt.cancellation_token == token) {
119 if (!attempt.is_cancelled) {
120 attempt.is_cancelled = true;
121 attempt.context->SetStatus(
122 errors::Cancelled("TakeGrad operation was cancelled"));
123 std::swap(callback, attempt.done_callback);
124 }
125 break;
126 }
127 }
128 }
129 if (callback) {
130 callback();
131 FlushUnlocked();
132 }
133}
134
135/**
136 * Try to flush logged, blocked TakeGrad attempts.
137 */
138bool ConditionalAccumulatorBase::TryAttemptLocked(
139 std::vector<CleanUp>* clean_up) {
140 bool progress = false;
141 bool done = false;
142 while (!done && !takegrad_attempts_.empty()) {
143 if (takegrad_attempts_.front().is_cancelled) {
144 VLOG(1) << "Skipping cancelled TakeGrad attempt";
145 takegrad_attempts_.pop_front();
146 } else {
147 Attempt* cur_attempt = &takegrad_attempts_.front();
148 switch (cur_attempt->run_callback(cur_attempt)) {
149 case kNoProgress:
150 done = true;
151 break;
152 case kComplete:
153 progress = true;
154 clean_up->emplace_back(std::move(cur_attempt->done_callback),
155 cur_attempt->cancellation_token,
156 cur_attempt->context->cancellation_manager());
157 takegrad_attempts_.pop_front();
158 break;
159 }
160 }
161 }
162 return progress;
163}
164
165/**
166 * Try to flush logged, blocked TakeGrad attempts.
167 */
168void ConditionalAccumulatorBase::FlushUnlocked() {
169 std::vector<CleanUp> clean_up;
170 Ref();
171 {
172 mutex_lock lock(mu_);
173 bool changed;
174 do {
175 changed = TryAttemptLocked(&clean_up);
176 } while (changed);
177 }
178 Unref();
179 for (const auto& to_clean : clean_up) {
180 if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
181 // NOTE(mrry): We can safely ignore the return value of
182 // DeregisterCallback because the mutex mu_ ensures that the
183 // cleanup action only executes once.
184 to_clean.cm->DeregisterCallback(to_clean.to_deregister);
185 }
186 to_clean.finished();
187 }
188}
189
190bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
191 DoneCallback callback) {
192 // At this point, the conditional should have been passed
193
194 // Implicitly increment global_step
195 current_global_step_++;
196
197 // Average the accumulated gradient
198 if (reduction_type_ == "MEAN") {
199 DivideAccumGradByCounter(ctx);
200 }
201
202 // Set output for accumulated gradient tensor
203 bool successful_set_output = SetOutput(ctx);
204
205 // Reset counter
206 if (successful_set_output) counter_ = 0;
207
208 return successful_set_output;
209}
210
211} // namespace tensorflow
212